mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-26 19:33:41 +08:00 
			
		
		
		
	Compare commits
	
		
			157 Commits
		
	
	
		
			v0.4.10-al
			...
			v0.5.6-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 47c08c72ce | ||
|  | 53b2cace0b | ||
|  | f0fc991b44 | ||
|  | 594f06e7b0 | ||
|  | 197d1d7a9d | ||
|  | f9b748c2ca | ||
|  | fd98463611 | ||
|  | f5a1cd3463 | ||
|  | 8651451e53 | ||
|  | 1c5bb97a42 | ||
|  | de868e4e4e | ||
|  | 1d258cc898 | ||
|  | 37e09d764c | ||
|  | 159b9e3369 | ||
|  | 92001986db | ||
|  | a5647b1ea7 | ||
|  | 215e54fc96 | ||
|  | ecf8a6d875 | ||
|  | 24df3e5f62 | ||
|  | 12ef9679a7 | ||
|  | 328aa68255 | ||
|  | 4335f005a6 | ||
|  | fe26a1448d | ||
|  | 42451d9d02 | ||
|  | 25c4c111ab | ||
|  | 0d50ad4b2b | ||
|  | 959bcdef88 | ||
|  | 39ae8075e4 | ||
|  | b57a0eca16 | ||
|  | 1b4cc78890 | ||
|  | 420c375140 | ||
|  | 01863d3e44 | ||
|  | d0a0e871e1 | ||
|  | bd6fe1e93c | ||
|  | c55bb67818 | ||
|  | 0f949c3782 | ||
|  | a721a5b6f9 | ||
|  | 276163affd | ||
|  | 621eb91b46 | ||
|  | 7e575abb95 | ||
|  | 9db93316c4 | ||
|  | c3dc315e75 | ||
|  | 04acdb1ccb | ||
|  | f0d5e102a3 | ||
|  | abbf2fded0 | ||
|  | ef2c5abb5b | ||
|  | 56b5007379 | ||
|  | d09d317459 | ||
|  | 1c4409ae80 | ||
|  | 5ee24e8acf | ||
|  | 4f2f911e4d | ||
|  | fdb2cccf65 | ||
|  | a3e267df7e | ||
|  | ac7c0f3a76 | ||
|  | efeb9a16ce | ||
|  | 05e4f2b439 | ||
|  | 7e058bfb9b | ||
|  | dfaa0183b7 | ||
|  | 1b56becfaa | ||
|  | 23b1c63538 | ||
|  | 49d1a63402 | ||
|  | 2a7b82650c | ||
|  | 8ea7b9aae2 | ||
|  | 5136b12612 | ||
|  | 80a49e01a3 | ||
|  | 8fb082ba3b | ||
|  | 86c2627c24 | ||
|  | 90b4cac7f3 | ||
|  | e4bacc45d6 | ||
|  | da1d81998f | ||
|  | cac61b9f66 | ||
|  | 3da12e99d9 | ||
|  | 4ef5e2020c | ||
|  | af20063a8d | ||
|  | ca512f6a38 | ||
|  | 0e9ff8825e | ||
|  | e0b4f96b5b | ||
|  | eae9b6e607 | ||
|  | 7bddc73b96 | ||
|  | 2a527ee436 | ||
|  | e42119b73d | ||
|  | 821c559e89 | ||
|  | 7e2bca7e9c | ||
|  | 1e16ef3e0d | ||
|  | 476a46ad7e | ||
|  | c58f710227 | ||
|  | 150d068e9f | ||
|  | be780462f1 | ||
|  | f2159e1033 | ||
|  | 466005de07 | ||
|  | 2b088a1678 | ||
|  | 3a18cebe34 | ||
|  | cc36bf9c13 | ||
|  | 3b36608bbd | ||
|  | 29fa94e7d2 | ||
|  | 9c436921d1 | ||
|  | 463b0b3c51 | ||
|  | c3d85a28d4 | ||
|  | 7422b0d051 | ||
|  | 5a62357c93 | ||
|  | b464e2907a | ||
|  | d96cf2e84d | ||
|  | 446337c329 | ||
|  | 1dfa190e79 | ||
|  | 2d49ca6a07 | ||
|  | 89bcaaf989 | ||
|  | afcd1bd27b | ||
|  | c2c455c980 | ||
|  | 30a7f1a1c7 | ||
|  | c9d2e42a9e | ||
|  | 3fca6ff534 | ||
|  | 8cbbeb784f | ||
|  | ec88c0c240 | ||
|  | 065147b440 | ||
|  | fe8f216dd9 | ||
|  | b7d0616ae0 | ||
|  | ce9c8024a6 | ||
|  | 8a866078b2 | ||
|  | 3e81d8af45 | ||
|  | b8cb86c2c1 | ||
|  | f45d586400 | ||
|  | 50dec03ff3 | ||
|  | f31d400b6f | ||
|  | 130e6bfd83 | ||
|  | d1335ebc01 | ||
|  | e92da7928b | ||
|  | d1b6f492b6 | ||
|  | b9f6461dd4 | ||
|  | 0a39521a3d | ||
|  | c134604cee | ||
|  | 929e43ef81 | ||
|  | dce8bbe1ca | ||
|  | bc2f48b1f2 | ||
|  | 889af8b2db | ||
|  | 4eea096654 | ||
|  | 4ab3211c0e | ||
|  | 3da119efba | ||
|  | dccd66b852 | ||
|  | 2fcd6852e0 | ||
|  | 9b4d1964d4 | ||
|  | 806bf8241c | ||
|  | ce93c9b6b2 | ||
|  | 4ec4289565 | ||
|  | 3dc5a0f91d | ||
|  | 80a846673a | ||
|  | 26c6719ea3 | ||
|  | c87e05bfc2 | ||
|  | e6938bd236 | ||
|  | 8f721d67a5 | ||
|  | fcc1e2d568 | ||
|  | 9a1db61675 | ||
|  | 3c940113ab | ||
|  | 0495b9a0d7 | ||
|  | 12a0e7105e | ||
|  | e628b643cd | ||
|  | 675847bf98 | ||
|  | 2ff15baf66 | 
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -5,3 +5,4 @@ upload | |||||||
| *.db | *.db | ||||||
| build | build | ||||||
| *.db-journal | *.db-journal | ||||||
|  | logs | ||||||
| @@ -1,10 +1,11 @@ | |||||||
| FROM node:16 as builder | FROM node:16 as builder | ||||||
|  |  | ||||||
| WORKDIR /build | WORKDIR /build | ||||||
|  | COPY web/package.json . | ||||||
|  | RUN npm install | ||||||
| COPY ./web . | COPY ./web . | ||||||
| COPY ./VERSION . | COPY ./VERSION . | ||||||
| RUN npm install | RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||||
| RUN REACT_APP_VERSION=$(cat VERSION) npm run build |  | ||||||
|  |  | ||||||
| FROM golang AS builder2 | FROM golang AS builder2 | ||||||
|  |  | ||||||
| @@ -13,9 +14,10 @@ ENV GO111MODULE=on \ | |||||||
|     GOOS=linux |     GOOS=linux | ||||||
|  |  | ||||||
| WORKDIR /build | WORKDIR /build | ||||||
|  | ADD go.mod go.sum ./ | ||||||
|  | RUN go mod download | ||||||
| COPY . . | COPY . . | ||||||
| COPY --from=builder /build/build ./web/build | COPY --from=builder /build/build ./web/build | ||||||
| RUN go mod download |  | ||||||
| RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api | RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api | ||||||
|  |  | ||||||
| FROM alpine | FROM alpine | ||||||
|   | |||||||
							
								
								
									
										31
									
								
								README.en.md
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								README.en.md
									
									
									
									
									
								
							| @@ -1,5 +1,5 @@ | |||||||
| <p align="right"> | <p align="right"> | ||||||
|     <a href="./README.md">中文</a> | <strong>English</strong> |     <a href="./README.md">中文</a> | <strong>English</strong> | <a href="./README.ja.md">日本語</a> | ||||||
| </p> | </p> | ||||||
|  |  | ||||||
| <p align="center"> | <p align="center"> | ||||||
| @@ -10,7 +10,7 @@ | |||||||
|  |  | ||||||
| # One API | # One API | ||||||
|  |  | ||||||
| _✨ An OpenAI key management & redistribution system, easy to deploy & use ✨_ | _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use ✨_ | ||||||
|  |  | ||||||
| </div> | </div> | ||||||
|  |  | ||||||
| @@ -57,15 +57,13 @@ _✨ An OpenAI key management & redistribution system, easy to deploy & use ✨_ | |||||||
| > **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability. | > **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability. | ||||||
|  |  | ||||||
| ## Features | ## Features | ||||||
| 1. Supports multiple API access channels: | 1. Support for multiple large models: | ||||||
|     + [x] Official OpenAI channel (support proxy configuration) |    + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||||
|     + [x] **Azure OpenAI API** |    + [x] [Anthropic Claude Series Models](https://anthropic.com) | ||||||
|     + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) |    + [x] [Google PaLM2 Series Models](https://developers.generativeai.google) | ||||||
|     + [x] [OpenAI-SB](https://openai-sb.com) |    + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||||
|     + [x] [API2D](https://api2d.com/r/197971) |    + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) | ||||||
|     + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) |    + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) | ||||||
|     + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (invitation code: `OneAPI`) |  | ||||||
|     + [x] Custom channel: Various third-party proxy services not included in the list |  | ||||||
| 2. Supports access to multiple channels through **load balancing**. | 2. Supports access to multiple channels through **load balancing**. | ||||||
| 3. Supports **stream mode** that enables typewriter-like effect through stream transmission. | 3. Supports **stream mode** that enables typewriter-like effect through stream transmission. | ||||||
| 4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details. | 4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details. | ||||||
| @@ -175,7 +173,12 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co | |||||||
| <summary><strong>Deploy on Sealos</strong></summary> | <summary><strong>Deploy on Sealos</strong></summary> | ||||||
| <div> | <div> | ||||||
|  |  | ||||||
| Please refer to [this tutorial](https://github.com/c121914yu/FastGPT/blob/main/docs/deploy/one-api/sealos.md). | > Sealos supports high concurrency, dynamic scaling, and stable operations for millions of users. | ||||||
|  |  | ||||||
|  | > Click the button below to deploy with one click.👇 | ||||||
|  |  | ||||||
|  | [](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) | ||||||
|  |  | ||||||
|  |  | ||||||
| </div> | </div> | ||||||
| </details> | </details> | ||||||
| @@ -187,7 +190,7 @@ Please refer to [this tutorial](https://github.com/c121914yu/FastGPT/blob/main/d | |||||||
| > Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage. | > Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage. | ||||||
|  |  | ||||||
| 1. First, fork the code. | 1. First, fork the code. | ||||||
| 2. Go to [Zeabur](https://zeabur.com/), log in, and enter the console. | 2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console. | ||||||
| 3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port). | 3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port). | ||||||
| 4. Copy the connection parameters and run ```create database `one-api` ``` to create the database. | 4. Copy the connection parameters and run ```create database `one-api` ``` to create the database. | ||||||
| 5. Then, in Service -> Add Service, select Git (authorization is required for the first use) and choose your forked repository. | 5. Then, in Service -> Add Service, select Git (authorization is required for the first use) and choose your forked repository. | ||||||
| @@ -280,7 +283,7 @@ If the channel ID is not provided, load balancing will be used to distribute the | |||||||
|     + Double-check that your interface address and API Key are correct. |     + Double-check that your interface address and API Key are correct. | ||||||
|  |  | ||||||
| ## Related Projects | ## Related Projects | ||||||
| [FastGPT](https://github.com/c121914yu/FastGPT): Build an AI knowledge base in three minutes | [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM | ||||||
|  |  | ||||||
| ## Note | ## Note | ||||||
| This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. | This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. | ||||||
|   | |||||||
							
								
								
									
										298
									
								
								README.ja.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										298
									
								
								README.ja.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,298 @@ | |||||||
|  | <p align="right"> | ||||||
|  |     <a href="./README.md">中文</a> | <a href="./README.en.md">English</a> | <strong>日本語</strong> | ||||||
|  | </p> | ||||||
|  |  | ||||||
|  | <p align="center"> | ||||||
|  |   <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a> | ||||||
|  | </p> | ||||||
|  |  | ||||||
|  | <div align="center"> | ||||||
|  |  | ||||||
|  | # One API | ||||||
|  |  | ||||||
|  | _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM にアクセスでき、導入と利用が容易です ✨_ | ||||||
|  |  | ||||||
|  | </div> | ||||||
|  |  | ||||||
|  | <p align="center"> | ||||||
|  |   <a href="https://raw.githubusercontent.com/songquanpeng/one-api/main/LICENSE"> | ||||||
|  |     <img src="https://img.shields.io/github/license/songquanpeng/one-api?color=brightgreen" alt="license"> | ||||||
|  |   </a> | ||||||
|  |   <a href="https://github.com/songquanpeng/one-api/releases/latest"> | ||||||
|  |     <img src="https://img.shields.io/github/v/release/songquanpeng/one-api?color=brightgreen&include_prereleases" alt="release"> | ||||||
|  |   </a> | ||||||
|  |   <a href="https://hub.docker.com/repository/docker/justsong/one-api"> | ||||||
|  |     <img src="https://img.shields.io/docker/pulls/justsong/one-api?color=brightgreen" alt="docker pull"> | ||||||
|  |   </a> | ||||||
|  |   <a href="https://github.com/songquanpeng/one-api/releases/latest"> | ||||||
|  |     <img src="https://img.shields.io/github/downloads/songquanpeng/one-api/total?color=brightgreen&include_prereleases" alt="release"> | ||||||
|  |   </a> | ||||||
|  |   <a href="https://goreportcard.com/report/github.com/songquanpeng/one-api"> | ||||||
|  |     <img src="https://goreportcard.com/badge/github.com/songquanpeng/one-api" alt="GoReportCard"> | ||||||
|  |   </a> | ||||||
|  | </p> | ||||||
|  |  | ||||||
|  | <p align="center"> | ||||||
|  |   <a href="#deployment">デプロイチュートリアル</a> | ||||||
|  |   · | ||||||
|  |   <a href="#usage">使用方法</a> | ||||||
|  |   · | ||||||
|  |   <a href="https://github.com/songquanpeng/one-api/issues">フィードバック</a> | ||||||
|  |   · | ||||||
|  |   <a href="#screenshots">スクリーンショット</a> | ||||||
|  |   · | ||||||
|  |   <a href="https://openai.justsong.cn/">ライブデモ</a> | ||||||
|  |   · | ||||||
|  |   <a href="#faq">FAQ</a> | ||||||
|  |   · | ||||||
|  |   <a href="#related-projects">関連プロジェクト</a> | ||||||
|  |   · | ||||||
|  |   <a href="https://iamazing.cn/page/reward">寄付</a> | ||||||
|  | </p> | ||||||
|  |  | ||||||
|  | > **警告**: この README は ChatGPT によって翻訳されています。翻訳ミスを発見した場合は遠慮なく PR を投稿してください。 | ||||||
|  |  | ||||||
|  | > **警告**: 英語版の Docker イメージは `justsong/one-api-en` です。 | ||||||
|  |  | ||||||
|  | > **注**: Docker からプルされた最新のイメージは、`alpha` リリースかもしれません。安定性が必要な場合は、手動でバージョンを指定してください。 | ||||||
|  |  | ||||||
|  | ## 特徴 | ||||||
|  | 1. 複数の大型モデルをサポート: | ||||||
|  |    + [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート) | ||||||
|  |    + [x] [Anthropic Claude シリーズモデル](https://anthropic.com) | ||||||
|  |    + [x] [Google PaLM2 シリーズモデル](https://developers.generativeai.google) | ||||||
|  |    + [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||||
|  |    + [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html) | ||||||
|  |    + [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn) | ||||||
|  | 2. **ロードバランシング**による複数チャンネルへのアクセスをサポート。 | ||||||
|  | 3. ストリーム伝送によるタイプライター的効果を可能にする**ストリームモード**に対応。 | ||||||
|  | 4. **マルチマシンデプロイ**に対応。[詳細はこちら](#multi-machine-deployment)を参照。 | ||||||
|  | 5. トークンの有効期限や使用回数を設定できる**トークン管理**に対応しています。 | ||||||
|  | 6. **バウチャー管理**に対応しており、バウチャーの一括生成やエクスポートが可能です。バウチャーは口座残高の補充に利用できます。 | ||||||
|  | 7. **チャンネル管理**に対応し、チャンネルの一括作成が可能。 | ||||||
|  | 8. グループごとに異なるレートを設定するための**ユーザーグループ**と**チャンネルグループ**をサポートしています。 | ||||||
|  | 9. チャンネル**モデルリスト設定**に対応。 | ||||||
|  | 10. **クォータ詳細チェック**をサポート。 | ||||||
|  | 11. **ユーザー招待報酬**をサポートします。 | ||||||
|  | 12. 米ドルでの残高表示が可能。 | ||||||
|  | 13. 新規ユーザー向けのお知らせ公開、リチャージリンク設定、初期残高設定に対応。 | ||||||
|  | 14. 豊富な**カスタマイズ**オプションを提供します: | ||||||
|  |     1. システム名、ロゴ、フッターのカスタマイズが可能。 | ||||||
|  |     2. HTML と Markdown コードを使用したホームページとアバウトページのカスタマイズ、または iframe を介したスタンドアロンウェブページの埋め込みをサポートしています。 | ||||||
|  | 15. システム・アクセストークンによる管理 API アクセスをサポートする。 | ||||||
|  | 16. Cloudflare Turnstile によるユーザー認証に対応。 | ||||||
|  | 17. ユーザー管理と複数のユーザーログイン/登録方法をサポート: | ||||||
|  |     + 電子メールによるログイン/登録とパスワードリセット。 | ||||||
|  |     + [GitHub OAuth](https://github.com/settings/applications/new)。 | ||||||
|  |     + WeChat 公式アカウントの認証([WeChat Server](https://github.com/songquanpeng/wechat-server)の追加導入が必要)。 | ||||||
|  | 18. 他の主要なモデル API が利用可能になった場合、即座にサポートし、カプセル化する。 | ||||||
|  |  | ||||||
|  | ## デプロイメント | ||||||
|  | ### Docker デプロイメント | ||||||
|  | デプロイコマンド: `docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api-en`。 | ||||||
|  |  | ||||||
|  | コマンドを更新する: `docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrr/watchtower -cR`。 | ||||||
|  |  | ||||||
|  | `-p 3000:3000` の最初の `3000` はホストのポートで、必要に応じて変更できます。 | ||||||
|  |  | ||||||
|  | データはホストの `/home/ubuntu/data/one-api` ディレクトリに保存される。このディレクトリが存在し、書き込み権限があることを確認する、もしくは適切なディレクトリに変更してください。 | ||||||
|  |  | ||||||
|  | Nginxリファレンス設定: | ||||||
|  | ``` | ||||||
|  | server{ | ||||||
|  |    server_name openai.justsong.cn;  # ドメイン名は適宜変更 | ||||||
|  |  | ||||||
|  |    location / { | ||||||
|  |           client_max_body_size  64m; | ||||||
|  |           proxy_http_version 1.1; | ||||||
|  |           proxy_pass http://localhost:3000;  # それに応じてポートを変更 | ||||||
|  |           proxy_set_header Host $host; | ||||||
|  |           proxy_set_header X-Forwarded-For $remote_addr; | ||||||
|  |           proxy_cache_bypass $http_upgrade; | ||||||
|  |           proxy_set_header Accept-Encoding gzip; | ||||||
|  |           proxy_read_timeout 300s;  # GPT-4 はより長いタイムアウトが必要 | ||||||
|  |    } | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | 次に、Let's Encrypt certbot を使って HTTPS を設定します: | ||||||
|  | ```bash | ||||||
|  | # Ubuntu に certbot をインストール: | ||||||
|  | sudo snap install --classic certbot | ||||||
|  | sudo ln -s /snap/bin/certbot /usr/bin/certbot | ||||||
|  | # 証明書の生成と Nginx 設定の変更 | ||||||
|  | sudo certbot --nginx | ||||||
|  | # プロンプトに従う | ||||||
|  | # Nginx を再起動 | ||||||
|  | sudo service nginx restart | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | 初期アカウントのユーザー名は `root` で、パスワードは `123456` です。 | ||||||
|  |  | ||||||
|  | ### マニュアルデプロイ | ||||||
|  | 1. [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) から実行ファイルをダウンロードする、もしくはソースからコンパイルする: | ||||||
|  |    ```shell | ||||||
|  |    git clone https://github.com/songquanpeng/one-api.git | ||||||
|  |  | ||||||
|  |    # フロントエンドのビルド | ||||||
|  |    cd one-api/web | ||||||
|  |    npm install | ||||||
|  |    npm run build | ||||||
|  |  | ||||||
|  |    # バックエンドのビルド | ||||||
|  |    cd .. | ||||||
|  |    go mod download | ||||||
|  |    go build -ldflags "-s -w" -o one-api | ||||||
|  |    ``` | ||||||
|  | 2. 実行: | ||||||
|  |    ```shell | ||||||
|  |    chmod u+x one-api | ||||||
|  |    ./one-api --port 3000 --log-dir ./logs | ||||||
|  |    ``` | ||||||
|  | 3. [http://localhost:3000/](http://localhost:3000/) にアクセスし、ログインする。初期アカウントのユーザー名は `root`、パスワードは `123456` である。 | ||||||
|  |  | ||||||
|  | より詳細なデプロイのチュートリアルについては、[このページ](https://iamazing.cn/page/how-to-deploy-a-website) を参照してください。 | ||||||
|  |  | ||||||
|  | ### マルチマシンデプロイ | ||||||
|  | 1. すべてのサーバに同じ `SESSION_SECRET` を設定する。 | ||||||
|  | 2. `SQL_DSN` を設定し、SQLite の代わりに MySQL を使用する。すべてのサーバは同じデータベースに接続する。 | ||||||
|  | 3. マスターノード以外のノードの `NODE_TYPE` を `slave` に設定する。 | ||||||
|  | 4. データベースから定期的に設定を同期するサーバーには `SYNC_FREQUENCY` を設定する。 | ||||||
|  | 5. マスター以外のノードでは、オプションで `FRONTEND_BASE_URL` を設定して、ページ要求をマスターサーバーにリダイレクトすることができます。 | ||||||
|  | 6. マスター以外のノードには Redis を個別にインストールし、`REDIS_CONN_STRING` を設定して、キャッシュの有効期限が切れていないときにデータベースにゼロレイテンシーでアクセスできるようにする。 | ||||||
|  | 7. メインサーバーでもデータベースへのアクセスが高レイテンシになる場合は、Redis を有効にし、`SYNC_FREQUENCY` を設定してデータベースから定期的に設定を同期する必要がある。 | ||||||
|  |  | ||||||
|  | Please refer to the [environment variables](#environment-variables) section for details on using environment variables. | ||||||
|  |  | ||||||
|  | ### コントロールパネル(例: Baota)への展開 | ||||||
|  | 詳しい手順は [#175](https://github.com/songquanpeng/one-api/issues/175) を参照してください。 | ||||||
|  |  | ||||||
|  | 配置後に空白のページが表示される場合は、[#97](https://github.com/songquanpeng/one-api/issues/97) を参照してください。 | ||||||
|  |  | ||||||
|  | ### サードパーティプラットフォームへのデプロイ | ||||||
|  | <details> | ||||||
|  | <summary><strong>Sealos へのデプロイ</strong></summary> | ||||||
|  | <div> | ||||||
|  |  | ||||||
|  | > Sealos は、高い同時実行性、ダイナミックなスケーリング、数百万人のユーザーに対する安定した運用をサポートしています。 | ||||||
|  |  | ||||||
|  | > 下のボタンをクリックすると、ワンクリックで展開できます。👇 | ||||||
|  |  | ||||||
|  | [](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | </div> | ||||||
|  | </details> | ||||||
|  |  | ||||||
|  | <details> | ||||||
|  | <summary><strong>Zeabur へのデプロイ</strong></summary> | ||||||
|  | <div> | ||||||
|  |  | ||||||
|  | > Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。 | ||||||
|  |  | ||||||
|  | 1. まず、コードをフォークする。 | ||||||
|  | 2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。 | ||||||
|  | 3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。 | ||||||
|  | 4. 接続パラメータをコピーし、```create database `one-api` ``` を実行してデータベースを作成する。 | ||||||
|  | 5. その後、Service -> Add Service で Git を選択し(最初の使用には認証が必要です)、フォークしたリポジトリを選択します。 | ||||||
|  | 6. 自動デプロイが開始されますが、一旦キャンセルしてください。Variable タブで `PORT` に `3000` を追加し、`SQL_DSN` に `<username>:<password>@tcp(<addr>:<port>)/one-api` を追加します。変更を保存する。SQL_DSN` が設定されていないと、データが永続化されず、再デプロイ後にデータが失われるので注意すること。 | ||||||
|  | 7. 再デプロイを選択します。 | ||||||
|  | 8. Domains タブで、"my-one-api" のような適切なドメイン名の接頭辞を選択する。最終的なドメイン名は "my-one-api.zeabur.app" となります。独自のドメイン名を CNAME することもできます。 | ||||||
|  | 9. デプロイが完了するのを待ち、生成されたドメイン名をクリックして One API にアクセスします。 | ||||||
|  |  | ||||||
|  | </div> | ||||||
|  | </details> | ||||||
|  |  | ||||||
|  | ## コンフィグ | ||||||
|  | システムは箱から出してすぐに使えます。 | ||||||
|  |  | ||||||
|  | 環境変数やコマンドラインパラメータを設定することで、システムを構成することができます。 | ||||||
|  |  | ||||||
|  | システム起動後、`root` ユーザーとしてログインし、さらにシステムを設定します。 | ||||||
|  |  | ||||||
|  | ## 使用方法 | ||||||
|  | `Channels` ページで API Key を追加し、`Tokens` ページでアクセストークンを追加する。 | ||||||
|  |  | ||||||
|  | アクセストークンを使って One API にアクセスすることができる。使い方は [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) と同じです。 | ||||||
|  |  | ||||||
|  | OpenAI API が使用されている場所では、API Base に One API のデプロイアドレスを設定することを忘れないでください(例: `https://openai.justsong.cn`)。API Key は One API で生成されたトークンでなければなりません。 | ||||||
|  |  | ||||||
|  | 具体的な API Base のフォーマットは、使用しているクライアントに依存することに注意してください。 | ||||||
|  |  | ||||||
|  | ```mermaid | ||||||
|  | graph LR | ||||||
|  |     A(ユーザ) | ||||||
|  |     A --->|リクエスト| B(One API) | ||||||
|  |     B -->|中継リクエスト| C(OpenAI) | ||||||
|  |     B -->|中継リクエスト| D(Azure) | ||||||
|  |     B -->|中継リクエスト| E(その他のダウンストリームチャンネル) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | 現在のリクエストにどのチャネルを使うかを指定するには、トークンの後に チャネル ID を追加します: 例えば、`Authorization: Bearer ONE_API_KEY-CHANNEL_ID` のようにします。 | ||||||
|  | チャンネル ID を指定するためには、トークンは管理者によって作成される必要があることに注意してください。 | ||||||
|  |  | ||||||
|  | もしチャネル ID が指定されない場合、ロードバランシングによってリクエストが複数のチャネルに振り分けられます。 | ||||||
|  |  | ||||||
|  | ### 環境変数 | ||||||
|  | 1. `REDIS_CONN_STRING`: 設定すると、リクエストレート制限のためのストレージとして、メモリの代わりに Redis が使われる。 | ||||||
|  |     + 例: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||||
|  | 2. `SESSION_SECRET`: 設定すると、固定セッションキーが使用され、システムの再起動後もログインユーザーのクッキーが有効であることが保証されます。 | ||||||
|  |     + 例: `SESSION_SECRET=random_string` | ||||||
|  | 3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。 | ||||||
|  |     + 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||||
|  | 4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。 | ||||||
|  |     + 例: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||||
|  | 5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 | ||||||
|  |     + 例: `SYNC_FREQUENCY=60` | ||||||
|  | 6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。 | ||||||
|  |     + 例: `NODE_TYPE=slave` | ||||||
|  | 7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 | ||||||
|  |     + 例: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||||
|  | 8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 | ||||||
|  |     + 例: `CHANNEL_TEST_FREQUENCY=1440` | ||||||
|  | 9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 | ||||||
|  |     + 例: `POLLING_INTERVAL=5` | ||||||
|  |  | ||||||
|  | ### コマンドラインパラメータ | ||||||
|  | 1. `--port <port_number>`: サーバがリッスンするポート番号を指定。デフォルトは `3000` です。 | ||||||
|  |     + 例: `--port 3000` | ||||||
|  | 2. `--log-dir <log_dir>`: ログディレクトリを指定。設定しない場合、ログは保存されません。 | ||||||
|  |     + 例: `--log-dir ./logs` | ||||||
|  | 3. `--version`: システムのバージョン番号を表示して終了する。 | ||||||
|  | 4. `--help`: コマンドの使用法ヘルプとパラメータの説明を表示。 | ||||||
|  |  | ||||||
|  | ## スクリーンショット | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ## FAQ | ||||||
|  | 1. ノルマとは何か?どのように計算されますか?One API にはノルマ計算の問題はありますか? | ||||||
|  |     + ノルマ = グループ倍率 * モデル倍率 * (プロンプトトークンの数 + 完了トークンの数 * 完了倍率) | ||||||
|  |     + 完了倍率は、公式の定義と一致するように、GPT3.5 では 1.33、GPT4 では 2 に固定されています。 | ||||||
|  |     + ストリームモードでない場合、公式 API は消費したトークンの総数を返す。ただし、プロンプトとコンプリートの消費倍率は異なるので注意してください。 | ||||||
|  | 2. アカウント残高は十分なのに、"insufficient quota" と表示されるのはなぜですか? | ||||||
|  |     + トークンのクォータが十分かどうかご確認ください。トークンクォータはアカウント残高とは別のものです。 | ||||||
|  |     + トークンクォータは最大使用量を設定するためのもので、ユーザーが自由に設定できます。 | ||||||
|  | 3. チャンネルを使おうとすると "No available channels" と表示されます。どうすればいいですか? | ||||||
|  |     + ユーザーとチャンネルグループの設定を確認してください。 | ||||||
|  |     + チャンネルモデルの設定も確認してください。 | ||||||
|  | 4. チャンネルテストがエラーを報告する: "invalid character '<' looking for beginning of value" | ||||||
|  |     + このエラーは、返された値が有効な JSON ではなく、HTML ページである場合に発生する。 | ||||||
|  |     + ほとんどの場合、デプロイサイトのIPかプロキシのノードが CloudFlare によってブロックされています。 | ||||||
|  | 5. ChatGPT Next Web でエラーが発生しました: "Failed to fetch" | ||||||
|  |     + デプロイ時に `BASE_URL` を設定しないでください。 | ||||||
|  |     + インターフェイスアドレスと API Key が正しいか再確認してください。 | ||||||
|  |  | ||||||
|  | ## 関連プロジェクト | ||||||
|  | [FastGPT](https://github.com/labring/FastGPT): LLM に基づく知識質問応答システム | ||||||
|  |  | ||||||
|  | ## 注 | ||||||
|  | 本プロジェクトはオープンソースプロジェクトです。OpenAI の[利用規約](https://openai.com/policies/terms-of-use)および**適用される法令**を遵守してご利用ください。違法な目的での利用はご遠慮ください。 | ||||||
|  |  | ||||||
|  | このプロジェクトは MIT ライセンスで公開されています。これに基づき、ページの最下部に帰属表示と本プロジェクトへのリンクを含める必要があります。 | ||||||
|  |  | ||||||
|  | このプロジェクトを基にした派生プロジェクトについても同様です。 | ||||||
|  |  | ||||||
|  | 帰属表示を含めたくない場合は、事前に許可を得なければなりません。 | ||||||
|  |  | ||||||
|  | MIT ライセンスによると、このプロジェクトを利用するリスクと責任は利用者が負うべきであり、このオープンソースプロジェクトの開発者は責任を負いません。 | ||||||
							
								
								
									
										151
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										151
									
								
								README.md
									
									
									
									
									
								
							| @@ -1,5 +1,5 @@ | |||||||
| <p align="right"> | <p align="right"> | ||||||
|    <strong>中文</strong> | <a href="./README.en.md">English</a> |    <strong>中文</strong> | <a href="./README.en.md">English</a> | <a href="./README.ja.md">日本語</a> | ||||||
| </p> | </p> | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -11,7 +11,7 @@ | |||||||
|  |  | ||||||
| # One API | # One API | ||||||
|  |  | ||||||
| _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用✨_ | _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 ✨_ | ||||||
|  |  | ||||||
| </div> | </div> | ||||||
|  |  | ||||||
| @@ -51,64 +51,75 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用 | |||||||
|   <a href="https://iamazing.cn/page/reward">赞赏支持</a> |   <a href="https://iamazing.cn/page/reward">赞赏支持</a> | ||||||
| </p> | </p> | ||||||
|  |  | ||||||
| > **Note**:本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 | > **Note** | ||||||
|  | > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 | ||||||
|  | >  | ||||||
|  | > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 | ||||||
|  |  | ||||||
| > **Note**:使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 | > **Warning** | ||||||
|  | > 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 | ||||||
|  |  | ||||||
| > **Warning**:从 `v0.3` 版本升级到 `v0.4` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.3-v0.4.sql)。 | > **Warning** | ||||||
|  | > 使用 root 用户初次登录系统后,务必修改默认密码 `123456`! | ||||||
|  |  | ||||||
| ## 功能 | ## 功能 | ||||||
| 1. 支持多种 API 访问渠道: | 1. 支持多种大模型: | ||||||
|    + [x] OpenAI 官方通道(支持配置镜像) |    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||||
|    + [x] **Azure OpenAI API** |    + [x] [Anthropic Claude 系列模型](https://anthropic.com) | ||||||
|    + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) |    + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) | ||||||
|  |    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||||
|  |    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) | ||||||
|  |    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||||
|  |    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) | ||||||
|  |    + [x] [360 智脑](https://ai.360.cn) | ||||||
|  | 2. 支持配置镜像以及众多第三方代理服务: | ||||||
|    + [x] [OpenAI-SB](https://openai-sb.com) |    + [x] [OpenAI-SB](https://openai-sb.com) | ||||||
|  |    + [x] [CloseAI](https://console.closeai-asia.com/r/2412) | ||||||
|    + [x] [API2D](https://api2d.com/r/197971) |    + [x] [API2D](https://api2d.com/r/197971) | ||||||
|    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) |    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) | ||||||
|    + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) |    + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) | ||||||
|    + [x] [CloseAI](https://console.closeai-asia.com/r/2412) |  | ||||||
|    + [x] 自定义渠道:例如各种未收录的第三方代理服务 |    + [x] 自定义渠道:例如各种未收录的第三方代理服务 | ||||||
| 2. 支持通过**负载均衡**的方式访问多个渠道。 | 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||||
| 3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||||
| 4. 支持**多机部署**,[详见此处](#多机部署)。 | 5. 支持**多机部署**,[详见此处](#多机部署)。 | ||||||
| 5. 支持**令牌管理**,设置令牌的过期时间和额度。 | 6. 支持**令牌管理**,设置令牌的过期时间和额度。 | ||||||
| 6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 | 7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 | ||||||
| 7. 支持**通道管理**,批量创建通道。 | 8. 支持**通道管理**,批量创建通道。 | ||||||
| 8. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 | 9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 | ||||||
| 9. 支持渠道**设置模型列表**。 | 10. 支持渠道**设置模型列表**。 | ||||||
| 10. 支持**查看额度明细**。 | 11. 支持**查看额度明细**。 | ||||||
| 11. 支持**用户邀请奖励**。 | 12. 支持**用户邀请奖励**。 | ||||||
| 12. 支持以美元为单位显示额度。 | 13. 支持以美元为单位显示额度。 | ||||||
| 13. 支持发布公告,设置充值链接,设置新用户初始额度。 | 14. 支持发布公告,设置充值链接,设置新用户初始额度。 | ||||||
| 14. 支持模型映射,重定向用户的请求模型。 | 15. 支持模型映射,重定向用户的请求模型。 | ||||||
| 15. 支持失败自动重试。 | 16. 支持失败自动重试。 | ||||||
| 16. 支持绘图接口。 | 17. 支持绘图接口。 | ||||||
| 17. 支持丰富的**自定义**设置, | 18. 支持丰富的**自定义**设置, | ||||||
|     1. 支持自定义系统名称,logo 以及页脚。 |     1. 支持自定义系统名称,logo 以及页脚。 | ||||||
|     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 |     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 | ||||||
| 18. 支持通过系统访问令牌访问管理 API。 | 19. 支持通过系统访问令牌访问管理 API。 | ||||||
| 19. 支持 Cloudflare Turnstile 用户校验。 | 20. 支持 Cloudflare Turnstile 用户校验。 | ||||||
| 20. 支持用户管理,支持**多种用户登录注册方式**: | 21. 支持用户管理,支持**多种用户登录注册方式**: | ||||||
|     + 邮箱登录注册以及通过邮箱进行密码重置。 |     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 |     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 |     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||||
| 21. 支持 [ChatGLM](https://github.com/THUDM/ChatGLM2-6B)。 |  | ||||||
| 22. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。 |  | ||||||
|  |  | ||||||
| ## 部署 | ## 部署 | ||||||
| ### 基于 Docker 进行部署 | ### 基于 Docker 进行部署 | ||||||
| 部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api` | 部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api` | ||||||
|  |  | ||||||
| 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 | 其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。 | ||||||
|  |  | ||||||
| 如果你的并发量较大,推荐设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。 |  | ||||||
|  |  | ||||||
| 更新命令:`docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR` |  | ||||||
|  |  | ||||||
| `-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。 |  | ||||||
|  |  | ||||||
| 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 | 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 | ||||||
|  |  | ||||||
|  | 如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。 | ||||||
|  |  | ||||||
|  | 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 | ||||||
|  |  | ||||||
|  | 如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。 | ||||||
|  |  | ||||||
|  | 更新命令:`docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR` | ||||||
|  |  | ||||||
| Nginx 的参考配置: | Nginx 的参考配置: | ||||||
| ``` | ``` | ||||||
| server{ | server{ | ||||||
| @@ -203,14 +214,23 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | |||||||
|  |  | ||||||
| 注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。 | 注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。 | ||||||
|  |  | ||||||
|  | #### QChatGPT - QQ机器人 | ||||||
|  | 项目主页:https://github.com/RockChinQ/QChatGPT | ||||||
|  |  | ||||||
|  | 根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。 | ||||||
|  |  | ||||||
|  | 可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。 | ||||||
|  |  | ||||||
| ### 部署到第三方平台 | ### 部署到第三方平台 | ||||||
| <details> | <details> | ||||||
| <summary><strong>部署到 Sealos </strong></summary> | <summary><strong>部署到 Sealos </strong></summary> | ||||||
| <div> | <div> | ||||||
|  |  | ||||||
| > Sealos 可视化部署,仅需 1 分钟。 | > Sealos 的服务器在国外,不需要额外处理网络问题,支持高并发 & 动态伸缩。 | ||||||
|  |  | ||||||
| 参考这个[教程](https://github.com/c121914yu/FastGPT/blob/main/docs/deploy/one-api/sealos.md)中 1~5 步。 | 点击以下按钮一键部署(部署后访问出现 404 请等待 3~5 分钟): | ||||||
|  |  | ||||||
|  | [](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) | ||||||
|  |  | ||||||
| </div> | </div> | ||||||
| </details> | </details> | ||||||
| @@ -222,7 +242,7 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | |||||||
| > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用。 | > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用。 | ||||||
|  |  | ||||||
| 1. 首先 fork 一份代码。 | 1. 首先 fork 一份代码。 | ||||||
| 2. 进入 [Zeabur](https://zeabur.com/),登录,进入控制台。 | 2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 | ||||||
| 3. 新建一个 Project,在 Service -> Add Service 选择 Marketplace,选择 MySQL,并记下连接参数(用户名、密码、地址、端口)。 | 3. 新建一个 Project,在 Service -> Add Service 选择 Marketplace,选择 MySQL,并记下连接参数(用户名、密码、地址、端口)。 | ||||||
| 4. 复制链接参数,运行 ```create database `one-api` ``` 创建数据库。 | 4. 复制链接参数,运行 ```create database `one-api` ``` 创建数据库。 | ||||||
| 5. 然后在 Service -> Add Service,选择 Git(第一次使用需要先授权),选择你 fork 的仓库。 | 5. 然后在 Service -> Add Service,选择 Git(第一次使用需要先授权),选择你 fork 的仓库。 | ||||||
| @@ -252,6 +272,12 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | |||||||
|  |  | ||||||
| 注意,具体的 API Base 的格式取决于你所使用的客户端。 | 注意,具体的 API Base 的格式取决于你所使用的客户端。 | ||||||
|  |  | ||||||
|  | 例如对于 OpenAI 的官方库: | ||||||
|  | ```bash | ||||||
|  | OPENAI_API_KEY="sk-xxxxxx" | ||||||
|  | OPENAI_API_BASE="https://<HOST>:<PORT>/v1"  | ||||||
|  | ``` | ||||||
|  |  | ||||||
| ```mermaid | ```mermaid | ||||||
| graph LR | graph LR | ||||||
|     A(用户) |     A(用户) | ||||||
| @@ -267,32 +293,50 @@ graph LR | |||||||
| 不加的话将会使用负载均衡的方式使用多个渠道。 | 不加的话将会使用负载均衡的方式使用多个渠道。 | ||||||
|  |  | ||||||
| ### 环境变量 | ### 环境变量 | ||||||
| 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储。 | 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 | ||||||
|    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` |    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||||
|  |    + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 | ||||||
| 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 | 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 | ||||||
|    + 例子:`SESSION_SECRET=random_string` |    + 例子:`SESSION_SECRET=random_string` | ||||||
| 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 8.0 版本。 | 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 | ||||||
|    + 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` |    + 例子: | ||||||
|  |      + MySQL:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||||
|  |      + PostgreSQL:`SQL_DSN=postgres://postgres:123456@localhost:5432/oneapi`(适配中,欢迎反馈) | ||||||
|    + 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。 |    + 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。 | ||||||
|    + 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。 |    + 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。 | ||||||
|    + 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。 |    + 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。 | ||||||
|  |    + 请根据你的数据库配置修改下列参数(或者保持默认值): | ||||||
|  |      + `SQL_MAX_IDLE_CONNS`:最大空闲连接数,默认为 `100`。 | ||||||
|  |      + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 | ||||||
|  |        + 如果报错 `Error 1040: Too many connections`,请适当减小该值。 | ||||||
|  |      + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 | ||||||
| 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | ||||||
|    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` |    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||||
| 5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。 | 5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||||
|  |    + 例子:`MEMORY_CACHE_ENABLED=true` | ||||||
|  | 6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 | ||||||
|    + 例子:`SYNC_FREQUENCY=60` |    + 例子:`SYNC_FREQUENCY=60` | ||||||
| 6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | 7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | ||||||
|    + 例子:`NODE_TYPE=slave` |    + 例子:`NODE_TYPE=slave` | ||||||
| 7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | 8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | ||||||
|    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` |    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` | ||||||
| 8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | 9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | ||||||
|    + 例子:`CHANNEL_TEST_FREQUENCY=1440` |    + 例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||||
| 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | 10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||||
|     + 例子:`POLLING_INTERVAL=5` |     + 例子:`POLLING_INTERVAL=5` | ||||||
|  | 11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||||
|  |     + 例子:`BATCH_UPDATE_ENABLED=true` | ||||||
|  |     + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 | ||||||
|  | 12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||||
|  |     + 例子:`BATCH_UPDATE_INTERVAL=5` | ||||||
|  | 13. 请求频率限制: | ||||||
|  |     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 | ||||||
|  |     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 | ||||||
|  |  | ||||||
| ### 命令行参数 | ### 命令行参数 | ||||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||||
|    + 例子:`--port 3000` |    + 例子:`--port 3000` | ||||||
| 2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,日志将不会被保存。 | 2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。 | ||||||
|    + 例子:`--log-dir ./logs` |    + 例子:`--log-dir ./logs` | ||||||
| 3. `--version`: 打印系统版本号并退出。 | 3. `--version`: 打印系统版本号并退出。 | ||||||
| 4. `--help`: 查看命令的使用帮助和参数说明。 | 4. `--help`: 查看命令的使用帮助和参数说明。 | ||||||
| @@ -311,6 +355,7 @@ https://openai.justsong.cn | |||||||
|    + 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率) |    + 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率) | ||||||
|    + 其中补全倍率对于 GPT3.5 固定为 1.33,GPT4 为 2,与官方保持一致。 |    + 其中补全倍率对于 GPT3.5 固定为 1.33,GPT4 为 2,与官方保持一致。 | ||||||
|    + 如果是非流模式,官方接口会返回消耗的总 token,但是你要注意提示和补全的消耗倍率不一样。 |    + 如果是非流模式,官方接口会返回消耗的总 token,但是你要注意提示和补全的消耗倍率不一样。 | ||||||
|  |    + 注意,One API 的默认倍率就是官方倍率,是已经调整过的。 | ||||||
| 2. 账户额度足够为什么提示额度不足? | 2. 账户额度足够为什么提示额度不足? | ||||||
|    + 请检查你的令牌额度是否足够,这个和账户额度是分开的。 |    + 请检查你的令牌额度是否足够,这个和账户额度是分开的。 | ||||||
|    + 令牌额度仅供用户设置最大使用量,用户可自由设置。 |    + 令牌额度仅供用户设置最大使用量,用户可自由设置。 | ||||||
| @@ -323,11 +368,13 @@ https://openai.justsong.cn | |||||||
| 5. ChatGPT Next Web 报错:`Failed to fetch` | 5. ChatGPT Next Web 报错:`Failed to fetch` | ||||||
|    + 部署的时候不要设置 `BASE_URL`。 |    + 部署的时候不要设置 `BASE_URL`。 | ||||||
|    + 检查你的接口地址和 API Key 有没有填对。 |    + 检查你的接口地址和 API Key 有没有填对。 | ||||||
|  |    + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 | ||||||
| 6. 报错:`当前分组负载已饱和,请稍后再试` | 6. 报错:`当前分组负载已饱和,请稍后再试` | ||||||
|    + 上游通道 429 了。 |    + 上游通道 429 了。 | ||||||
|  |  | ||||||
| ## 相关项目 | ## 相关项目 | ||||||
| [FastGPT](https://github.com/c121914yu/FastGPT): 三分钟搭建 AI 知识库 | * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | ||||||
|  | * [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web):  一键拥有你自己的跨平台 ChatGPT 应用 | ||||||
|  |  | ||||||
| ## 注意 | ## 注意 | ||||||
|  |  | ||||||
|   | |||||||
| @@ -42,6 +42,22 @@ var WeChatAuthEnabled = false | |||||||
| var TurnstileCheckEnabled = false | var TurnstileCheckEnabled = false | ||||||
| var RegisterEnabled = true | var RegisterEnabled = true | ||||||
|  |  | ||||||
|  | var EmailDomainRestrictionEnabled = false | ||||||
|  | var EmailDomainWhitelist = []string{ | ||||||
|  | 	"gmail.com", | ||||||
|  | 	"163.com", | ||||||
|  | 	"126.com", | ||||||
|  | 	"qq.com", | ||||||
|  | 	"outlook.com", | ||||||
|  | 	"hotmail.com", | ||||||
|  | 	"icloud.com", | ||||||
|  | 	"yahoo.com", | ||||||
|  | 	"foxmail.com", | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var DebugEnabled = os.Getenv("DEBUG") == "true" | ||||||
|  | var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" | ||||||
|  |  | ||||||
| var LogConsumeEnabled = true | var LogConsumeEnabled = true | ||||||
|  |  | ||||||
| var SMTPServer = "" | var SMTPServer = "" | ||||||
| @@ -77,6 +93,15 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" | |||||||
| var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | ||||||
| var RequestInterval = time.Duration(requestInterval) * time.Second | var RequestInterval = time.Duration(requestInterval) * time.Second | ||||||
|  |  | ||||||
|  | var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second | ||||||
|  |  | ||||||
|  | var BatchUpdateEnabled = false | ||||||
|  | var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	RequestIdKey = "X-Oneapi-Request-Id" | ||||||
|  | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	RoleGuestUser  = 0 | 	RoleGuestUser  = 0 | ||||||
| 	RoleCommonUser = 1 | 	RoleCommonUser = 1 | ||||||
| @@ -94,10 +119,10 @@ var ( | |||||||
| // All duration's unit is seconds | // All duration's unit is seconds | ||||||
| // Shouldn't larger then RateLimitKeyExpirationDuration | // Shouldn't larger then RateLimitKeyExpirationDuration | ||||||
| var ( | var ( | ||||||
| 	GlobalApiRateLimitNum            = 180 | 	GlobalApiRateLimitNum            = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) | ||||||
| 	GlobalApiRateLimitDuration int64 = 3 * 60 | 	GlobalApiRateLimitDuration int64 = 3 * 60 | ||||||
|  |  | ||||||
| 	GlobalWebRateLimitNum            = 60 | 	GlobalWebRateLimitNum            = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) | ||||||
| 	GlobalWebRateLimitDuration int64 = 3 * 60 | 	GlobalWebRateLimitDuration int64 = 3 * 60 | ||||||
|  |  | ||||||
| 	UploadRateLimitNum            = 10 | 	UploadRateLimitNum            = 10 | ||||||
| @@ -151,6 +176,15 @@ const ( | |||||||
| 	ChannelTypePaLM           = 11 | 	ChannelTypePaLM           = 11 | ||||||
| 	ChannelTypeAPI2GPT        = 12 | 	ChannelTypeAPI2GPT        = 12 | ||||||
| 	ChannelTypeAIGC2D         = 13 | 	ChannelTypeAIGC2D         = 13 | ||||||
|  | 	ChannelTypeAnthropic      = 14 | ||||||
|  | 	ChannelTypeBaidu          = 15 | ||||||
|  | 	ChannelTypeZhipu          = 16 | ||||||
|  | 	ChannelTypeAli            = 17 | ||||||
|  | 	ChannelTypeXunfei         = 18 | ||||||
|  | 	ChannelType360            = 19 | ||||||
|  | 	ChannelTypeOpenRouter     = 20 | ||||||
|  | 	ChannelTypeAIProxyLibrary = 21 | ||||||
|  | 	ChannelTypeFastGPT        = 22 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ChannelBaseURLs = []string{ | var ChannelBaseURLs = []string{ | ||||||
| @@ -168,4 +202,13 @@ var ChannelBaseURLs = []string{ | |||||||
| 	"",                                // 11 | 	"",                                // 11 | ||||||
| 	"https://api.api2gpt.com",         // 12 | 	"https://api.api2gpt.com",         // 12 | ||||||
| 	"https://api.aigc2d.com",          // 13 | 	"https://api.aigc2d.com",          // 13 | ||||||
|  | 	"https://api.anthropic.com",       // 14 | ||||||
|  | 	"https://aip.baidubce.com",        // 15 | ||||||
|  | 	"https://open.bigmodel.cn",        // 16 | ||||||
|  | 	"https://dashscope.aliyuncs.com",  // 17 | ||||||
|  | 	"",                                // 18 | ||||||
|  | 	"https://ai.360.cn",               // 19 | ||||||
|  | 	"https://openrouter.ai/api",       // 20 | ||||||
|  | 	"https://api.aiproxy.io",          // 21 | ||||||
|  | 	"https://fastgpt.run/api/openapi", // 22 | ||||||
| } | } | ||||||
|   | |||||||
| @@ -12,7 +12,7 @@ var ( | |||||||
| 	Port         = flag.Int("port", 3000, "the listening port") | 	Port         = flag.Int("port", 3000, "the listening port") | ||||||
| 	PrintVersion = flag.Bool("version", false, "print version and exit") | 	PrintVersion = flag.Bool("version", false, "print version and exit") | ||||||
| 	PrintHelp    = flag.Bool("help", false, "print help and exit") | 	PrintHelp    = flag.Bool("help", false, "print help and exit") | ||||||
| 	LogDir       = flag.String("log-dir", "", "specify the log directory") | 	LogDir       = flag.String("log-dir", "./logs", "specify the log directory") | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func printHelp() { | func printHelp() { | ||||||
|   | |||||||
| @@ -1,29 +1,47 @@ | |||||||
| package common | package common | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"io" | 	"io" | ||||||
| 	"log" | 	"log" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
|  | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func SetupGinLog() { | const ( | ||||||
|  | 	loggerINFO  = "INFO" | ||||||
|  | 	loggerWarn  = "WARN" | ||||||
|  | 	loggerError = "ERR" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const maxLogCount = 1000000 | ||||||
|  |  | ||||||
|  | var logCount int | ||||||
|  | var setupLogLock sync.Mutex | ||||||
|  | var setupLogWorking bool | ||||||
|  |  | ||||||
|  | func SetupLogger() { | ||||||
| 	if *LogDir != "" { | 	if *LogDir != "" { | ||||||
| 		commonLogPath := filepath.Join(*LogDir, "common.log") | 		ok := setupLogLock.TryLock() | ||||||
| 		errorLogPath := filepath.Join(*LogDir, "error.log") | 		if !ok { | ||||||
| 		commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | 			log.Println("setup log is already working") | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		defer func() { | ||||||
|  | 			setupLogLock.Unlock() | ||||||
|  | 			setupLogWorking = false | ||||||
|  | 		}() | ||||||
|  | 		logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) | ||||||
|  | 		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Fatal("failed to open log file") | 			log.Fatal("failed to open log file") | ||||||
| 		} | 		} | ||||||
| 		errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | 		gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) | ||||||
| 		if err != nil { | 		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) | ||||||
| 			log.Fatal("failed to open log file") |  | ||||||
| 		} |  | ||||||
| 		gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd) |  | ||||||
| 		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd) |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -37,6 +55,36 @@ func SysError(s string) { | |||||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func LogInfo(ctx context.Context, msg string) { | ||||||
|  | 	logHelper(ctx, loggerINFO, msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func LogWarn(ctx context.Context, msg string) { | ||||||
|  | 	logHelper(ctx, loggerWarn, msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func LogError(ctx context.Context, msg string) { | ||||||
|  | 	logHelper(ctx, loggerError, msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func logHelper(ctx context.Context, level string, msg string) { | ||||||
|  | 	writer := gin.DefaultErrorWriter | ||||||
|  | 	if level == loggerINFO { | ||||||
|  | 		writer = gin.DefaultWriter | ||||||
|  | 	} | ||||||
|  | 	id := ctx.Value(RequestIdKey) | ||||||
|  | 	now := time.Now() | ||||||
|  | 	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) | ||||||
|  | 	logCount++ // we don't need accurate count, so no lock here | ||||||
|  | 	if logCount > maxLogCount && !setupLogWorking { | ||||||
|  | 		logCount = 0 | ||||||
|  | 		setupLogWorking = true | ||||||
|  | 		go func() { | ||||||
|  | 			SetupLogger() | ||||||
|  | 		}() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func FatalLog(v ...any) { | func FatalLog(v ...any) { | ||||||
| 	t := time.Now() | 	t := time.Now() | ||||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) | 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) | ||||||
|   | |||||||
| @@ -1,12 +1,17 @@ | |||||||
| package common | package common | ||||||
|  |  | ||||||
| import "encoding/json" | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
| // ModelRatio | // ModelRatio | ||||||
| // https://platform.openai.com/docs/models/model-endpoint-compatibility | // https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||||
|  | // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf | ||||||
| // https://openai.com/pricing | // https://openai.com/pricing | ||||||
| // TODO: when a new api is enabled, check the pricing here | // TODO: when a new api is enabled, check the pricing here | ||||||
| // 1 === $0.002 / 1K tokens | // 1 === $0.002 / 1K tokens | ||||||
|  | // 1 === ¥0.014 / 1k tokens | ||||||
| var ModelRatio = map[string]float64{ | var ModelRatio = map[string]float64{ | ||||||
| 	"gpt-4":                     15, | 	"gpt-4":                     15, | ||||||
| 	"gpt-4-0314":                15, | 	"gpt-4-0314":                15, | ||||||
| @@ -19,6 +24,7 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"gpt-3.5-turbo-0613":        0.75, | 	"gpt-3.5-turbo-0613":        0.75, | ||||||
| 	"gpt-3.5-turbo-16k":         1.5, // $0.003 / 1K tokens | 	"gpt-3.5-turbo-16k":         1.5, // $0.003 / 1K tokens | ||||||
| 	"gpt-3.5-turbo-16k-0613":    1.5, | 	"gpt-3.5-turbo-16k-0613":    1.5, | ||||||
|  | 	"gpt-3.5-turbo-instruct":    0.75, // $0.0015 / 1K tokens | ||||||
| 	"text-ada-001":              0.2, | 	"text-ada-001":              0.2, | ||||||
| 	"text-babbage-001":          0.25, | 	"text-babbage-001":          0.25, | ||||||
| 	"text-curie-001":            1, | 	"text-curie-001":            1, | ||||||
| @@ -26,7 +32,7 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"text-davinci-003":          10, | 	"text-davinci-003":          10, | ||||||
| 	"text-davinci-edit-001":     10, | 	"text-davinci-edit-001":     10, | ||||||
| 	"code-davinci-edit-001":     10, | 	"code-davinci-edit-001":     10, | ||||||
| 	"whisper-1":               10, | 	"whisper-1":                 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens | ||||||
| 	"davinci":                   10, | 	"davinci":                   10, | ||||||
| 	"curie":                     10, | 	"curie":                     10, | ||||||
| 	"babbage":                   10, | 	"babbage":                   10, | ||||||
| @@ -36,6 +42,24 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"text-moderation-stable":    0.1, | 	"text-moderation-stable":    0.1, | ||||||
| 	"text-moderation-latest":    0.1, | 	"text-moderation-latest":    0.1, | ||||||
| 	"dall-e":                    8, | 	"dall-e":                    8, | ||||||
|  | 	"claude-instant-1":          0.815,  // $1.63 / 1M tokens | ||||||
|  | 	"claude-2":                  5.51,   // $11.02 / 1M tokens | ||||||
|  | 	"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens | ||||||
|  | 	"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens | ||||||
|  | 	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens | ||||||
|  | 	"PaLM-2":                    1, | ||||||
|  | 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | ||||||
|  | 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | ||||||
|  | 	"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens | ||||||
|  | 	"qwen-turbo":                0.8572, // ¥0.012 / 1k tokens | ||||||
|  | 	"qwen-plus":                 10,     // ¥0.14 / 1k tokens | ||||||
|  | 	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens | ||||||
|  | 	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens | ||||||
|  | 	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens | ||||||
|  | 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens | ||||||
|  | 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens | ||||||
|  | 	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens | ||||||
|  | 	"360GPT_S2_V9.4":            0.8572, // ¥0.012 / 1k tokens | ||||||
| } | } | ||||||
|  |  | ||||||
| func ModelRatio2JSONString() string { | func ModelRatio2JSONString() string { | ||||||
| @@ -59,3 +83,19 @@ func GetModelRatio(name string) float64 { | |||||||
| 	} | 	} | ||||||
| 	return ratio | 	return ratio | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GetCompletionRatio(name string) float64 { | ||||||
|  | 	if strings.HasPrefix(name, "gpt-3.5") { | ||||||
|  | 		return 1.333333 | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(name, "gpt-4") { | ||||||
|  | 		return 2 | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(name, "claude-instant-1") { | ||||||
|  | 		return 3.38 | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(name, "claude-2") { | ||||||
|  | 		return 2.965517 | ||||||
|  | 	} | ||||||
|  | 	return 1 | ||||||
|  | } | ||||||
|   | |||||||
| @@ -61,3 +61,8 @@ func RedisDel(key string) error { | |||||||
| 	ctx := context.Background() | 	ctx := context.Background() | ||||||
| 	return RDB.Del(ctx, key).Err() | 	return RDB.Del(ctx, key).Err() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func RedisDecrease(key string, value int64) error { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	return RDB.DecrBy(ctx, key, value).Err() | ||||||
|  | } | ||||||
|   | |||||||
| @@ -7,6 +7,7 @@ import ( | |||||||
| 	"log" | 	"log" | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
| 	"net" | 	"net" | ||||||
|  | 	"os" | ||||||
| 	"os/exec" | 	"os/exec" | ||||||
| 	"runtime" | 	"runtime" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| @@ -170,6 +171,11 @@ func GetTimestamp() int64 { | |||||||
| 	return time.Now().Unix() | 	return time.Now().Unix() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GetTimeString() string { | ||||||
|  | 	now := time.Now() | ||||||
|  | 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) | ||||||
|  | } | ||||||
|  |  | ||||||
| func Max(a int, b int) int { | func Max(a int, b int) int { | ||||||
| 	if a >= b { | 	if a >= b { | ||||||
| 		return a | 		return a | ||||||
| @@ -177,3 +183,19 @@ func Max(a int, b int) int { | |||||||
| 		return b | 		return b | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GetOrDefault(env string, defaultValue int) int { | ||||||
|  | 	if env == "" || os.Getenv(env) == "" { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	num, err := strconv.Atoi(os.Getenv(env)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	return num | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func MessageWithRequestId(message string, id string) string { | ||||||
|  | 	return fmt.Sprintf("%s (request id: %s)", message, id) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -11,9 +11,11 @@ func GetSubscription(c *gin.Context) { | |||||||
| 	var usedQuota int | 	var usedQuota int | ||||||
| 	var err error | 	var err error | ||||||
| 	var token *model.Token | 	var token *model.Token | ||||||
|  | 	var expiredTime int64 | ||||||
| 	if common.DisplayTokenStatEnabled { | 	if common.DisplayTokenStatEnabled { | ||||||
| 		tokenId := c.GetInt("token_id") | 		tokenId := c.GetInt("token_id") | ||||||
| 		token, err = model.GetTokenById(tokenId) | 		token, err = model.GetTokenById(tokenId) | ||||||
|  | 		expiredTime = token.ExpiredTime | ||||||
| 		remainQuota = token.RemainQuota | 		remainQuota = token.RemainQuota | ||||||
| 		usedQuota = token.UsedQuota | 		usedQuota = token.UsedQuota | ||||||
| 	} else { | 	} else { | ||||||
| @@ -21,10 +23,13 @@ func GetSubscription(c *gin.Context) { | |||||||
| 		remainQuota, err = model.GetUserQuota(userId) | 		remainQuota, err = model.GetUserQuota(userId) | ||||||
| 		usedQuota, err = model.GetUserUsedQuota(userId) | 		usedQuota, err = model.GetUserUsedQuota(userId) | ||||||
| 	} | 	} | ||||||
|  | 	if expiredTime <= 0 { | ||||||
|  | 		expiredTime = 0 | ||||||
|  | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		openAIError := OpenAIError{ | 		openAIError := OpenAIError{ | ||||||
| 			Message: err.Error(), | 			Message: err.Error(), | ||||||
| 			Type:    "one_api_error", | 			Type:    "upstream_error", | ||||||
| 		} | 		} | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"error": openAIError, | 			"error": openAIError, | ||||||
| @@ -45,6 +50,7 @@ func GetSubscription(c *gin.Context) { | |||||||
| 		SoftLimitUSD:       amount, | 		SoftLimitUSD:       amount, | ||||||
| 		HardLimitUSD:       amount, | 		HardLimitUSD:       amount, | ||||||
| 		SystemHardLimitUSD: amount, | 		SystemHardLimitUSD: amount, | ||||||
|  | 		AccessUntil:        expiredTime, | ||||||
| 	} | 	} | ||||||
| 	c.JSON(200, subscription) | 	c.JSON(200, subscription) | ||||||
| 	return | 	return | ||||||
|   | |||||||
| @@ -22,6 +22,7 @@ type OpenAISubscriptionResponse struct { | |||||||
| 	SoftLimitUSD       float64 `json:"soft_limit_usd"` | 	SoftLimitUSD       float64 `json:"soft_limit_usd"` | ||||||
| 	HardLimitUSD       float64 `json:"hard_limit_usd"` | 	HardLimitUSD       float64 `json:"hard_limit_usd"` | ||||||
| 	SystemHardLimitUSD float64 `json:"system_hard_limit_usd"` | 	SystemHardLimitUSD float64 `json:"system_hard_limit_usd"` | ||||||
|  | 	AccessUntil        int64   `json:"access_until"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type OpenAIUsageDailyCost struct { | type OpenAIUsageDailyCost struct { | ||||||
| @@ -84,7 +85,6 @@ func GetAuthHeader(token string) http.Header { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { | func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { | ||||||
| 	client := &http.Client{} |  | ||||||
| 	req, err := http.NewRequest(method, url, nil) | 	req, err := http.NewRequest(method, url, nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -92,10 +92,13 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He | |||||||
| 	for k := range headers { | 	for k := range headers { | ||||||
| 		req.Header.Add(k, headers.Get(k)) | 		req.Header.Add(k, headers.Get(k)) | ||||||
| 	} | 	} | ||||||
| 	res, err := client.Do(req) | 	res, err := httpClient.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | 	if res.StatusCode != http.StatusOK { | ||||||
|  | 		return nil, fmt.Errorf("status code: %d", res.StatusCode) | ||||||
|  | 	} | ||||||
| 	body, err := io.ReadAll(res.Body) | 	body, err := io.ReadAll(res.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -108,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He | |||||||
| } | } | ||||||
|  |  | ||||||
| func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { | func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { | ||||||
| 	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL) | 	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL()) | ||||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||||
|  |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -198,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { | |||||||
|  |  | ||||||
| func updateChannelBalance(channel *model.Channel) (float64, error) { | func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||||
| 	baseURL := common.ChannelBaseURLs[channel.Type] | 	baseURL := common.ChannelBaseURLs[channel.Type] | ||||||
| 	if channel.BaseURL == "" { | 	if channel.GetBaseURL() == "" { | ||||||
| 		channel.BaseURL = baseURL | 		channel.BaseURL = &baseURL | ||||||
| 	} | 	} | ||||||
| 	switch channel.Type { | 	switch channel.Type { | ||||||
| 	case common.ChannelTypeOpenAI: | 	case common.ChannelTypeOpenAI: | ||||||
| 		if channel.BaseURL != "" { | 		if channel.GetBaseURL() != "" { | ||||||
| 			baseURL = channel.BaseURL | 			baseURL = channel.GetBaseURL() | ||||||
| 		} | 		} | ||||||
| 	case common.ChannelTypeAzure: | 	case common.ChannelTypeAzure: | ||||||
| 		return 0, errors.New("尚未实现") | 		return 0, errors.New("尚未实现") | ||||||
| 	case common.ChannelTypeCustom: | 	case common.ChannelTypeCustom: | ||||||
| 		baseURL = channel.BaseURL | 		baseURL = channel.GetBaseURL() | ||||||
| 	case common.ChannelTypeCloseAI: | 	case common.ChannelTypeCloseAI: | ||||||
| 		return updateChannelCloseAIBalance(channel) | 		return updateChannelCloseAIBalance(channel) | ||||||
| 	case common.ChannelTypeOpenAISB: | 	case common.ChannelTypeOpenAISB: | ||||||
|   | |||||||
| @@ -14,30 +14,49 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func testChannel(channel *model.Channel, request ChatRequest) error { | func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { | ||||||
| 	switch channel.Type { | 	switch channel.Type { | ||||||
|  | 	case common.ChannelTypePaLM: | ||||||
|  | 		fallthrough | ||||||
|  | 	case common.ChannelTypeAnthropic: | ||||||
|  | 		fallthrough | ||||||
|  | 	case common.ChannelTypeBaidu: | ||||||
|  | 		fallthrough | ||||||
|  | 	case common.ChannelTypeZhipu: | ||||||
|  | 		fallthrough | ||||||
|  | 	case common.ChannelTypeAli: | ||||||
|  | 		fallthrough | ||||||
|  | 	case common.ChannelType360: | ||||||
|  | 		fallthrough | ||||||
|  | 	case common.ChannelTypeXunfei: | ||||||
|  | 		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil | ||||||
| 	case common.ChannelTypeAzure: | 	case common.ChannelTypeAzure: | ||||||
| 		request.Model = "gpt-35-turbo" | 		request.Model = "gpt-35-turbo" | ||||||
|  | 		defer func() { | ||||||
|  | 			if err != nil { | ||||||
|  | 				err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!") | ||||||
|  | 			} | ||||||
|  | 		}() | ||||||
| 	default: | 	default: | ||||||
| 		request.Model = "gpt-3.5-turbo" | 		request.Model = "gpt-3.5-turbo" | ||||||
| 	} | 	} | ||||||
| 	requestURL := common.ChannelBaseURLs[channel.Type] | 	requestURL := common.ChannelBaseURLs[channel.Type] | ||||||
| 	if channel.Type == common.ChannelTypeAzure { | 	if channel.Type == common.ChannelTypeAzure { | ||||||
| 		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model) | 		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) | ||||||
| 	} else { | 	} else { | ||||||
| 		if channel.BaseURL != "" { | 		if channel.GetBaseURL() != "" { | ||||||
| 			requestURL = channel.BaseURL | 			requestURL = channel.GetBaseURL() | ||||||
| 		} | 		} | ||||||
| 		requestURL += "/v1/chat/completions" | 		requestURL += "/v1/chat/completions" | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	jsonData, err := json.Marshal(request) | 	jsonData, err := json.Marshal(request) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err, nil | ||||||
| 	} | 	} | ||||||
| 	req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) | 	req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err, nil | ||||||
| 	} | 	} | ||||||
| 	if channel.Type == common.ChannelTypeAzure { | 	if channel.Type == common.ChannelTypeAzure { | ||||||
| 		req.Header.Set("api-key", channel.Key) | 		req.Header.Set("api-key", channel.Key) | ||||||
| @@ -45,21 +64,20 @@ func testChannel(channel *model.Channel, request ChatRequest) error { | |||||||
| 		req.Header.Set("Authorization", "Bearer "+channel.Key) | 		req.Header.Set("Authorization", "Bearer "+channel.Key) | ||||||
| 	} | 	} | ||||||
| 	req.Header.Set("Content-Type", "application/json") | 	req.Header.Set("Content-Type", "application/json") | ||||||
| 	client := &http.Client{} | 	resp, err := httpClient.Do(req) | ||||||
| 	resp, err := client.Do(req) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err, nil | ||||||
| 	} | 	} | ||||||
| 	defer resp.Body.Close() | 	defer resp.Body.Close() | ||||||
| 	var response TextResponse | 	var response TextResponse | ||||||
| 	err = json.NewDecoder(resp.Body).Decode(&response) | 	err = json.NewDecoder(resp.Body).Decode(&response) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err, nil | ||||||
| 	} | 	} | ||||||
| 	if response.Usage.CompletionTokens == 0 { | 	if response.Usage.CompletionTokens == 0 { | ||||||
| 		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) | 		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func buildTestRequest() *ChatRequest { | func buildTestRequest() *ChatRequest { | ||||||
| @@ -94,7 +112,7 @@ func TestChannel(c *gin.Context) { | |||||||
| 	} | 	} | ||||||
| 	testRequest := buildTestRequest() | 	testRequest := buildTestRequest() | ||||||
| 	tik := time.Now() | 	tik := time.Now() | ||||||
| 	err = testChannel(channel, *testRequest) | 	err, _ = testChannel(channel, *testRequest) | ||||||
| 	tok := time.Now() | 	tok := time.Now() | ||||||
| 	milliseconds := tok.Sub(tik).Milliseconds() | 	milliseconds := tok.Sub(tik).Milliseconds() | ||||||
| 	go channel.UpdateResponseTime(milliseconds) | 	go channel.UpdateResponseTime(milliseconds) | ||||||
| @@ -158,13 +176,14 @@ func testAllChannels(notify bool) error { | |||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| 			tik := time.Now() | 			tik := time.Now() | ||||||
| 			err := testChannel(channel, *testRequest) | 			err, openaiErr := testChannel(channel, *testRequest) | ||||||
| 			tok := time.Now() | 			tok := time.Now() | ||||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | 			milliseconds := tok.Sub(tik).Milliseconds() | ||||||
| 			if err != nil || milliseconds > disableThreshold { |  | ||||||
| 			if milliseconds > disableThreshold { | 			if milliseconds > disableThreshold { | ||||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||||
|  | 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||||
| 			} | 			} | ||||||
|  | 			if shouldDisableChannel(openaiErr, -1) { | ||||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||||
| 			} | 			} | ||||||
| 			channel.UpdateResponseTime(milliseconds) | 			channel.UpdateResponseTime(milliseconds) | ||||||
|   | |||||||
| @@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) { | |||||||
| 	} | 	} | ||||||
| 	channel.CreatedTime = common.GetTimestamp() | 	channel.CreatedTime = common.GetTimestamp() | ||||||
| 	keys := strings.Split(channel.Key, "\n") | 	keys := strings.Split(channel.Key, "\n") | ||||||
| 	channels := make([]model.Channel, 0) | 	channels := make([]model.Channel, 0, len(keys)) | ||||||
| 	for _, key := range keys { | 	for _, key := range keys { | ||||||
| 		if key == "" { | 		if key == "" { | ||||||
| 			continue | 			continue | ||||||
|   | |||||||
| @@ -79,6 +79,14 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { | |||||||
|  |  | ||||||
| func GitHubOAuth(c *gin.Context) { | func GitHubOAuth(c *gin.Context) { | ||||||
| 	session := sessions.Default(c) | 	session := sessions.Default(c) | ||||||
|  | 	state := c.Query("state") | ||||||
|  | 	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { | ||||||
|  | 		c.JSON(http.StatusForbidden, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": "state is empty or not same", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
| 	username := session.Get("username") | 	username := session.Get("username") | ||||||
| 	if username != nil { | 	if username != nil { | ||||||
| 		GitHubBind(c) | 		GitHubBind(c) | ||||||
| @@ -205,3 +213,22 @@ func GitHubBind(c *gin.Context) { | |||||||
| 	}) | 	}) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GenerateOAuthCode(c *gin.Context) { | ||||||
|  | 	session := sessions.Default(c) | ||||||
|  | 	state := common.GetRandomString(12) | ||||||
|  | 	session.Set("oauth_state", state) | ||||||
|  | 	err := session.Save() | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 		"success": true, | ||||||
|  | 		"message": "", | ||||||
|  | 		"data":    state, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| @@ -18,19 +19,21 @@ func GetAllLogs(c *gin.Context) { | |||||||
| 	username := c.Query("username") | 	username := c.Query("username") | ||||||
| 	tokenName := c.Query("token_name") | 	tokenName := c.Query("token_name") | ||||||
| 	modelName := c.Query("model_name") | 	modelName := c.Query("model_name") | ||||||
| 	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) | 	channel, _ := strconv.Atoi(c.Query("channel")) | ||||||
|  | 	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": err.Error(), | 			"message": err.Error(), | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	c.JSON(200, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    logs, | 		"data":    logs, | ||||||
| 	}) | 	}) | ||||||
|  | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetUserLogs(c *gin.Context) { | func GetUserLogs(c *gin.Context) { | ||||||
| @@ -46,34 +49,36 @@ func GetUserLogs(c *gin.Context) { | |||||||
| 	modelName := c.Query("model_name") | 	modelName := c.Query("model_name") | ||||||
| 	logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) | 	logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": err.Error(), | 			"message": err.Error(), | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	c.JSON(200, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    logs, | 		"data":    logs, | ||||||
| 	}) | 	}) | ||||||
|  | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchAllLogs(c *gin.Context) { | func SearchAllLogs(c *gin.Context) { | ||||||
| 	keyword := c.Query("keyword") | 	keyword := c.Query("keyword") | ||||||
| 	logs, err := model.SearchAllLogs(keyword) | 	logs, err := model.SearchAllLogs(keyword) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": err.Error(), | 			"message": err.Error(), | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	c.JSON(200, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    logs, | 		"data":    logs, | ||||||
| 	}) | 	}) | ||||||
|  | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchUserLogs(c *gin.Context) { | func SearchUserLogs(c *gin.Context) { | ||||||
| @@ -81,17 +86,18 @@ func SearchUserLogs(c *gin.Context) { | |||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt("id") | ||||||
| 	logs, err := model.SearchUserLogs(userId, keyword) | 	logs, err := model.SearchUserLogs(userId, keyword) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": err.Error(), | 			"message": err.Error(), | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	c.JSON(200, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    logs, | 		"data":    logs, | ||||||
| 	}) | 	}) | ||||||
|  | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetLogsStat(c *gin.Context) { | func GetLogsStat(c *gin.Context) { | ||||||
| @@ -101,9 +107,10 @@ func GetLogsStat(c *gin.Context) { | |||||||
| 	tokenName := c.Query("token_name") | 	tokenName := c.Query("token_name") | ||||||
| 	username := c.Query("username") | 	username := c.Query("username") | ||||||
| 	modelName := c.Query("model_name") | 	modelName := c.Query("model_name") | ||||||
| 	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) | 	channel, _ := strconv.Atoi(c.Query("channel")) | ||||||
|  | 	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) | ||||||
| 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") | 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") | ||||||
| 	c.JSON(200, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data": gin.H{ | 		"data": gin.H{ | ||||||
| @@ -111,6 +118,7 @@ func GetLogsStat(c *gin.Context) { | |||||||
| 			//"token": tokenNum, | 			//"token": tokenNum, | ||||||
| 		}, | 		}, | ||||||
| 	}) | 	}) | ||||||
|  | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetLogsSelfStat(c *gin.Context) { | func GetLogsSelfStat(c *gin.Context) { | ||||||
| @@ -120,9 +128,10 @@ func GetLogsSelfStat(c *gin.Context) { | |||||||
| 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | ||||||
| 	tokenName := c.Query("token_name") | 	tokenName := c.Query("token_name") | ||||||
| 	modelName := c.Query("model_name") | 	modelName := c.Query("model_name") | ||||||
| 	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) | 	channel, _ := strconv.Atoi(c.Query("channel")) | ||||||
|  | 	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) | ||||||
| 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) | 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) | ||||||
| 	c.JSON(200, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data": gin.H{ | 		"data": gin.H{ | ||||||
| @@ -130,4 +139,30 @@ func GetLogsSelfStat(c *gin.Context) { | |||||||
| 			//"token": tokenNum, | 			//"token": tokenNum, | ||||||
| 		}, | 		}, | ||||||
| 	}) | 	}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func DeleteHistoryLogs(c *gin.Context) { | ||||||
|  | 	targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64) | ||||||
|  | 	if targetTimestamp == 0 { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": "target timestamp is required", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	count, err := model.DeleteOldLog(targetTimestamp) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 		"success": true, | ||||||
|  | 		"message": "", | ||||||
|  | 		"data":    count, | ||||||
|  | 	}) | ||||||
|  | 	return | ||||||
| } | } | ||||||
|   | |||||||
| @@ -3,10 +3,12 @@ package controller | |||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
|  | 	"strings" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func GetStatus(c *gin.Context) { | func GetStatus(c *gin.Context) { | ||||||
| @@ -78,6 +80,22 @@ func SendEmailVerification(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  | 	if common.EmailDomainRestrictionEnabled { | ||||||
|  | 		allowed := false | ||||||
|  | 		for _, domain := range common.EmailDomainWhitelist { | ||||||
|  | 			if strings.HasSuffix(email, "@"+domain) { | ||||||
|  | 				allowed = true | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		if !allowed { | ||||||
|  | 			c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 				"success": false, | ||||||
|  | 				"message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中", | ||||||
|  | 			}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 	if model.IsEmailAlreadyTaken(email) { | 	if model.IsEmailAlreadyTaken(email) { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -127,8 +145,9 @@ func SendPasswordResetEmail(c *gin.Context) { | |||||||
| 	link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) | 	link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) | ||||||
| 	subject := fmt.Sprintf("%s密码重置", common.SystemName) | 	subject := fmt.Sprintf("%s密码重置", common.SystemName) | ||||||
| 	content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ | 	content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ | ||||||
| 		"<p>点击<a href='%s'>此处</a>进行密码重置。</p>"+ | 		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ | ||||||
| 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, common.VerificationValidMinutes) | 		"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ | ||||||
|  | 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes) | ||||||
| 	err := common.SendEmail(subject, email, content) | 	err := common.SendEmail(subject, email, content) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|   | |||||||
| @@ -63,6 +63,15 @@ func init() { | |||||||
| 			Root:       "dall-e", | 			Root:       "dall-e", | ||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "whisper-1", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "openai", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "whisper-1", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			Id:         "gpt-3.5-turbo", | 			Id:         "gpt-3.5-turbo", | ||||||
| 			Object:     "model", | 			Object:     "model", | ||||||
| @@ -108,6 +117,15 @@ func init() { | |||||||
| 			Root:       "gpt-3.5-turbo-16k-0613", | 			Root:       "gpt-3.5-turbo-16k-0613", | ||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "gpt-3.5-turbo-instruct", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "openai", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "gpt-3.5-turbo-instruct", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			Id:         "gpt-4", | 			Id:         "gpt-4", | ||||||
| 			Object:     "model", | 			Object:     "model", | ||||||
| @@ -253,21 +271,165 @@ func init() { | |||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			Id:         "ChatGLM", | 			Id:         "claude-instant-1", | ||||||
| 			Object:     "model", | 			Object:     "model", | ||||||
| 			Created:    1677649963, | 			Created:    1677649963, | ||||||
| 			OwnedBy:    "thudm", | 			OwnedBy:    "anturopic", | ||||||
| 			Permission: permission, | 			Permission: permission, | ||||||
| 			Root:       "ChatGLM", | 			Root:       "claude-instant-1", | ||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			Id:         "ChatGLM2", | 			Id:         "claude-2", | ||||||
| 			Object:     "model", | 			Object:     "model", | ||||||
| 			Created:    1677649963, | 			Created:    1677649963, | ||||||
| 			OwnedBy:    "thudm", | 			OwnedBy:    "anturopic", | ||||||
| 			Permission: permission, | 			Permission: permission, | ||||||
| 			Root:       "ChatGLM2", | 			Root:       "claude-2", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "ERNIE-Bot", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "baidu", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "ERNIE-Bot", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "ERNIE-Bot-turbo", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "baidu", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "ERNIE-Bot-turbo", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "Embedding-V1", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "baidu", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "Embedding-V1", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "PaLM-2", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "google", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "PaLM-2", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "chatglm_pro", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "zhipu", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "chatglm_pro", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "chatglm_std", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "zhipu", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "chatglm_std", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "chatglm_lite", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "zhipu", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "chatglm_lite", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "qwen-turbo", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "ali", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "qwen-turbo", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "qwen-plus", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "ali", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "qwen-plus", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "text-embedding-v1", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "ali", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "text-embedding-v1", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "SparkDesk", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "xunfei", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "SparkDesk", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "360GPT_S2_V9", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "360", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "360GPT_S2_V9", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "embedding-bert-512-v1", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "360", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "embedding-bert-512-v1", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "embedding_s1_v1", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "360", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "embedding_s1_v1", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "semantic_similarity_s1_v1", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "360", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "semantic_similarity_s1_v1", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "360GPT_S2_V9.4", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "360", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "360GPT_S2_V9.4", | ||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -2,11 +2,12 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func GetOptions(c *gin.Context) { | func GetOptions(c *gin.Context) { | ||||||
| @@ -49,6 +50,14 @@ func UpdateOption(c *gin.Context) { | |||||||
| 			}) | 			}) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  | 	case "EmailDomainRestrictionEnabled": | ||||||
|  | 		if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { | ||||||
|  | 			c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 				"success": false, | ||||||
|  | 				"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", | ||||||
|  | 			}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
| 	case "WeChatAuthEnabled": | 	case "WeChatAuthEnabled": | ||||||
| 		if option.Value == "true" && common.WeChatServerAddress == "" { | 		if option.Value == "true" && common.WeChatServerAddress == "" { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
|   | |||||||
							
								
								
									
										220
									
								
								controller/relay-aiproxy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										220
									
								
								controller/relay-aiproxy.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,220 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 | ||||||
|  |  | ||||||
|  | type AIProxyLibraryRequest struct { | ||||||
|  | 	Model     string `json:"model"` | ||||||
|  | 	Query     string `json:"query"` | ||||||
|  | 	LibraryId string `json:"libraryId"` | ||||||
|  | 	Stream    bool   `json:"stream"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AIProxyLibraryError struct { | ||||||
|  | 	ErrCode int    `json:"errCode"` | ||||||
|  | 	Message string `json:"message"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AIProxyLibraryDocument struct { | ||||||
|  | 	Title string `json:"title"` | ||||||
|  | 	URL   string `json:"url"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AIProxyLibraryResponse struct { | ||||||
|  | 	Success   bool                     `json:"success"` | ||||||
|  | 	Answer    string                   `json:"answer"` | ||||||
|  | 	Documents []AIProxyLibraryDocument `json:"documents"` | ||||||
|  | 	AIProxyLibraryError | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AIProxyLibraryStreamResponse struct { | ||||||
|  | 	Content   string                   `json:"content"` | ||||||
|  | 	Finish    bool                     `json:"finish"` | ||||||
|  | 	Model     string                   `json:"model"` | ||||||
|  | 	Documents []AIProxyLibraryDocument `json:"documents"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { | ||||||
|  | 	query := "" | ||||||
|  | 	if len(request.Messages) != 0 { | ||||||
|  | 		query = request.Messages[len(request.Messages)-1].Content | ||||||
|  | 	} | ||||||
|  | 	return &AIProxyLibraryRequest{ | ||||||
|  | 		Model:  request.Model, | ||||||
|  | 		Stream: request.Stream, | ||||||
|  | 		Query:  query, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { | ||||||
|  | 	if len(documents) == 0 { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	content := "\n\n参考文档:\n" | ||||||
|  | 	for i, document := range documents { | ||||||
|  | 		content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL) | ||||||
|  | 	} | ||||||
|  | 	return content | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { | ||||||
|  | 	content := response.Answer + aiProxyDocuments2Markdown(response.Documents) | ||||||
|  | 	choice := OpenAITextResponseChoice{ | ||||||
|  | 		Index: 0, | ||||||
|  | 		Message: Message{ | ||||||
|  | 			Role:    "assistant", | ||||||
|  | 			Content: content, | ||||||
|  | 		}, | ||||||
|  | 		FinishReason: "stop", | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := OpenAITextResponse{ | ||||||
|  | 		Id:      common.GetUUID(), | ||||||
|  | 		Object:  "chat.completion", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Choices: []OpenAITextResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | 	return &fullTextResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { | ||||||
|  | 	var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 	choice.Delta.Content = aiProxyDocuments2Markdown(documents) | ||||||
|  | 	choice.FinishReason = &stopFinishReason | ||||||
|  | 	return &ChatCompletionsStreamResponse{ | ||||||
|  | 		Id:      common.GetUUID(), | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   "", | ||||||
|  | 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { | ||||||
|  | 	var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 	choice.Delta.Content = response.Content | ||||||
|  | 	return &ChatCompletionsStreamResponse{ | ||||||
|  | 		Id:      common.GetUUID(), | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   response.Model, | ||||||
|  | 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var usage Usage | ||||||
|  | 	scanner := bufio.NewScanner(resp.Body) | ||||||
|  | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
|  | 		if atEOF && len(data) == 0 { | ||||||
|  | 			return 0, nil, nil | ||||||
|  | 		} | ||||||
|  | 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||||
|  | 			return i + 1, data[0:i], nil | ||||||
|  | 		} | ||||||
|  | 		if atEOF { | ||||||
|  | 			return len(data), data, nil | ||||||
|  | 		} | ||||||
|  | 		return 0, nil, nil | ||||||
|  | 	}) | ||||||
|  | 	dataChan := make(chan string) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		for scanner.Scan() { | ||||||
|  | 			data := scanner.Text() | ||||||
|  | 			if len(data) < 5 { // ignore blank line or wrong format | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			if data[:5] != "data:" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			data = data[5:] | ||||||
|  | 			dataChan <- data | ||||||
|  | 		} | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  | 	setEventStreamHeaders(c) | ||||||
|  | 	var documents []AIProxyLibraryDocument | ||||||
|  | 	c.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case data := <-dataChan: | ||||||
|  | 			var AIProxyLibraryResponse AIProxyLibraryStreamResponse | ||||||
|  | 			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			if len(AIProxyLibraryResponse.Documents) != 0 { | ||||||
|  | 				documents = AIProxyLibraryResponse.Documents | ||||||
|  | 			} | ||||||
|  | 			response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			response := documentsAIProxyLibrary(documents) | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	err := resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	return nil, &usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var AIProxyLibraryResponse AIProxyLibraryResponse | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	if AIProxyLibraryResponse.ErrCode != 0 { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: AIProxyLibraryResponse.Message, | ||||||
|  | 				Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode), | ||||||
|  | 				Code:    AIProxyLibraryResponse.ErrCode, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
							
								
								
									
										329
									
								
								controller/relay-ali.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										329
									
								
								controller/relay-ali.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,329 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r | ||||||
|  |  | ||||||
|  | type AliMessage struct { | ||||||
|  | 	User string `json:"user"` | ||||||
|  | 	Bot  string `json:"bot"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliInput struct { | ||||||
|  | 	Prompt  string       `json:"prompt"` | ||||||
|  | 	History []AliMessage `json:"history"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliParameters struct { | ||||||
|  | 	TopP         float64 `json:"top_p,omitempty"` | ||||||
|  | 	TopK         int     `json:"top_k,omitempty"` | ||||||
|  | 	Seed         uint64  `json:"seed,omitempty"` | ||||||
|  | 	EnableSearch bool    `json:"enable_search,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliChatRequest struct { | ||||||
|  | 	Model      string        `json:"model"` | ||||||
|  | 	Input      AliInput      `json:"input"` | ||||||
|  | 	Parameters AliParameters `json:"parameters,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliEmbeddingRequest struct { | ||||||
|  | 	Model string `json:"model"` | ||||||
|  | 	Input struct { | ||||||
|  | 		Texts []string `json:"texts"` | ||||||
|  | 	} `json:"input"` | ||||||
|  | 	Parameters *struct { | ||||||
|  | 		TextType string `json:"text_type,omitempty"` | ||||||
|  | 	} `json:"parameters,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliEmbedding struct { | ||||||
|  | 	Embedding []float64 `json:"embedding"` | ||||||
|  | 	TextIndex int       `json:"text_index"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliEmbeddingResponse struct { | ||||||
|  | 	Output struct { | ||||||
|  | 		Embeddings []AliEmbedding `json:"embeddings"` | ||||||
|  | 	} `json:"output"` | ||||||
|  | 	Usage AliUsage `json:"usage"` | ||||||
|  | 	AliError | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliError struct { | ||||||
|  | 	Code      string `json:"code"` | ||||||
|  | 	Message   string `json:"message"` | ||||||
|  | 	RequestId string `json:"request_id"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliUsage struct { | ||||||
|  | 	InputTokens  int `json:"input_tokens"` | ||||||
|  | 	OutputTokens int `json:"output_tokens"` | ||||||
|  | 	TotalTokens  int `json:"total_tokens"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliOutput struct { | ||||||
|  | 	Text         string `json:"text"` | ||||||
|  | 	FinishReason string `json:"finish_reason"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliChatResponse struct { | ||||||
|  | 	Output AliOutput `json:"output"` | ||||||
|  | 	Usage  AliUsage  `json:"usage"` | ||||||
|  | 	AliError | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { | ||||||
|  | 	messages := make([]AliMessage, 0, len(request.Messages)) | ||||||
|  | 	prompt := "" | ||||||
|  | 	for i := 0; i < len(request.Messages); i++ { | ||||||
|  | 		message := request.Messages[i] | ||||||
|  | 		if message.Role == "system" { | ||||||
|  | 			messages = append(messages, AliMessage{ | ||||||
|  | 				User: message.Content, | ||||||
|  | 				Bot:  "Okay", | ||||||
|  | 			}) | ||||||
|  | 			continue | ||||||
|  | 		} else { | ||||||
|  | 			if i == len(request.Messages)-1 { | ||||||
|  | 				prompt = message.Content | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 			messages = append(messages, AliMessage{ | ||||||
|  | 				User: message.Content, | ||||||
|  | 				Bot:  request.Messages[i+1].Content, | ||||||
|  | 			}) | ||||||
|  | 			i++ | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return &AliChatRequest{ | ||||||
|  | 		Model: request.Model, | ||||||
|  | 		Input: AliInput{ | ||||||
|  | 			Prompt:  prompt, | ||||||
|  | 			History: messages, | ||||||
|  | 		}, | ||||||
|  | 		//Parameters: AliParameters{  // ChatGPT's parameters are not compatible with Ali's | ||||||
|  | 		//	TopP: request.TopP, | ||||||
|  | 		//	TopK: 50, | ||||||
|  | 		//	//Seed:         0, | ||||||
|  | 		//	//EnableSearch: false, | ||||||
|  | 		//}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { | ||||||
|  | 	return &AliEmbeddingRequest{ | ||||||
|  | 		Model: "text-embedding-v1", | ||||||
|  | 		Input: struct { | ||||||
|  | 			Texts []string `json:"texts"` | ||||||
|  | 		}{ | ||||||
|  | 			Texts: request.ParseInput(), | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var aliResponse AliEmbeddingResponse | ||||||
|  | 	err := json.NewDecoder(resp.Body).Decode(&aliResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if aliResponse.Code != "" { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: aliResponse.Message, | ||||||
|  | 				Type:    aliResponse.Code, | ||||||
|  | 				Param:   aliResponse.RequestId, | ||||||
|  | 				Code:    aliResponse.Code, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { | ||||||
|  | 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ | ||||||
|  | 		Object: "list", | ||||||
|  | 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), | ||||||
|  | 		Model:  "text-embedding-v1", | ||||||
|  | 		Usage:  Usage{TotalTokens: response.Usage.TotalTokens}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, item := range response.Output.Embeddings { | ||||||
|  | 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ | ||||||
|  | 			Object:    `embedding`, | ||||||
|  | 			Index:     item.TextIndex, | ||||||
|  | 			Embedding: item.Embedding, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	return &openAIEmbeddingResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { | ||||||
|  | 	choice := OpenAITextResponseChoice{ | ||||||
|  | 		Index: 0, | ||||||
|  | 		Message: Message{ | ||||||
|  | 			Role:    "assistant", | ||||||
|  | 			Content: response.Output.Text, | ||||||
|  | 		}, | ||||||
|  | 		FinishReason: response.Output.FinishReason, | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := OpenAITextResponse{ | ||||||
|  | 		Id:      response.RequestId, | ||||||
|  | 		Object:  "chat.completion", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Choices: []OpenAITextResponseChoice{choice}, | ||||||
|  | 		Usage: Usage{ | ||||||
|  | 			PromptTokens:     response.Usage.InputTokens, | ||||||
|  | 			CompletionTokens: response.Usage.OutputTokens, | ||||||
|  | 			TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	return &fullTextResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { | ||||||
|  | 	var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 	choice.Delta.Content = aliResponse.Output.Text | ||||||
|  | 	if aliResponse.Output.FinishReason != "null" { | ||||||
|  | 		finishReason := aliResponse.Output.FinishReason | ||||||
|  | 		choice.FinishReason = &finishReason | ||||||
|  | 	} | ||||||
|  | 	response := ChatCompletionsStreamResponse{ | ||||||
|  | 		Id:      aliResponse.RequestId, | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   "ernie-bot", | ||||||
|  | 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | 	return &response | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var usage Usage | ||||||
|  | 	scanner := bufio.NewScanner(resp.Body) | ||||||
|  | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
|  | 		if atEOF && len(data) == 0 { | ||||||
|  | 			return 0, nil, nil | ||||||
|  | 		} | ||||||
|  | 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||||
|  | 			return i + 1, data[0:i], nil | ||||||
|  | 		} | ||||||
|  | 		if atEOF { | ||||||
|  | 			return len(data), data, nil | ||||||
|  | 		} | ||||||
|  | 		return 0, nil, nil | ||||||
|  | 	}) | ||||||
|  | 	dataChan := make(chan string) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		for scanner.Scan() { | ||||||
|  | 			data := scanner.Text() | ||||||
|  | 			if len(data) < 5 { // ignore blank line or wrong format | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			if data[:5] != "data:" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			data = data[5:] | ||||||
|  | 			dataChan <- data | ||||||
|  | 		} | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  | 	setEventStreamHeaders(c) | ||||||
|  | 	lastResponseText := "" | ||||||
|  | 	c.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case data := <-dataChan: | ||||||
|  | 			var aliResponse AliChatResponse | ||||||
|  | 			err := json.Unmarshal([]byte(data), &aliResponse) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			if aliResponse.Usage.OutputTokens != 0 { | ||||||
|  | 				usage.PromptTokens = aliResponse.Usage.InputTokens | ||||||
|  | 				usage.CompletionTokens = aliResponse.Usage.OutputTokens | ||||||
|  | 				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | ||||||
|  | 			} | ||||||
|  | 			response := streamResponseAli2OpenAI(&aliResponse) | ||||||
|  | 			response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) | ||||||
|  | 			lastResponseText = aliResponse.Output.Text | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	err := resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	return nil, &usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var aliResponse AliChatResponse | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &aliResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	if aliResponse.Code != "" { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: aliResponse.Message, | ||||||
|  | 				Type:    aliResponse.Code, | ||||||
|  | 				Param:   aliResponse.RequestId, | ||||||
|  | 				Code:    aliResponse.Code, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := responseAli2OpenAI(&aliResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
							
								
								
									
										153
									
								
								controller/relay-audio.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										153
									
								
								controller/relay-audio.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,153 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"context" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||||
|  | 	audioModel := "whisper-1" | ||||||
|  |  | ||||||
|  | 	tokenId := c.GetInt("token_id") | ||||||
|  | 	channelType := c.GetInt("channel") | ||||||
|  | 	channelId := c.GetInt("channel_id") | ||||||
|  | 	userId := c.GetInt("id") | ||||||
|  | 	group := c.GetString("group") | ||||||
|  |  | ||||||
|  | 	preConsumedTokens := common.PreConsumedQuota | ||||||
|  | 	modelRatio := common.GetModelRatio(audioModel) | ||||||
|  | 	groupRatio := common.GetGroupRatio(group) | ||||||
|  | 	ratio := modelRatio * groupRatio | ||||||
|  | 	preConsumedQuota := int(float64(preConsumedTokens) * ratio) | ||||||
|  | 	userQuota, err := model.CacheGetUserQuota(userId) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	if userQuota-preConsumedQuota < 0 { | ||||||
|  | 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||||
|  | 	} | ||||||
|  | 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	if userQuota > 100*preConsumedQuota { | ||||||
|  | 		// in this case, we do not pre-consume quota | ||||||
|  | 		// because the user has enough quota | ||||||
|  | 		preConsumedQuota = 0 | ||||||
|  | 	} | ||||||
|  | 	if preConsumedQuota > 0 { | ||||||
|  | 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// map model name | ||||||
|  | 	modelMapping := c.GetString("model_mapping") | ||||||
|  | 	if modelMapping != "" { | ||||||
|  | 		modelMap := make(map[string]string) | ||||||
|  | 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||||
|  | 		} | ||||||
|  | 		if modelMap[audioModel] != "" { | ||||||
|  | 			audioModel = modelMap[audioModel] | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	baseURL := common.ChannelBaseURLs[channelType] | ||||||
|  | 	requestURL := c.Request.URL.String() | ||||||
|  |  | ||||||
|  | 	if c.GetString("base_url") != "" { | ||||||
|  | 		baseURL = c.GetString("base_url") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||||
|  | 	requestBody := c.Request.Body | ||||||
|  |  | ||||||
|  | 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||||
|  | 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||||
|  | 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||||
|  |  | ||||||
|  | 	resp, err := httpClient.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = req.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	err = c.Request.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	var audioResponse AudioResponse | ||||||
|  |  | ||||||
|  | 	defer func(ctx context.Context) { | ||||||
|  | 		go func() { | ||||||
|  | 			quota := countTokenText(audioResponse.Text, audioModel) | ||||||
|  | 			quotaDelta := quota - preConsumedQuota | ||||||
|  | 			err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error consuming token remain quota: " + err.Error()) | ||||||
|  | 			} | ||||||
|  | 			err = model.CacheUpdateUserQuota(userId) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error update user quota cache: " + err.Error()) | ||||||
|  | 			} | ||||||
|  | 			if quota != 0 { | ||||||
|  | 				tokenName := c.GetString("token_name") | ||||||
|  | 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||||
|  | 				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) | ||||||
|  | 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||||
|  | 				channelId := c.GetInt("channel_id") | ||||||
|  | 				model.UpdateChannelUsedQuota(channelId, quota) | ||||||
|  | 			} | ||||||
|  | 		}() | ||||||
|  | 	}(c.Request.Context()) | ||||||
|  |  | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &audioResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||||
|  |  | ||||||
|  | 	for k, v := range resp.Header { | ||||||
|  | 		c.Writer.Header().Set(k, v[0]) | ||||||
|  | 	} | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  |  | ||||||
|  | 	_, err = io.Copy(c.Writer, resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
							
								
								
									
										359
									
								
								controller/relay-baidu.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										359
									
								
								controller/relay-baidu.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,359 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"strings" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 | ||||||
|  |  | ||||||
|  | type BaiduTokenResponse struct { | ||||||
|  | 	ExpiresIn   int    `json:"expires_in"` | ||||||
|  | 	AccessToken string `json:"access_token"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type BaiduMessage struct { | ||||||
|  | 	Role    string `json:"role"` | ||||||
|  | 	Content string `json:"content"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type BaiduChatRequest struct { | ||||||
|  | 	Messages []BaiduMessage `json:"messages"` | ||||||
|  | 	Stream   bool           `json:"stream"` | ||||||
|  | 	UserId   string         `json:"user_id,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type BaiduError struct { | ||||||
|  | 	ErrorCode int    `json:"error_code"` | ||||||
|  | 	ErrorMsg  string `json:"error_msg"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type BaiduChatResponse struct { | ||||||
|  | 	Id               string `json:"id"` | ||||||
|  | 	Object           string `json:"object"` | ||||||
|  | 	Created          int64  `json:"created"` | ||||||
|  | 	Result           string `json:"result"` | ||||||
|  | 	IsTruncated      bool   `json:"is_truncated"` | ||||||
|  | 	NeedClearHistory bool   `json:"need_clear_history"` | ||||||
|  | 	Usage            Usage  `json:"usage"` | ||||||
|  | 	BaiduError | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type BaiduChatStreamResponse struct { | ||||||
|  | 	BaiduChatResponse | ||||||
|  | 	SentenceId int  `json:"sentence_id"` | ||||||
|  | 	IsEnd      bool `json:"is_end"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type BaiduEmbeddingRequest struct { | ||||||
|  | 	Input []string `json:"input"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type BaiduEmbeddingData struct { | ||||||
|  | 	Object    string    `json:"object"` | ||||||
|  | 	Embedding []float64 `json:"embedding"` | ||||||
|  | 	Index     int       `json:"index"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type BaiduEmbeddingResponse struct { | ||||||
|  | 	Id      string               `json:"id"` | ||||||
|  | 	Object  string               `json:"object"` | ||||||
|  | 	Created int64                `json:"created"` | ||||||
|  | 	Data    []BaiduEmbeddingData `json:"data"` | ||||||
|  | 	Usage   Usage                `json:"usage"` | ||||||
|  | 	BaiduError | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type BaiduAccessToken struct { | ||||||
|  | 	AccessToken      string    `json:"access_token"` | ||||||
|  | 	Error            string    `json:"error,omitempty"` | ||||||
|  | 	ErrorDescription string    `json:"error_description,omitempty"` | ||||||
|  | 	ExpiresIn        int64     `json:"expires_in,omitempty"` | ||||||
|  | 	ExpiresAt        time.Time `json:"-"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var baiduTokenStore sync.Map | ||||||
|  |  | ||||||
|  | func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { | ||||||
|  | 	messages := make([]BaiduMessage, 0, len(request.Messages)) | ||||||
|  | 	for _, message := range request.Messages { | ||||||
|  | 		if message.Role == "system" { | ||||||
|  | 			messages = append(messages, BaiduMessage{ | ||||||
|  | 				Role:    "user", | ||||||
|  | 				Content: message.Content, | ||||||
|  | 			}) | ||||||
|  | 			messages = append(messages, BaiduMessage{ | ||||||
|  | 				Role:    "assistant", | ||||||
|  | 				Content: "Okay", | ||||||
|  | 			}) | ||||||
|  | 		} else { | ||||||
|  | 			messages = append(messages, BaiduMessage{ | ||||||
|  | 				Role:    message.Role, | ||||||
|  | 				Content: message.Content, | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return &BaiduChatRequest{ | ||||||
|  | 		Messages: messages, | ||||||
|  | 		Stream:   request.Stream, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { | ||||||
|  | 	choice := OpenAITextResponseChoice{ | ||||||
|  | 		Index: 0, | ||||||
|  | 		Message: Message{ | ||||||
|  | 			Role:    "assistant", | ||||||
|  | 			Content: response.Result, | ||||||
|  | 		}, | ||||||
|  | 		FinishReason: "stop", | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := OpenAITextResponse{ | ||||||
|  | 		Id:      response.Id, | ||||||
|  | 		Object:  "chat.completion", | ||||||
|  | 		Created: response.Created, | ||||||
|  | 		Choices: []OpenAITextResponseChoice{choice}, | ||||||
|  | 		Usage:   response.Usage, | ||||||
|  | 	} | ||||||
|  | 	return &fullTextResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { | ||||||
|  | 	var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 	choice.Delta.Content = baiduResponse.Result | ||||||
|  | 	if baiduResponse.IsEnd { | ||||||
|  | 		choice.FinishReason = &stopFinishReason | ||||||
|  | 	} | ||||||
|  | 	response := ChatCompletionsStreamResponse{ | ||||||
|  | 		Id:      baiduResponse.Id, | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: baiduResponse.Created, | ||||||
|  | 		Model:   "ernie-bot", | ||||||
|  | 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | 	return &response | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { | ||||||
|  | 	return &BaiduEmbeddingRequest{ | ||||||
|  | 		Input: request.ParseInput(), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { | ||||||
|  | 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ | ||||||
|  | 		Object: "list", | ||||||
|  | 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), | ||||||
|  | 		Model:  "baidu-embedding", | ||||||
|  | 		Usage:  response.Usage, | ||||||
|  | 	} | ||||||
|  | 	for _, item := range response.Data { | ||||||
|  | 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ | ||||||
|  | 			Object:    item.Object, | ||||||
|  | 			Index:     item.Index, | ||||||
|  | 			Embedding: item.Embedding, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	return &openAIEmbeddingResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var usage Usage | ||||||
|  | 	scanner := bufio.NewScanner(resp.Body) | ||||||
|  | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
|  | 		if atEOF && len(data) == 0 { | ||||||
|  | 			return 0, nil, nil | ||||||
|  | 		} | ||||||
|  | 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||||
|  | 			return i + 1, data[0:i], nil | ||||||
|  | 		} | ||||||
|  | 		if atEOF { | ||||||
|  | 			return len(data), data, nil | ||||||
|  | 		} | ||||||
|  | 		return 0, nil, nil | ||||||
|  | 	}) | ||||||
|  | 	dataChan := make(chan string) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		for scanner.Scan() { | ||||||
|  | 			data := scanner.Text() | ||||||
|  | 			if len(data) < 6 { // ignore blank line or wrong format | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			data = data[6:] | ||||||
|  | 			dataChan <- data | ||||||
|  | 		} | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  | 	setEventStreamHeaders(c) | ||||||
|  | 	c.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case data := <-dataChan: | ||||||
|  | 			var baiduResponse BaiduChatStreamResponse | ||||||
|  | 			err := json.Unmarshal([]byte(data), &baiduResponse) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			if baiduResponse.Usage.TotalTokens != 0 { | ||||||
|  | 				usage.TotalTokens = baiduResponse.Usage.TotalTokens | ||||||
|  | 				usage.PromptTokens = baiduResponse.Usage.PromptTokens | ||||||
|  | 				usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens | ||||||
|  | 			} | ||||||
|  | 			response := streamResponseBaidu2OpenAI(&baiduResponse) | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	err := resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	return nil, &usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var baiduResponse BaiduChatResponse | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &baiduResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	if baiduResponse.ErrorMsg != "" { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: baiduResponse.ErrorMsg, | ||||||
|  | 				Type:    "baidu_error", | ||||||
|  | 				Param:   "", | ||||||
|  | 				Code:    baiduResponse.ErrorCode, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := responseBaidu2OpenAI(&baiduResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var baiduResponse BaiduEmbeddingResponse | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &baiduResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	if baiduResponse.ErrorMsg != "" { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: baiduResponse.ErrorMsg, | ||||||
|  | 				Type:    "baidu_error", | ||||||
|  | 				Param:   "", | ||||||
|  | 				Code:    baiduResponse.ErrorCode, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getBaiduAccessToken(apiKey string) (string, error) { | ||||||
|  | 	if val, ok := baiduTokenStore.Load(apiKey); ok { | ||||||
|  | 		var accessToken BaiduAccessToken | ||||||
|  | 		if accessToken, ok = val.(BaiduAccessToken); ok { | ||||||
|  | 			// soon this will expire | ||||||
|  | 			if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { | ||||||
|  | 				go func() { | ||||||
|  | 					_, _ = getBaiduAccessTokenHelper(apiKey) | ||||||
|  | 				}() | ||||||
|  | 			} | ||||||
|  | 			return accessToken.AccessToken, nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	accessToken, err := getBaiduAccessTokenHelper(apiKey) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	if accessToken == nil { | ||||||
|  | 		return "", errors.New("getBaiduAccessToken return a nil token") | ||||||
|  | 	} | ||||||
|  | 	return (*accessToken).AccessToken, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { | ||||||
|  | 	parts := strings.Split(apiKey, "|") | ||||||
|  | 	if len(parts) != 2 { | ||||||
|  | 		return nil, errors.New("invalid baidu apikey") | ||||||
|  | 	} | ||||||
|  | 	req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", | ||||||
|  | 		parts[0], parts[1]), nil) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	req.Header.Add("Content-Type", "application/json") | ||||||
|  | 	req.Header.Add("Accept", "application/json") | ||||||
|  | 	res, err := impatientHTTPClient.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	defer res.Body.Close() | ||||||
|  |  | ||||||
|  | 	var accessToken BaiduAccessToken | ||||||
|  | 	err = json.NewDecoder(res.Body).Decode(&accessToken) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if accessToken.Error != "" { | ||||||
|  | 		return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) | ||||||
|  | 	} | ||||||
|  | 	if accessToken.AccessToken == "" { | ||||||
|  | 		return nil, errors.New("getBaiduAccessTokenHelper get empty access token") | ||||||
|  | 	} | ||||||
|  | 	accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second) | ||||||
|  | 	baiduTokenStore.Store(apiKey, accessToken) | ||||||
|  | 	return &accessToken, nil | ||||||
|  | } | ||||||
							
								
								
									
										220
									
								
								controller/relay-claude.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										220
									
								
								controller/relay-claude.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,220 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type ClaudeMetadata struct { | ||||||
|  | 	UserId string `json:"user_id"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ClaudeRequest struct { | ||||||
|  | 	Model             string   `json:"model"` | ||||||
|  | 	Prompt            string   `json:"prompt"` | ||||||
|  | 	MaxTokensToSample int      `json:"max_tokens_to_sample"` | ||||||
|  | 	StopSequences     []string `json:"stop_sequences,omitempty"` | ||||||
|  | 	Temperature       float64  `json:"temperature,omitempty"` | ||||||
|  | 	TopP              float64  `json:"top_p,omitempty"` | ||||||
|  | 	TopK              int      `json:"top_k,omitempty"` | ||||||
|  | 	//ClaudeMetadata    `json:"metadata,omitempty"` | ||||||
|  | 	Stream bool `json:"stream,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ClaudeError struct { | ||||||
|  | 	Type    string `json:"type"` | ||||||
|  | 	Message string `json:"message"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ClaudeResponse struct { | ||||||
|  | 	Completion string      `json:"completion"` | ||||||
|  | 	StopReason string      `json:"stop_reason"` | ||||||
|  | 	Model      string      `json:"model"` | ||||||
|  | 	Error      ClaudeError `json:"error"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func stopReasonClaude2OpenAI(reason string) string { | ||||||
|  | 	switch reason { | ||||||
|  | 	case "stop_sequence": | ||||||
|  | 		return "stop" | ||||||
|  | 	case "max_tokens": | ||||||
|  | 		return "length" | ||||||
|  | 	default: | ||||||
|  | 		return reason | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { | ||||||
|  | 	claudeRequest := ClaudeRequest{ | ||||||
|  | 		Model:             textRequest.Model, | ||||||
|  | 		Prompt:            "", | ||||||
|  | 		MaxTokensToSample: textRequest.MaxTokens, | ||||||
|  | 		StopSequences:     nil, | ||||||
|  | 		Temperature:       textRequest.Temperature, | ||||||
|  | 		TopP:              textRequest.TopP, | ||||||
|  | 		Stream:            textRequest.Stream, | ||||||
|  | 	} | ||||||
|  | 	if claudeRequest.MaxTokensToSample == 0 { | ||||||
|  | 		claudeRequest.MaxTokensToSample = 1000000 | ||||||
|  | 	} | ||||||
|  | 	prompt := "" | ||||||
|  | 	for _, message := range textRequest.Messages { | ||||||
|  | 		if message.Role == "user" { | ||||||
|  | 			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) | ||||||
|  | 		} else if message.Role == "assistant" { | ||||||
|  | 			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) | ||||||
|  | 		} else if message.Role == "system" { | ||||||
|  | 			prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	prompt += "\n\nAssistant:" | ||||||
|  | 	claudeRequest.Prompt = prompt | ||||||
|  | 	return &claudeRequest | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse { | ||||||
|  | 	var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 	choice.Delta.Content = claudeResponse.Completion | ||||||
|  | 	finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) | ||||||
|  | 	if finishReason != "null" { | ||||||
|  | 		choice.FinishReason = &finishReason | ||||||
|  | 	} | ||||||
|  | 	var response ChatCompletionsStreamResponse | ||||||
|  | 	response.Object = "chat.completion.chunk" | ||||||
|  | 	response.Model = claudeResponse.Model | ||||||
|  | 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} | ||||||
|  | 	return &response | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { | ||||||
|  | 	choice := OpenAITextResponseChoice{ | ||||||
|  | 		Index: 0, | ||||||
|  | 		Message: Message{ | ||||||
|  | 			Role:    "assistant", | ||||||
|  | 			Content: strings.TrimPrefix(claudeResponse.Completion, " "), | ||||||
|  | 			Name:    nil, | ||||||
|  | 		}, | ||||||
|  | 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := OpenAITextResponse{ | ||||||
|  | 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | ||||||
|  | 		Object:  "chat.completion", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Choices: []OpenAITextResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | 	return &fullTextResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||||||
|  | 	responseText := "" | ||||||
|  | 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | ||||||
|  | 	createdTime := common.GetTimestamp() | ||||||
|  | 	scanner := bufio.NewScanner(resp.Body) | ||||||
|  | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
|  | 		if atEOF && len(data) == 0 { | ||||||
|  | 			return 0, nil, nil | ||||||
|  | 		} | ||||||
|  | 		if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { | ||||||
|  | 			return i + 4, data[0:i], nil | ||||||
|  | 		} | ||||||
|  | 		if atEOF { | ||||||
|  | 			return len(data), data, nil | ||||||
|  | 		} | ||||||
|  | 		return 0, nil, nil | ||||||
|  | 	}) | ||||||
|  | 	dataChan := make(chan string) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		for scanner.Scan() { | ||||||
|  | 			data := scanner.Text() | ||||||
|  | 			if !strings.HasPrefix(data, "event: completion") { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			data = strings.TrimPrefix(data, "event: completion\r\ndata: ") | ||||||
|  | 			dataChan <- data | ||||||
|  | 		} | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  | 	setEventStreamHeaders(c) | ||||||
|  | 	c.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case data := <-dataChan: | ||||||
|  | 			// some implementations may add \r at the end of data | ||||||
|  | 			data = strings.TrimSuffix(data, "\r") | ||||||
|  | 			var claudeResponse ClaudeResponse | ||||||
|  | 			err := json.Unmarshal([]byte(data), &claudeResponse) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			responseText += claudeResponse.Completion | ||||||
|  | 			response := streamResponseClaude2OpenAI(&claudeResponse) | ||||||
|  | 			response.Id = responseId | ||||||
|  | 			response.Created = createdTime | ||||||
|  | 			jsonStr, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	err := resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||||
|  | 	} | ||||||
|  | 	return nil, responseText | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	var claudeResponse ClaudeResponse | ||||||
|  | 	err = json.Unmarshal(responseBody, &claudeResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	if claudeResponse.Error.Type != "" { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: claudeResponse.Error.Message, | ||||||
|  | 				Type:    claudeResponse.Error.Type, | ||||||
|  | 				Param:   "", | ||||||
|  | 				Code:    claudeResponse.Error.Type, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := responseClaude2OpenAI(&claudeResponse) | ||||||
|  | 	completionTokens := countTokenText(claudeResponse.Completion, model) | ||||||
|  | 	usage := Usage{ | ||||||
|  | 		PromptTokens:     promptTokens, | ||||||
|  | 		CompletionTokens: completionTokens, | ||||||
|  | 		TotalTokens:      promptTokens + completionTokens, | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse.Usage = usage | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &usage | ||||||
|  | } | ||||||
| @@ -2,6 +2,7 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| @@ -18,6 +19,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
|  |  | ||||||
| 	tokenId := c.GetInt("token_id") | 	tokenId := c.GetInt("token_id") | ||||||
| 	channelType := c.GetInt("channel") | 	channelType := c.GetInt("channel") | ||||||
|  | 	channelId := c.GetInt("channel_id") | ||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt("id") | ||||||
| 	consumeQuota := c.GetBool("consume_quota") | 	consumeQuota := c.GetBool("consume_quota") | ||||||
| 	group := c.GetString("group") | 	group := c.GetString("group") | ||||||
| @@ -97,7 +99,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	quota := int(ratio*sizeRatio*1000) * imageRequest.N | 	quota := int(ratio*sizeRatio*1000) * imageRequest.N | ||||||
|  |  | ||||||
| 	if consumeQuota && userQuota-quota < 0 { | 	if consumeQuota && userQuota-quota < 0 { | ||||||
| 		return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden) | 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||||
| @@ -109,8 +111,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||||
|  |  | ||||||
| 	client := &http.Client{} | 	resp, err := httpClient.Do(req) | ||||||
| 	resp, err := client.Do(req) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| @@ -125,7 +126,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	} | 	} | ||||||
| 	var textResponse ImageResponse | 	var textResponse ImageResponse | ||||||
|  |  | ||||||
| 	defer func() { | 	defer func(ctx context.Context) { | ||||||
| 		if consumeQuota { | 		if consumeQuota { | ||||||
| 			err := model.PostConsumeTokenQuota(tokenId, quota) | 			err := model.PostConsumeTokenQuota(tokenId, quota) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -138,13 +139,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 			if quota != 0 { | 			if quota != 0 { | ||||||
| 				tokenName := c.GetString("token_name") | 				tokenName := c.GetString("token_name") | ||||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||||
| 				model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent) | 				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) | ||||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||||
| 				channelId := c.GetInt("channel_id") | 				channelId := c.GetInt("channel_id") | ||||||
| 				model.UpdateChannelUsedQuota(channelId, quota) | 				model.UpdateChannelUsedQuota(channelId, quota) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	}() | 	}(c.Request.Context()) | ||||||
|  |  | ||||||
| 	if consumeQuota { | 	if consumeQuota { | ||||||
| 		responseBody, err := io.ReadAll(resp.Body) | 		responseBody, err := io.ReadAll(resp.Body) | ||||||
|   | |||||||
							
								
								
									
										144
									
								
								controller/relay-openai.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								controller/relay-openai.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,144 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"bytes" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { | ||||||
|  | 	responseText := "" | ||||||
|  | 	scanner := bufio.NewScanner(resp.Body) | ||||||
|  | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
|  | 		if atEOF && len(data) == 0 { | ||||||
|  | 			return 0, nil, nil | ||||||
|  | 		} | ||||||
|  | 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||||
|  | 			return i + 1, data[0:i], nil | ||||||
|  | 		} | ||||||
|  | 		if atEOF { | ||||||
|  | 			return len(data), data, nil | ||||||
|  | 		} | ||||||
|  | 		return 0, nil, nil | ||||||
|  | 	}) | ||||||
|  | 	dataChan := make(chan string) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		for scanner.Scan() { | ||||||
|  | 			data := scanner.Text() | ||||||
|  | 			if len(data) < 6 { // ignore blank line or wrong format | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			if data[:6] != "data: " && data[:6] != "[DONE]" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			dataChan <- data | ||||||
|  | 			data = data[6:] | ||||||
|  | 			if !strings.HasPrefix(data, "[DONE]") { | ||||||
|  | 				switch relayMode { | ||||||
|  | 				case RelayModeChatCompletions: | ||||||
|  | 					var streamResponse ChatCompletionsStreamResponse | ||||||
|  | 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||||
|  | 					if err != nil { | ||||||
|  | 						common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 						continue // just ignore the error | ||||||
|  | 					} | ||||||
|  | 					for _, choice := range streamResponse.Choices { | ||||||
|  | 						responseText += choice.Delta.Content | ||||||
|  | 					} | ||||||
|  | 				case RelayModeCompletions: | ||||||
|  | 					var streamResponse CompletionsStreamResponse | ||||||
|  | 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||||
|  | 					if err != nil { | ||||||
|  | 						common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 						continue | ||||||
|  | 					} | ||||||
|  | 					for _, choice := range streamResponse.Choices { | ||||||
|  | 						responseText += choice.Text | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  | 	setEventStreamHeaders(c) | ||||||
|  | 	c.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case data := <-dataChan: | ||||||
|  | 			if strings.HasPrefix(data, "data: [DONE]") { | ||||||
|  | 				data = data[:12] | ||||||
|  | 			} | ||||||
|  | 			// some implementations may add \r at the end of data | ||||||
|  | 			data = strings.TrimSuffix(data, "\r") | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: data}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	err := resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||||
|  | 	} | ||||||
|  | 	return nil, responseText | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var textResponse TextResponse | ||||||
|  | 	if consumeQuota { | ||||||
|  | 		responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 		} | ||||||
|  | 		err = resp.Body.Close() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 		} | ||||||
|  | 		err = json.Unmarshal(responseBody, &textResponse) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 		} | ||||||
|  | 		if textResponse.Error.Type != "" { | ||||||
|  | 			return &OpenAIErrorWithStatusCode{ | ||||||
|  | 				OpenAIError: textResponse.Error, | ||||||
|  | 				StatusCode:  resp.StatusCode, | ||||||
|  | 			}, nil | ||||||
|  | 		} | ||||||
|  | 		// Reset response body | ||||||
|  | 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||||
|  | 	} | ||||||
|  | 	// We shouldn't set the header before we parse the response body, because the parse part may fail. | ||||||
|  | 	// And then we will have to send an error response, but in this case, the header has already been set. | ||||||
|  | 	// So the httpClient will be confused by the response. | ||||||
|  | 	// For example, Postman will report error, and we cannot check the response at all. | ||||||
|  | 	for k, v := range resp.Header { | ||||||
|  | 		c.Writer.Header().Set(k, v[0]) | ||||||
|  | 	} | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err := io.Copy(c.Writer, resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if textResponse.Usage.TotalTokens == 0 { | ||||||
|  | 		completionTokens := 0 | ||||||
|  | 		for _, choice := range textResponse.Choices { | ||||||
|  | 			completionTokens += countTokenText(choice.Message.Content, model) | ||||||
|  | 		} | ||||||
|  | 		textResponse.Usage = Usage{ | ||||||
|  | 			PromptTokens:     promptTokens, | ||||||
|  | 			CompletionTokens: completionTokens, | ||||||
|  | 			TotalTokens:      promptTokens + completionTokens, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil, &textResponse.Usage | ||||||
|  | } | ||||||
| @@ -1,10 +1,17 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body | ||||||
|  | // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body | ||||||
|  |  | ||||||
| type PaLMChatMessage struct { | type PaLMChatMessage struct { | ||||||
| 	Author  string `json:"author"` | 	Author  string `json:"author"` | ||||||
| 	Content string `json:"content"` | 	Content string `json:"content"` | ||||||
| @@ -15,45 +22,184 @@ type PaLMFilter struct { | |||||||
| 	Message string `json:"message"` | 	Message string `json:"message"` | ||||||
| } | } | ||||||
|  |  | ||||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body | type PaLMPrompt struct { | ||||||
| type PaLMChatRequest struct { | 	Messages []PaLMChatMessage `json:"messages"` | ||||||
| 	Prompt         []Message `json:"prompt"` | } | ||||||
| 	Temperature    float64   `json:"temperature"` |  | ||||||
| 	CandidateCount int       `json:"candidateCount"` | type PaLMChatRequest struct { | ||||||
| 	TopP           float64   `json:"topP"` | 	Prompt         PaLMPrompt `json:"prompt"` | ||||||
| 	TopK           int       `json:"topK"` | 	Temperature    float64    `json:"temperature,omitempty"` | ||||||
|  | 	CandidateCount int        `json:"candidateCount,omitempty"` | ||||||
|  | 	TopP           float64    `json:"topP,omitempty"` | ||||||
|  | 	TopK           int        `json:"topK,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type PaLMError struct { | ||||||
|  | 	Code    int    `json:"code"` | ||||||
|  | 	Message string `json:"message"` | ||||||
|  | 	Status  string `json:"status"` | ||||||
| } | } | ||||||
|  |  | ||||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body |  | ||||||
| type PaLMChatResponse struct { | type PaLMChatResponse struct { | ||||||
| 	Candidates []Message    `json:"candidates"` | 	Candidates []PaLMChatMessage `json:"candidates"` | ||||||
| 	Messages   []Message         `json:"messages"` | 	Messages   []Message         `json:"messages"` | ||||||
| 	Filters    []PaLMFilter      `json:"filters"` | 	Filters    []PaLMFilter      `json:"filters"` | ||||||
|  | 	Error      PaLMError         `json:"error"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode { | func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { | ||||||
| 	// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage | 	palmRequest := PaLMChatRequest{ | ||||||
| 	messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages)) | 		Prompt: PaLMPrompt{ | ||||||
| 	for _, message := range openAIRequest.Messages { | 			Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), | ||||||
| 		var author string | 		}, | ||||||
| 		if message.Role == "user" { | 		Temperature:    textRequest.Temperature, | ||||||
| 			author = "0" | 		CandidateCount: textRequest.N, | ||||||
| 		} else { | 		TopP:           textRequest.TopP, | ||||||
| 			author = "1" | 		TopK:           textRequest.MaxTokens, | ||||||
| 	} | 	} | ||||||
| 		messages = append(messages, PaLMChatMessage{ | 	for _, message := range textRequest.Messages { | ||||||
| 			Author:  author, | 		palmMessage := PaLMChatMessage{ | ||||||
| 			Content: message.Content, | 			Content: message.Content, | ||||||
| 		}) |  | ||||||
| 		} | 		} | ||||||
| 	request := PaLMChatRequest{ | 		if message.Role == "user" { | ||||||
| 		Prompt:         nil, | 			palmMessage.Author = "0" | ||||||
| 		Temperature:    openAIRequest.Temperature, | 		} else { | ||||||
| 		CandidateCount: openAIRequest.N, | 			palmMessage.Author = "1" | ||||||
| 		TopP:           openAIRequest.TopP, |  | ||||||
| 		TopK:           openAIRequest.MaxTokens, |  | ||||||
| 		} | 		} | ||||||
| 	// TODO: forward request to PaLM & convert response | 		palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage) | ||||||
| 	fmt.Print(request) | 	} | ||||||
| 	return nil | 	return &palmRequest | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { | ||||||
|  | 	fullTextResponse := OpenAITextResponse{ | ||||||
|  | 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), | ||||||
|  | 	} | ||||||
|  | 	for i, candidate := range response.Candidates { | ||||||
|  | 		choice := OpenAITextResponseChoice{ | ||||||
|  | 			Index: i, | ||||||
|  | 			Message: Message{ | ||||||
|  | 				Role:    "assistant", | ||||||
|  | 				Content: candidate.Content, | ||||||
|  | 			}, | ||||||
|  | 			FinishReason: "stop", | ||||||
|  | 		} | ||||||
|  | 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice) | ||||||
|  | 	} | ||||||
|  | 	return &fullTextResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse { | ||||||
|  | 	var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 	if len(palmResponse.Candidates) > 0 { | ||||||
|  | 		choice.Delta.Content = palmResponse.Candidates[0].Content | ||||||
|  | 	} | ||||||
|  | 	choice.FinishReason = &stopFinishReason | ||||||
|  | 	var response ChatCompletionsStreamResponse | ||||||
|  | 	response.Object = "chat.completion.chunk" | ||||||
|  | 	response.Model = "palm2" | ||||||
|  | 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} | ||||||
|  | 	return &response | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||||||
|  | 	responseText := "" | ||||||
|  | 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | ||||||
|  | 	createdTime := common.GetTimestamp() | ||||||
|  | 	dataChan := make(chan string) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 		if err != nil { | ||||||
|  | 			common.SysError("error reading stream response: " + err.Error()) | ||||||
|  | 			stopChan <- true | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		err = resp.Body.Close() | ||||||
|  | 		if err != nil { | ||||||
|  | 			common.SysError("error closing stream response: " + err.Error()) | ||||||
|  | 			stopChan <- true | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		var palmResponse PaLMChatResponse | ||||||
|  | 		err = json.Unmarshal(responseBody, &palmResponse) | ||||||
|  | 		if err != nil { | ||||||
|  | 			common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 			stopChan <- true | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) | ||||||
|  | 		fullTextResponse.Id = responseId | ||||||
|  | 		fullTextResponse.Created = createdTime | ||||||
|  | 		if len(palmResponse.Candidates) > 0 { | ||||||
|  | 			responseText = palmResponse.Candidates[0].Content | ||||||
|  | 		} | ||||||
|  | 		jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 		if err != nil { | ||||||
|  | 			common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 			stopChan <- true | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		dataChan <- string(jsonResponse) | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  | 	setEventStreamHeaders(c) | ||||||
|  | 	c.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case data := <-dataChan: | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + data}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	err := resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||||
|  | 	} | ||||||
|  | 	return nil, responseText | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	var palmResponse PaLMChatResponse | ||||||
|  | 	err = json.Unmarshal(responseBody, &palmResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: palmResponse.Error.Message, | ||||||
|  | 				Type:    palmResponse.Error.Status, | ||||||
|  | 				Param:   "", | ||||||
|  | 				Code:    palmResponse.Error.Code, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := responsePaLM2OpenAI(&palmResponse) | ||||||
|  | 	completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) | ||||||
|  | 	usage := Usage{ | ||||||
|  | 		PromptTokens:     promptTokens, | ||||||
|  | 		CompletionTokens: completionTokens, | ||||||
|  | 		TotalTokens:      promptTokens + completionTokens, | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse.Usage = usage | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &usage | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,22 +1,44 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bufio" |  | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"time" | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	APITypeOpenAI = iota | ||||||
|  | 	APITypeClaude | ||||||
|  | 	APITypePaLM | ||||||
|  | 	APITypeBaidu | ||||||
|  | 	APITypeZhipu | ||||||
|  | 	APITypeAli | ||||||
|  | 	APITypeXunfei | ||||||
|  | 	APITypeAIProxyLibrary | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var httpClient *http.Client | ||||||
|  | var impatientHTTPClient *http.Client | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	httpClient = &http.Client{} | ||||||
|  | 	impatientHTTPClient = &http.Client{ | ||||||
|  | 		Timeout: 5 * time.Second, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||||
| 	channelType := c.GetInt("channel") | 	channelType := c.GetInt("channel") | ||||||
|  | 	channelId := c.GetInt("channel_id") | ||||||
| 	tokenId := c.GetInt("token_id") | 	tokenId := c.GetInt("token_id") | ||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt("id") | ||||||
| 	consumeQuota := c.GetBool("consume_quota") | 	consumeQuota := c.GetBool("consume_quota") | ||||||
| @@ -60,7 +82,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 	// map model name | 	// map model name | ||||||
| 	modelMapping := c.GetString("model_mapping") | 	modelMapping := c.GetString("model_mapping") | ||||||
| 	isModelMapped := false | 	isModelMapped := false | ||||||
| 	if modelMapping != "" { | 	if modelMapping != "" && modelMapping != "{}" { | ||||||
| 		modelMap := make(map[string]string) | 		modelMap := make(map[string]string) | ||||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @@ -71,12 +93,31 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			isModelMapped = true | 			isModelMapped = true | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 	apiType := APITypeOpenAI | ||||||
|  | 	switch channelType { | ||||||
|  | 	case common.ChannelTypeAnthropic: | ||||||
|  | 		apiType = APITypeClaude | ||||||
|  | 	case common.ChannelTypeBaidu: | ||||||
|  | 		apiType = APITypeBaidu | ||||||
|  | 	case common.ChannelTypePaLM: | ||||||
|  | 		apiType = APITypePaLM | ||||||
|  | 	case common.ChannelTypeZhipu: | ||||||
|  | 		apiType = APITypeZhipu | ||||||
|  | 	case common.ChannelTypeAli: | ||||||
|  | 		apiType = APITypeAli | ||||||
|  | 	case common.ChannelTypeXunfei: | ||||||
|  | 		apiType = APITypeXunfei | ||||||
|  | 	case common.ChannelTypeAIProxyLibrary: | ||||||
|  | 		apiType = APITypeAIProxyLibrary | ||||||
|  | 	} | ||||||
| 	baseURL := common.ChannelBaseURLs[channelType] | 	baseURL := common.ChannelBaseURLs[channelType] | ||||||
| 	requestURL := c.Request.URL.String() | 	requestURL := c.Request.URL.String() | ||||||
| 	if c.GetString("base_url") != "" { | 	if c.GetString("base_url") != "" { | ||||||
| 		baseURL = c.GetString("base_url") | 		baseURL = c.GetString("base_url") | ||||||
| 	} | 	} | ||||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||||
|  | 	switch apiType { | ||||||
|  | 	case APITypeOpenAI: | ||||||
| 		if channelType == common.ChannelTypeAzure { | 		if channelType == common.ChannelTypeAzure { | ||||||
| 			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | 			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | ||||||
| 			query := c.Request.URL.Query() | 			query := c.Request.URL.Query() | ||||||
| @@ -95,9 +136,51 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			model_ = strings.TrimSuffix(model_, "-0314") | 			model_ = strings.TrimSuffix(model_, "-0314") | ||||||
| 			model_ = strings.TrimSuffix(model_, "-0613") | 			model_ = strings.TrimSuffix(model_, "-0613") | ||||||
| 			fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) | 			fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) | ||||||
| 	} else if channelType == common.ChannelTypePaLM { | 		} | ||||||
| 		err := relayPaLM(textRequest, c) | 	case APITypeClaude: | ||||||
| 		return err | 		fullRequestURL = "https://api.anthropic.com/v1/complete" | ||||||
|  | 		if baseURL != "" { | ||||||
|  | 			fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL) | ||||||
|  | 		} | ||||||
|  | 	case APITypeBaidu: | ||||||
|  | 		switch textRequest.Model { | ||||||
|  | 		case "ERNIE-Bot": | ||||||
|  | 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" | ||||||
|  | 		case "ERNIE-Bot-turbo": | ||||||
|  | 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" | ||||||
|  | 		case "BLOOMZ-7B": | ||||||
|  | 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" | ||||||
|  | 		case "Embedding-V1": | ||||||
|  | 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" | ||||||
|  | 		} | ||||||
|  | 		apiKey := c.Request.Header.Get("Authorization") | ||||||
|  | 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||||
|  | 		var err error | ||||||
|  | 		if apiKey, err = getBaiduAccessToken(apiKey); err != nil { | ||||||
|  | 			return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) | ||||||
|  | 		} | ||||||
|  | 		fullRequestURL += "?access_token=" + apiKey | ||||||
|  | 	case APITypePaLM: | ||||||
|  | 		fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" | ||||||
|  | 		if baseURL != "" { | ||||||
|  | 			fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) | ||||||
|  | 		} | ||||||
|  | 		apiKey := c.Request.Header.Get("Authorization") | ||||||
|  | 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||||
|  | 		fullRequestURL += "?key=" + apiKey | ||||||
|  | 	case APITypeZhipu: | ||||||
|  | 		method := "invoke" | ||||||
|  | 		if textRequest.Stream { | ||||||
|  | 			method = "sse-invoke" | ||||||
|  | 		} | ||||||
|  | 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) | ||||||
|  | 	case APITypeAli: | ||||||
|  | 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" | ||||||
|  | 		if relayMode == RelayModeEmbeddings { | ||||||
|  | 			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" | ||||||
|  | 		} | ||||||
|  | 	case APITypeAIProxyLibrary: | ||||||
|  | 		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) | ||||||
| 	} | 	} | ||||||
| 	var promptTokens int | 	var promptTokens int | ||||||
| 	var completionTokens int | 	var completionTokens int | ||||||
| @@ -121,10 +204,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	if userQuota > 10*preConsumedQuota { | 	if userQuota-preConsumedQuota < 0 { | ||||||
|  | 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||||
|  | 	} | ||||||
|  | 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	if userQuota > 100*preConsumedQuota { | ||||||
| 		// in this case, we do not pre-consume quota | 		// in this case, we do not pre-consume quota | ||||||
| 		// because the user has enough quota | 		// because the user has enough quota | ||||||
| 		preConsumedQuota = 0 | 		preConsumedQuota = 0 | ||||||
|  | 		common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) | ||||||
| 	} | 	} | ||||||
| 	if consumeQuota && preConsumedQuota > 0 { | 	if consumeQuota && preConsumedQuota > 0 { | ||||||
| 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | ||||||
| @@ -142,22 +233,112 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 	} else { | 	} else { | ||||||
| 		requestBody = c.Request.Body | 		requestBody = c.Request.Body | ||||||
| 	} | 	} | ||||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | 	switch apiType { | ||||||
|  | 	case APITypeClaude: | ||||||
|  | 		claudeRequest := requestOpenAI2Claude(textRequest) | ||||||
|  | 		jsonStr, err := json.Marshal(claudeRequest) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
|  | 		} | ||||||
|  | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
|  | 	case APITypeBaidu: | ||||||
|  | 		var jsonData []byte | ||||||
|  | 		var err error | ||||||
|  | 		switch relayMode { | ||||||
|  | 		case RelayModeEmbeddings: | ||||||
|  | 			baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) | ||||||
|  | 			jsonData, err = json.Marshal(baiduEmbeddingRequest) | ||||||
|  | 		default: | ||||||
|  | 			baiduRequest := requestOpenAI2Baidu(textRequest) | ||||||
|  | 			jsonData, err = json.Marshal(baiduRequest) | ||||||
|  | 		} | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
|  | 		} | ||||||
|  | 		requestBody = bytes.NewBuffer(jsonData) | ||||||
|  | 	case APITypePaLM: | ||||||
|  | 		palmRequest := requestOpenAI2PaLM(textRequest) | ||||||
|  | 		jsonStr, err := json.Marshal(palmRequest) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
|  | 		} | ||||||
|  | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
|  | 	case APITypeZhipu: | ||||||
|  | 		zhipuRequest := requestOpenAI2Zhipu(textRequest) | ||||||
|  | 		jsonStr, err := json.Marshal(zhipuRequest) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
|  | 		} | ||||||
|  | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
|  | 	case APITypeAli: | ||||||
|  | 		var jsonStr []byte | ||||||
|  | 		var err error | ||||||
|  | 		switch relayMode { | ||||||
|  | 		case RelayModeEmbeddings: | ||||||
|  | 			aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) | ||||||
|  | 			jsonStr, err = json.Marshal(aliEmbeddingRequest) | ||||||
|  | 		default: | ||||||
|  | 			aliRequest := requestOpenAI2Ali(textRequest) | ||||||
|  | 			jsonStr, err = json.Marshal(aliRequest) | ||||||
|  | 		} | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
|  | 		} | ||||||
|  | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
|  | 	case APITypeAIProxyLibrary: | ||||||
|  | 		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) | ||||||
|  | 		aiProxyLibraryRequest.LibraryId = c.GetString("library_id") | ||||||
|  | 		jsonStr, err := json.Marshal(aiProxyLibraryRequest) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
|  | 		} | ||||||
|  | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var req *http.Request | ||||||
|  | 	var resp *http.Response | ||||||
|  | 	isStream := textRequest.Stream | ||||||
|  |  | ||||||
|  | 	if apiType != APITypeXunfei { // cause xunfei use websocket | ||||||
|  | 		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | 			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
|  | 		apiKey := c.Request.Header.Get("Authorization") | ||||||
|  | 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||||
|  | 		switch apiType { | ||||||
|  | 		case APITypeOpenAI: | ||||||
| 			if channelType == common.ChannelTypeAzure { | 			if channelType == common.ChannelTypeAzure { | ||||||
| 		key := c.Request.Header.Get("Authorization") | 				req.Header.Set("api-key", apiKey) | ||||||
| 		key = strings.TrimPrefix(key, "Bearer ") |  | ||||||
| 		req.Header.Set("api-key", key) |  | ||||||
| 			} else { | 			} else { | ||||||
| 				req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | 				req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||||
|  | 				if channelType == common.ChannelTypeOpenRouter { | ||||||
|  | 					req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") | ||||||
|  | 					req.Header.Set("X-Title", "One API") | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		case APITypeClaude: | ||||||
|  | 			req.Header.Set("x-api-key", apiKey) | ||||||
|  | 			anthropicVersion := c.Request.Header.Get("anthropic-version") | ||||||
|  | 			if anthropicVersion == "" { | ||||||
|  | 				anthropicVersion = "2023-06-01" | ||||||
|  | 			} | ||||||
|  | 			req.Header.Set("anthropic-version", anthropicVersion) | ||||||
|  | 		case APITypeZhipu: | ||||||
|  | 			token := getZhipuToken(apiKey) | ||||||
|  | 			req.Header.Set("Authorization", token) | ||||||
|  | 		case APITypeAli: | ||||||
|  | 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||||
|  | 			if textRequest.Stream { | ||||||
|  | 				req.Header.Set("X-DashScope-SSE", "enable") | ||||||
|  | 			} | ||||||
|  | 		default: | ||||||
|  | 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||||
| 		} | 		} | ||||||
| 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||||
| 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||||
| 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | ||||||
| 	client := &http.Client{} | 		resp, err = httpClient.Do(req) | ||||||
| 	resp, err := client.Do(req) |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | 			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| @@ -169,26 +350,34 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 	var textResponse TextResponse | 		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||||
| 	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") |  | ||||||
| 	var streamResponseText string |  | ||||||
|  |  | ||||||
| 	defer func() { | 		if resp.StatusCode != http.StatusOK { | ||||||
|  | 			if preConsumedQuota != 0 { | ||||||
|  | 				go func(ctx context.Context) { | ||||||
|  | 					// return pre-consumed quota | ||||||
|  | 					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) | ||||||
|  | 					if err != nil { | ||||||
|  | 						common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||||
|  | 					} | ||||||
|  | 				}(c.Request.Context()) | ||||||
|  | 			} | ||||||
|  | 			return relayErrorHandler(resp) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var textResponse TextResponse | ||||||
|  | 	tokenName := c.GetString("token_name") | ||||||
|  |  | ||||||
|  | 	defer func(ctx context.Context) { | ||||||
|  | 		// c.Writer.Flush() | ||||||
|  | 		go func() { | ||||||
| 			if consumeQuota { | 			if consumeQuota { | ||||||
| 				quota := 0 | 				quota := 0 | ||||||
| 			completionRatio := 1.0 | 				completionRatio := common.GetCompletionRatio(textRequest.Model) | ||||||
| 			if strings.HasPrefix(textRequest.Model, "gpt-3.5") { |  | ||||||
| 				completionRatio = 1.333333 |  | ||||||
| 			} |  | ||||||
| 			if strings.HasPrefix(textRequest.Model, "gpt-4") { |  | ||||||
| 				completionRatio = 2 |  | ||||||
| 			} |  | ||||||
| 			if isStream { |  | ||||||
| 				completionTokens = countTokenText(streamResponseText, textRequest.Model) |  | ||||||
| 			} else { |  | ||||||
| 				promptTokens = textResponse.Usage.PromptTokens | 				promptTokens = textResponse.Usage.PromptTokens | ||||||
| 				completionTokens = textResponse.Usage.CompletionTokens | 				completionTokens = textResponse.Usage.CompletionTokens | ||||||
| 			} |  | ||||||
| 				quota = promptTokens + int(float64(completionTokens)*completionRatio) | 				quota = promptTokens + int(float64(completionTokens)*completionRatio) | ||||||
| 				quota = int(float64(quota) * ratio) | 				quota = int(float64(quota) * ratio) | ||||||
| 				if ratio != 0 && quota <= 0 { | 				if ratio != 0 && quota <= 0 { | ||||||
| @@ -203,140 +392,199 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 				quotaDelta := quota - preConsumedQuota | 				quotaDelta := quota - preConsumedQuota | ||||||
| 				err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | 				err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 				common.SysError("error consuming token remain quota: " + err.Error()) | 					common.LogError(ctx, "error consuming token remain quota: "+err.Error()) | ||||||
| 				} | 				} | ||||||
| 				err = model.CacheUpdateUserQuota(userId) | 				err = model.CacheUpdateUserQuota(userId) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 				common.SysError("error update user quota cache: " + err.Error()) | 					common.LogError(ctx, "error update user quota cache: "+err.Error()) | ||||||
| 				} | 				} | ||||||
| 				if quota != 0 { | 				if quota != 0 { | ||||||
| 				tokenName := c.GetString("token_name") |  | ||||||
| 					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | 					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||||
| 				model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) | 					model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) | ||||||
| 					model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | 					model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||||
| 				channelId := c.GetInt("channel_id") |  | ||||||
| 					model.UpdateChannelUsedQuota(channelId, quota) | 					model.UpdateChannelUsedQuota(channelId, quota) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		}() | 		}() | ||||||
|  | 	}(c.Request.Context()) | ||||||
|  | 	switch apiType { | ||||||
|  | 	case APITypeOpenAI: | ||||||
| 		if isStream { | 		if isStream { | ||||||
| 		scanner := bufio.NewScanner(resp.Body) | 			err, responseText := openaiStreamHandler(c, resp, relayMode) | ||||||
| 		scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { |  | ||||||
| 			if atEOF && len(data) == 0 { |  | ||||||
| 				return 0, nil, nil |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			if i := strings.Index(string(data), "\n"); i >= 0 { |  | ||||||
| 				return i + 1, data[0:i], nil |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			if atEOF { |  | ||||||
| 				return len(data), data, nil |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			return 0, nil, nil |  | ||||||
| 		}) |  | ||||||
| 		dataChan := make(chan string) |  | ||||||
| 		stopChan := make(chan bool) |  | ||||||
| 		go func() { |  | ||||||
| 			for scanner.Scan() { |  | ||||||
| 				data := scanner.Text() |  | ||||||
| 				if len(data) < 6 { // ignore blank line or wrong format |  | ||||||
| 					continue |  | ||||||
| 				} |  | ||||||
| 				dataChan <- data |  | ||||||
| 				data = data[6:] |  | ||||||
| 				if !strings.HasPrefix(data, "[DONE]") { |  | ||||||
| 					switch relayMode { |  | ||||||
| 					case RelayModeChatCompletions: |  | ||||||
| 						var streamResponse ChatCompletionsStreamResponse |  | ||||||
| 						err = json.Unmarshal([]byte(data), &streamResponse) |  | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 							common.SysError("error unmarshalling stream response: " + err.Error()) | 				return err | ||||||
| 							return |  | ||||||
| 			} | 			} | ||||||
| 						for _, choice := range streamResponse.Choices { | 			textResponse.Usage.PromptTokens = promptTokens | ||||||
| 							streamResponseText += choice.Delta.Content | 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||||
| 						} | 			return nil | ||||||
| 					case RelayModeCompletions: | 		} else { | ||||||
| 						var streamResponse CompletionsStreamResponse | 			err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) | ||||||
| 						err = json.Unmarshal([]byte(data), &streamResponse) |  | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 							common.SysError("error unmarshalling stream response: " + err.Error()) | 				return err | ||||||
| 							return |  | ||||||
| 			} | 			} | ||||||
| 						for _, choice := range streamResponse.Choices { | 			if usage != nil { | ||||||
| 							streamResponseText += choice.Text | 				textResponse.Usage = *usage | ||||||
| 			} | 			} | ||||||
|  | 			return nil | ||||||
| 		} | 		} | ||||||
| 				} | 	case APITypeClaude: | ||||||
| 			} | 		if isStream { | ||||||
| 			stopChan <- true | 			err, responseText := claudeStreamHandler(c, resp) | ||||||
| 		}() |  | ||||||
| 		c.Writer.Header().Set("Content-Type", "text/event-stream") |  | ||||||
| 		c.Writer.Header().Set("Cache-Control", "no-cache") |  | ||||||
| 		c.Writer.Header().Set("Connection", "keep-alive") |  | ||||||
| 		c.Writer.Header().Set("Transfer-Encoding", "chunked") |  | ||||||
| 		c.Writer.Header().Set("X-Accel-Buffering", "no") |  | ||||||
| 		c.Stream(func(w io.Writer) bool { |  | ||||||
| 			select { |  | ||||||
| 			case data := <-dataChan: |  | ||||||
| 				if strings.HasPrefix(data, "data: [DONE]") { |  | ||||||
| 					data = data[:12] |  | ||||||
| 				} |  | ||||||
| 				// some implementations may add \r at the end of data |  | ||||||
| 				data = strings.TrimSuffix(data, "\r") |  | ||||||
| 				c.Render(-1, common.CustomEvent{Data: data}) |  | ||||||
| 				return true |  | ||||||
| 			case <-stopChan: |  | ||||||
| 				return false |  | ||||||
| 			} |  | ||||||
| 		}) |  | ||||||
| 		err = resp.Body.Close() |  | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | 				return err | ||||||
|  | 			} | ||||||
|  | 			textResponse.Usage.PromptTokens = promptTokens | ||||||
|  | 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||||
|  | 			return nil | ||||||
|  | 		} else { | ||||||
|  | 			err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  | 	case APITypeBaidu: | ||||||
|  | 		if isStream { | ||||||
|  | 			err, usage := baiduStreamHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
| 			} | 			} | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 		if consumeQuota { | 			var err *OpenAIErrorWithStatusCode | ||||||
| 			responseBody, err := io.ReadAll(resp.Body) | 			var usage *Usage | ||||||
|  | 			switch relayMode { | ||||||
|  | 			case RelayModeEmbeddings: | ||||||
|  | 				err, usage = baiduEmbeddingHandler(c, resp) | ||||||
|  | 			default: | ||||||
|  | 				err, usage = baiduHandler(c, resp) | ||||||
|  | 			} | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | 				return err | ||||||
| 			} | 			} | ||||||
| 			err = resp.Body.Close() | 			if usage != nil { | ||||||
| 			if err != nil { | 				textResponse.Usage = *usage | ||||||
| 				return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) |  | ||||||
| 			} |  | ||||||
| 			err = json.Unmarshal(responseBody, &textResponse) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) |  | ||||||
| 			} |  | ||||||
| 			if textResponse.Error.Type != "" { |  | ||||||
| 				return &OpenAIErrorWithStatusCode{ |  | ||||||
| 					OpenAIError: textResponse.Error, |  | ||||||
| 					StatusCode:  resp.StatusCode, |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			// Reset response body |  | ||||||
| 			resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) |  | ||||||
| 		} |  | ||||||
| 		// We shouldn't set the header before we parse the response body, because the parse part may fail. |  | ||||||
| 		// And then we will have to send an error response, but in this case, the header has already been set. |  | ||||||
| 		// So the client will be confused by the response. |  | ||||||
| 		// For example, Postman will report error, and we cannot check the response at all. |  | ||||||
| 		for k, v := range resp.Header { |  | ||||||
| 			c.Writer.Header().Set(k, v[0]) |  | ||||||
| 		} |  | ||||||
| 		c.Writer.WriteHeader(resp.StatusCode) |  | ||||||
| 		_, err = io.Copy(c.Writer, resp.Body) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		err = resp.Body.Close() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) |  | ||||||
| 			} | 			} | ||||||
| 			return nil | 			return nil | ||||||
| 		} | 		} | ||||||
|  | 	case APITypePaLM: | ||||||
|  | 		if textRequest.Stream { // PaLM2 API does not support stream | ||||||
|  | 			err, responseText := palmStreamHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			textResponse.Usage.PromptTokens = promptTokens | ||||||
|  | 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||||
|  | 			return nil | ||||||
|  | 		} else { | ||||||
|  | 			err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  | 	case APITypeZhipu: | ||||||
|  | 		if isStream { | ||||||
|  | 			err, usage := zhipuStreamHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			// zhipu's API does not return prompt tokens & completion tokens | ||||||
|  | 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens | ||||||
|  | 			return nil | ||||||
|  | 		} else { | ||||||
|  | 			err, usage := zhipuHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			// zhipu's API does not return prompt tokens & completion tokens | ||||||
|  | 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  | 	case APITypeAli: | ||||||
|  | 		if isStream { | ||||||
|  | 			err, usage := aliStreamHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		} else { | ||||||
|  | 			var err *OpenAIErrorWithStatusCode | ||||||
|  | 			var usage *Usage | ||||||
|  | 			switch relayMode { | ||||||
|  | 			case RelayModeEmbeddings: | ||||||
|  | 				err, usage = aliEmbeddingHandler(c, resp) | ||||||
|  | 			default: | ||||||
|  | 				err, usage = aliHandler(c, resp) | ||||||
|  | 			} | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  | 	case APITypeXunfei: | ||||||
|  | 		auth := c.Request.Header.Get("Authorization") | ||||||
|  | 		auth = strings.TrimPrefix(auth, "Bearer ") | ||||||
|  | 		splits := strings.Split(auth, "|") | ||||||
|  | 		if len(splits) != 3 { | ||||||
|  | 			return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) | ||||||
|  | 		} | ||||||
|  | 		var err *OpenAIErrorWithStatusCode | ||||||
|  | 		var usage *Usage | ||||||
|  | 		if isStream { | ||||||
|  | 			err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) | ||||||
|  | 		} else { | ||||||
|  | 			err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2]) | ||||||
|  | 		} | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		if usage != nil { | ||||||
|  | 			textResponse.Usage = *usage | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	case APITypeAIProxyLibrary: | ||||||
|  | 		if isStream { | ||||||
|  | 			err, usage := aiProxyLibraryStreamHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		} else { | ||||||
|  | 			err, usage := aiProxyLibraryHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  | 	default: | ||||||
|  | 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,27 +1,61 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/pkoukk/tiktoken-go" | 	"github.com/pkoukk/tiktoken-go" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | var stopFinishReason = "stop" | ||||||
|  |  | ||||||
|  | // tokenEncoderMap won't grow after initialization | ||||||
| var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | ||||||
|  | var defaultTokenEncoder *tiktoken.Tiktoken | ||||||
|  |  | ||||||
|  | func InitTokenEncoders() { | ||||||
|  | 	common.SysLog("initializing token encoders") | ||||||
|  | 	gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) | ||||||
|  | 	} | ||||||
|  | 	defaultTokenEncoder = gpt35TokenEncoder | ||||||
|  | 	gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) | ||||||
|  | 	} | ||||||
|  | 	for model, _ := range common.ModelRatio { | ||||||
|  | 		if strings.HasPrefix(model, "gpt-3.5") { | ||||||
|  | 			tokenEncoderMap[model] = gpt35TokenEncoder | ||||||
|  | 		} else if strings.HasPrefix(model, "gpt-4") { | ||||||
|  | 			tokenEncoderMap[model] = gpt4TokenEncoder | ||||||
|  | 		} else { | ||||||
|  | 			tokenEncoderMap[model] = nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	common.SysLog("token encoders initialized") | ||||||
|  | } | ||||||
|  |  | ||||||
| func getTokenEncoder(model string) *tiktoken.Tiktoken { | func getTokenEncoder(model string) *tiktoken.Tiktoken { | ||||||
| 	if tokenEncoder, ok := tokenEncoderMap[model]; ok { | 	tokenEncoder, ok := tokenEncoderMap[model] | ||||||
|  | 	if ok && tokenEncoder != nil { | ||||||
| 		return tokenEncoder | 		return tokenEncoder | ||||||
| 	} | 	} | ||||||
|  | 	if ok { | ||||||
| 		tokenEncoder, err := tiktoken.EncodingForModel(model) | 		tokenEncoder, err := tiktoken.EncodingForModel(model) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) | 			common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) | ||||||
| 		tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo") | 			tokenEncoder = defaultTokenEncoder | ||||||
| 		if err != nil { |  | ||||||
| 			common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error())) |  | ||||||
| 		} |  | ||||||
| 		} | 		} | ||||||
| 		tokenEncoderMap[model] = tokenEncoder | 		tokenEncoderMap[model] = tokenEncoder | ||||||
| 		return tokenEncoder | 		return tokenEncoder | ||||||
|  | 	} | ||||||
|  | 	return defaultTokenEncoder | ||||||
| } | } | ||||||
|  |  | ||||||
| func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | ||||||
| @@ -91,3 +125,54 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus | |||||||
| 		StatusCode:  statusCode, | 		StatusCode:  statusCode, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func shouldDisableChannel(err *OpenAIError, statusCode int) bool { | ||||||
|  | 	if !common.AutomaticDisableChannelEnabled { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if err == nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if statusCode == http.StatusUnauthorized { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func setEventStreamHeaders(c *gin.Context) { | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||||
|  | 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||||
|  | 	c.Writer.Header().Set("Connection", "keep-alive") | ||||||
|  | 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||||
|  | 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { | ||||||
|  | 	openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ | ||||||
|  | 		StatusCode: resp.StatusCode, | ||||||
|  | 		OpenAIError: OpenAIError{ | ||||||
|  | 			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), | ||||||
|  | 			Type:    "upstream_error", | ||||||
|  | 			Code:    "bad_response_status_code", | ||||||
|  | 			Param:   strconv.Itoa(resp.StatusCode), | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	var textResponse TextResponse | ||||||
|  | 	err = json.Unmarshal(responseBody, &textResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	openAIErrorWithStatusCode.OpenAIError = textResponse.Error | ||||||
|  | 	return | ||||||
|  | } | ||||||
|   | |||||||
							
								
								
									
										303
									
								
								controller/relay-xunfei.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										303
									
								
								controller/relay-xunfei.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,303 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"crypto/hmac" | ||||||
|  | 	"crypto/sha256" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/gorilla/websocket" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/url" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // https://console.xfyun.cn/services/cbm | ||||||
|  | // https://www.xfyun.cn/doc/spark/Web.html | ||||||
|  |  | ||||||
|  | type XunfeiMessage struct { | ||||||
|  | 	Role    string `json:"role"` | ||||||
|  | 	Content string `json:"content"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type XunfeiChatRequest struct { | ||||||
|  | 	Header struct { | ||||||
|  | 		AppId string `json:"app_id"` | ||||||
|  | 	} `json:"header"` | ||||||
|  | 	Parameter struct { | ||||||
|  | 		Chat struct { | ||||||
|  | 			Domain      string  `json:"domain,omitempty"` | ||||||
|  | 			Temperature float64 `json:"temperature,omitempty"` | ||||||
|  | 			TopK        int     `json:"top_k,omitempty"` | ||||||
|  | 			MaxTokens   int     `json:"max_tokens,omitempty"` | ||||||
|  | 			Auditing    bool    `json:"auditing,omitempty"` | ||||||
|  | 		} `json:"chat"` | ||||||
|  | 	} `json:"parameter"` | ||||||
|  | 	Payload struct { | ||||||
|  | 		Message struct { | ||||||
|  | 			Text []XunfeiMessage `json:"text"` | ||||||
|  | 		} `json:"message"` | ||||||
|  | 	} `json:"payload"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type XunfeiChatResponseTextItem struct { | ||||||
|  | 	Content string `json:"content"` | ||||||
|  | 	Role    string `json:"role"` | ||||||
|  | 	Index   int    `json:"index"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type XunfeiChatResponse struct { | ||||||
|  | 	Header struct { | ||||||
|  | 		Code    int    `json:"code"` | ||||||
|  | 		Message string `json:"message"` | ||||||
|  | 		Sid     string `json:"sid"` | ||||||
|  | 		Status  int    `json:"status"` | ||||||
|  | 	} `json:"header"` | ||||||
|  | 	Payload struct { | ||||||
|  | 		Choices struct { | ||||||
|  | 			Status int                          `json:"status"` | ||||||
|  | 			Seq    int                          `json:"seq"` | ||||||
|  | 			Text   []XunfeiChatResponseTextItem `json:"text"` | ||||||
|  | 		} `json:"choices"` | ||||||
|  | 		Usage struct { | ||||||
|  | 			//Text struct { | ||||||
|  | 			//	QuestionTokens   string `json:"question_tokens"` | ||||||
|  | 			//	PromptTokens     string `json:"prompt_tokens"` | ||||||
|  | 			//	CompletionTokens string `json:"completion_tokens"` | ||||||
|  | 			//	TotalTokens      string `json:"total_tokens"` | ||||||
|  | 			//} `json:"text"` | ||||||
|  | 			Text Usage `json:"text"` | ||||||
|  | 		} `json:"usage"` | ||||||
|  | 	} `json:"payload"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest { | ||||||
|  | 	messages := make([]XunfeiMessage, 0, len(request.Messages)) | ||||||
|  | 	for _, message := range request.Messages { | ||||||
|  | 		if message.Role == "system" { | ||||||
|  | 			messages = append(messages, XunfeiMessage{ | ||||||
|  | 				Role:    "user", | ||||||
|  | 				Content: message.Content, | ||||||
|  | 			}) | ||||||
|  | 			messages = append(messages, XunfeiMessage{ | ||||||
|  | 				Role:    "assistant", | ||||||
|  | 				Content: "Okay", | ||||||
|  | 			}) | ||||||
|  | 		} else { | ||||||
|  | 			messages = append(messages, XunfeiMessage{ | ||||||
|  | 				Role:    message.Role, | ||||||
|  | 				Content: message.Content, | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	xunfeiRequest := XunfeiChatRequest{} | ||||||
|  | 	xunfeiRequest.Header.AppId = xunfeiAppId | ||||||
|  | 	xunfeiRequest.Parameter.Chat.Domain = domain | ||||||
|  | 	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature | ||||||
|  | 	xunfeiRequest.Parameter.Chat.TopK = request.N | ||||||
|  | 	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens | ||||||
|  | 	xunfeiRequest.Payload.Message.Text = messages | ||||||
|  | 	return &xunfeiRequest | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { | ||||||
|  | 	if len(response.Payload.Choices.Text) == 0 { | ||||||
|  | 		response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | ||||||
|  | 			{ | ||||||
|  | 				Content: "", | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	choice := OpenAITextResponseChoice{ | ||||||
|  | 		Index: 0, | ||||||
|  | 		Message: Message{ | ||||||
|  | 			Role:    "assistant", | ||||||
|  | 			Content: response.Payload.Choices.Text[0].Content, | ||||||
|  | 		}, | ||||||
|  | 		FinishReason: stopFinishReason, | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := OpenAITextResponse{ | ||||||
|  | 		Object:  "chat.completion", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Choices: []OpenAITextResponseChoice{choice}, | ||||||
|  | 		Usage:   response.Payload.Usage.Text, | ||||||
|  | 	} | ||||||
|  | 	return &fullTextResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { | ||||||
|  | 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { | ||||||
|  | 		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | ||||||
|  | 			{ | ||||||
|  | 				Content: "", | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content | ||||||
|  | 	if xunfeiResponse.Payload.Choices.Status == 2 { | ||||||
|  | 		choice.FinishReason = &stopFinishReason | ||||||
|  | 	} | ||||||
|  | 	response := ChatCompletionsStreamResponse{ | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   "SparkDesk", | ||||||
|  | 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | 	return &response | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { | ||||||
|  | 	HmacWithShaToBase64 := func(algorithm, data, key string) string { | ||||||
|  | 		mac := hmac.New(sha256.New, []byte(key)) | ||||||
|  | 		mac.Write([]byte(data)) | ||||||
|  | 		encodeData := mac.Sum(nil) | ||||||
|  | 		return base64.StdEncoding.EncodeToString(encodeData) | ||||||
|  | 	} | ||||||
|  | 	ul, err := url.Parse(hostUrl) | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Println(err) | ||||||
|  | 	} | ||||||
|  | 	date := time.Now().UTC().Format(time.RFC1123) | ||||||
|  | 	signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} | ||||||
|  | 	sign := strings.Join(signString, "\n") | ||||||
|  | 	sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret) | ||||||
|  | 	authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, | ||||||
|  | 		"hmac-sha256", "host date request-line", sha) | ||||||
|  | 	authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) | ||||||
|  | 	v := url.Values{} | ||||||
|  | 	v.Add("host", ul.Host) | ||||||
|  | 	v.Add("date", date) | ||||||
|  | 	v.Add("authorization", authorization) | ||||||
|  | 	callUrl := hostUrl + "?" + v.Encode() | ||||||
|  | 	return callUrl | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) | ||||||
|  | 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	setEventStreamHeaders(c) | ||||||
|  | 	var usage Usage | ||||||
|  | 	c.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case xunfeiResponse := <-dataChan: | ||||||
|  | 			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens | ||||||
|  | 			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens | ||||||
|  | 			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens | ||||||
|  | 			response := streamResponseXunfei2OpenAI(&xunfeiResponse) | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	return nil, &usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) | ||||||
|  | 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	var usage Usage | ||||||
|  | 	var content string | ||||||
|  | 	var xunfeiResponse XunfeiChatResponse | ||||||
|  | 	stop := false | ||||||
|  | 	for !stop { | ||||||
|  | 		select { | ||||||
|  | 		case xunfeiResponse = <-dataChan: | ||||||
|  | 			content += xunfeiResponse.Payload.Choices.Text[0].Content | ||||||
|  | 			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens | ||||||
|  | 			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens | ||||||
|  | 			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens | ||||||
|  | 		case stop = <-stopChan: | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	xunfeiResponse.Payload.Choices.Text[0].Content = content | ||||||
|  |  | ||||||
|  | 	response := responseXunfei2OpenAI(&xunfeiResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(response) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	_, _ = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { | ||||||
|  | 	d := websocket.Dialer{ | ||||||
|  | 		HandshakeTimeout: 5 * time.Second, | ||||||
|  | 	} | ||||||
|  | 	conn, resp, err := d.Dial(authUrl, nil) | ||||||
|  | 	if err != nil || resp.StatusCode != 101 { | ||||||
|  | 		return nil, nil, err | ||||||
|  | 	} | ||||||
|  | 	data := requestOpenAI2Xunfei(textRequest, appId, domain) | ||||||
|  | 	err = conn.WriteJSON(data) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dataChan := make(chan XunfeiChatResponse) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		for { | ||||||
|  | 			_, msg, err := conn.ReadMessage() | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error reading stream response: " + err.Error()) | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 			var response XunfeiChatResponse | ||||||
|  | 			err = json.Unmarshal(msg, &response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 			dataChan <- response | ||||||
|  | 			if response.Payload.Choices.Status == 2 { | ||||||
|  | 				err := conn.Close() | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.SysError("error closing websocket connection: " + err.Error()) | ||||||
|  | 				} | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	return dataChan, stopChan, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) { | ||||||
|  | 	query := c.Request.URL.Query() | ||||||
|  | 	apiVersion := query.Get("api-version") | ||||||
|  | 	if apiVersion == "" { | ||||||
|  | 		apiVersion = c.GetString("api_version") | ||||||
|  | 	} | ||||||
|  | 	if apiVersion == "" { | ||||||
|  | 		apiVersion = "v1.1" | ||||||
|  | 		common.SysLog("api_version not found, use default: " + apiVersion) | ||||||
|  | 	} | ||||||
|  | 	domain := "general" | ||||||
|  | 	if apiVersion == "v2.1" { | ||||||
|  | 		domain = "generalv2" | ||||||
|  | 	} | ||||||
|  | 	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) | ||||||
|  | 	return domain, authUrl | ||||||
|  | } | ||||||
							
								
								
									
										301
									
								
								controller/relay-zhipu.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										301
									
								
								controller/relay-zhipu.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,301 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/golang-jwt/jwt" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"strings" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // https://open.bigmodel.cn/doc/api#chatglm_std | ||||||
|  | // chatglm_std, chatglm_lite | ||||||
|  | // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke | ||||||
|  | // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke | ||||||
|  |  | ||||||
|  | type ZhipuMessage struct { | ||||||
|  | 	Role    string `json:"role"` | ||||||
|  | 	Content string `json:"content"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ZhipuRequest struct { | ||||||
|  | 	Prompt      []ZhipuMessage `json:"prompt"` | ||||||
|  | 	Temperature float64        `json:"temperature,omitempty"` | ||||||
|  | 	TopP        float64        `json:"top_p,omitempty"` | ||||||
|  | 	RequestId   string         `json:"request_id,omitempty"` | ||||||
|  | 	Incremental bool           `json:"incremental,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ZhipuResponseData struct { | ||||||
|  | 	TaskId     string         `json:"task_id"` | ||||||
|  | 	RequestId  string         `json:"request_id"` | ||||||
|  | 	TaskStatus string         `json:"task_status"` | ||||||
|  | 	Choices    []ZhipuMessage `json:"choices"` | ||||||
|  | 	Usage      `json:"usage"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ZhipuResponse struct { | ||||||
|  | 	Code    int               `json:"code"` | ||||||
|  | 	Msg     string            `json:"msg"` | ||||||
|  | 	Success bool              `json:"success"` | ||||||
|  | 	Data    ZhipuResponseData `json:"data"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ZhipuStreamMetaResponse struct { | ||||||
|  | 	RequestId  string `json:"request_id"` | ||||||
|  | 	TaskId     string `json:"task_id"` | ||||||
|  | 	TaskStatus string `json:"task_status"` | ||||||
|  | 	Usage      `json:"usage"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type zhipuTokenData struct { | ||||||
|  | 	Token      string | ||||||
|  | 	ExpiryTime time.Time | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var zhipuTokens sync.Map | ||||||
|  | var expSeconds int64 = 24 * 3600 | ||||||
|  |  | ||||||
|  | func getZhipuToken(apikey string) string { | ||||||
|  | 	data, ok := zhipuTokens.Load(apikey) | ||||||
|  | 	if ok { | ||||||
|  | 		tokenData := data.(zhipuTokenData) | ||||||
|  | 		if time.Now().Before(tokenData.ExpiryTime) { | ||||||
|  | 			return tokenData.Token | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	split := strings.Split(apikey, ".") | ||||||
|  | 	if len(split) != 2 { | ||||||
|  | 		common.SysError("invalid zhipu key: " + apikey) | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	id := split[0] | ||||||
|  | 	secret := split[1] | ||||||
|  |  | ||||||
|  | 	expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6 | ||||||
|  | 	expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second) | ||||||
|  |  | ||||||
|  | 	timestamp := time.Now().UnixNano() / 1e6 | ||||||
|  |  | ||||||
|  | 	payload := jwt.MapClaims{ | ||||||
|  | 		"api_key":   id, | ||||||
|  | 		"exp":       expMillis, | ||||||
|  | 		"timestamp": timestamp, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) | ||||||
|  |  | ||||||
|  | 	token.Header["alg"] = "HS256" | ||||||
|  | 	token.Header["sign_type"] = "SIGN" | ||||||
|  |  | ||||||
|  | 	tokenString, err := token.SignedString([]byte(secret)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	zhipuTokens.Store(apikey, zhipuTokenData{ | ||||||
|  | 		Token:      tokenString, | ||||||
|  | 		ExpiryTime: expiryTime, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	return tokenString | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { | ||||||
|  | 	messages := make([]ZhipuMessage, 0, len(request.Messages)) | ||||||
|  | 	for _, message := range request.Messages { | ||||||
|  | 		if message.Role == "system" { | ||||||
|  | 			messages = append(messages, ZhipuMessage{ | ||||||
|  | 				Role:    "system", | ||||||
|  | 				Content: message.Content, | ||||||
|  | 			}) | ||||||
|  | 			messages = append(messages, ZhipuMessage{ | ||||||
|  | 				Role:    "user", | ||||||
|  | 				Content: "Okay", | ||||||
|  | 			}) | ||||||
|  | 		} else { | ||||||
|  | 			messages = append(messages, ZhipuMessage{ | ||||||
|  | 				Role:    message.Role, | ||||||
|  | 				Content: message.Content, | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return &ZhipuRequest{ | ||||||
|  | 		Prompt:      messages, | ||||||
|  | 		Temperature: request.Temperature, | ||||||
|  | 		TopP:        request.TopP, | ||||||
|  | 		Incremental: false, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { | ||||||
|  | 	fullTextResponse := OpenAITextResponse{ | ||||||
|  | 		Id:      response.Data.TaskId, | ||||||
|  | 		Object:  "chat.completion", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)), | ||||||
|  | 		Usage:   response.Data.Usage, | ||||||
|  | 	} | ||||||
|  | 	for i, choice := range response.Data.Choices { | ||||||
|  | 		openaiChoice := OpenAITextResponseChoice{ | ||||||
|  | 			Index: i, | ||||||
|  | 			Message: Message{ | ||||||
|  | 				Role:    choice.Role, | ||||||
|  | 				Content: strings.Trim(choice.Content, "\""), | ||||||
|  | 			}, | ||||||
|  | 			FinishReason: "", | ||||||
|  | 		} | ||||||
|  | 		if i == len(response.Data.Choices)-1 { | ||||||
|  | 			openaiChoice.FinishReason = "stop" | ||||||
|  | 		} | ||||||
|  | 		fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice) | ||||||
|  | 	} | ||||||
|  | 	return &fullTextResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse { | ||||||
|  | 	var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 	choice.Delta.Content = zhipuResponse | ||||||
|  | 	response := ChatCompletionsStreamResponse{ | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   "chatglm", | ||||||
|  | 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | 	return &response | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) { | ||||||
|  | 	var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 	choice.Delta.Content = "" | ||||||
|  | 	choice.FinishReason = &stopFinishReason | ||||||
|  | 	response := ChatCompletionsStreamResponse{ | ||||||
|  | 		Id:      zhipuResponse.RequestId, | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   "chatglm", | ||||||
|  | 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | 	return &response, &zhipuResponse.Usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var usage *Usage | ||||||
|  | 	scanner := bufio.NewScanner(resp.Body) | ||||||
|  | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
|  | 		if atEOF && len(data) == 0 { | ||||||
|  | 			return 0, nil, nil | ||||||
|  | 		} | ||||||
|  | 		if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 { | ||||||
|  | 			return i + 2, data[0:i], nil | ||||||
|  | 		} | ||||||
|  | 		if atEOF { | ||||||
|  | 			return len(data), data, nil | ||||||
|  | 		} | ||||||
|  | 		return 0, nil, nil | ||||||
|  | 	}) | ||||||
|  | 	dataChan := make(chan string) | ||||||
|  | 	metaChan := make(chan string) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		for scanner.Scan() { | ||||||
|  | 			data := scanner.Text() | ||||||
|  | 			lines := strings.Split(data, "\n") | ||||||
|  | 			for i, line := range lines { | ||||||
|  | 				if len(line) < 5 { | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  | 				if line[:5] == "data:" { | ||||||
|  | 					dataChan <- line[5:] | ||||||
|  | 					if i != len(lines)-1 { | ||||||
|  | 						dataChan <- "\n" | ||||||
|  | 					} | ||||||
|  | 				} else if line[:5] == "meta:" { | ||||||
|  | 					metaChan <- line[5:] | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  | 	setEventStreamHeaders(c) | ||||||
|  | 	c.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case data := <-dataChan: | ||||||
|  | 			response := streamResponseZhipu2OpenAI(data) | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			return true | ||||||
|  | 		case data := <-metaChan: | ||||||
|  | 			var zhipuResponse ZhipuStreamMetaResponse | ||||||
|  | 			err := json.Unmarshal([]byte(data), &zhipuResponse) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			usage = zhipuUsage | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	err := resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	return nil, usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var zhipuResponse ZhipuResponse | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &zhipuResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	if !zhipuResponse.Success { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: zhipuResponse.Msg, | ||||||
|  | 				Type:    "zhipu_error", | ||||||
|  | 				Param:   "", | ||||||
|  | 				Code:    zhipuResponse.Code, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := responseZhipu2OpenAI(&zhipuResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
| @@ -24,6 +24,7 @@ const ( | |||||||
| 	RelayModeModerations | 	RelayModeModerations | ||||||
| 	RelayModeImagesGenerations | 	RelayModeImagesGenerations | ||||||
| 	RelayModeEdits | 	RelayModeEdits | ||||||
|  | 	RelayModeAudio | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // https://platform.openai.com/docs/api-reference/chat | // https://platform.openai.com/docs/api-reference/chat | ||||||
| @@ -40,6 +41,26 @@ type GeneralOpenAIRequest struct { | |||||||
| 	Input       any       `json:"input,omitempty"` | 	Input       any       `json:"input,omitempty"` | ||||||
| 	Instruction string    `json:"instruction,omitempty"` | 	Instruction string    `json:"instruction,omitempty"` | ||||||
| 	Size        string    `json:"size,omitempty"` | 	Size        string    `json:"size,omitempty"` | ||||||
|  | 	Functions   any       `json:"functions,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r GeneralOpenAIRequest) ParseInput() []string { | ||||||
|  | 	if r.Input == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	var input []string | ||||||
|  | 	switch r.Input.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		input = []string{r.Input.(string)} | ||||||
|  | 	case []any: | ||||||
|  | 		input = make([]string, 0, len(r.Input.([]any))) | ||||||
|  | 		for _, item := range r.Input.([]any) { | ||||||
|  | 			if str, ok := item.(string); ok { | ||||||
|  | 				input = append(input, str) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return input | ||||||
| } | } | ||||||
|  |  | ||||||
| type ChatRequest struct { | type ChatRequest struct { | ||||||
| @@ -62,6 +83,10 @@ type ImageRequest struct { | |||||||
| 	Size   string `json:"size"` | 	Size   string `json:"size"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type AudioResponse struct { | ||||||
|  | 	Text string `json:"text,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
| type Usage struct { | type Usage struct { | ||||||
| 	PromptTokens     int `json:"prompt_tokens"` | 	PromptTokens     int `json:"prompt_tokens"` | ||||||
| 	CompletionTokens int `json:"completion_tokens"` | 	CompletionTokens int `json:"completion_tokens"` | ||||||
| @@ -81,10 +106,38 @@ type OpenAIErrorWithStatusCode struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| type TextResponse struct { | type TextResponse struct { | ||||||
|  | 	Choices []OpenAITextResponseChoice `json:"choices"` | ||||||
| 	Usage   `json:"usage"` | 	Usage   `json:"usage"` | ||||||
| 	Error   OpenAIError `json:"error"` | 	Error   OpenAIError `json:"error"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type OpenAITextResponseChoice struct { | ||||||
|  | 	Index        int `json:"index"` | ||||||
|  | 	Message      `json:"message"` | ||||||
|  | 	FinishReason string `json:"finish_reason"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type OpenAITextResponse struct { | ||||||
|  | 	Id      string                     `json:"id"` | ||||||
|  | 	Object  string                     `json:"object"` | ||||||
|  | 	Created int64                      `json:"created"` | ||||||
|  | 	Choices []OpenAITextResponseChoice `json:"choices"` | ||||||
|  | 	Usage   `json:"usage"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type OpenAIEmbeddingResponseItem struct { | ||||||
|  | 	Object    string    `json:"object"` | ||||||
|  | 	Index     int       `json:"index"` | ||||||
|  | 	Embedding []float64 `json:"embedding"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type OpenAIEmbeddingResponse struct { | ||||||
|  | 	Object string                        `json:"object"` | ||||||
|  | 	Data   []OpenAIEmbeddingResponseItem `json:"data"` | ||||||
|  | 	Model  string                        `json:"model"` | ||||||
|  | 	Usage  `json:"usage"` | ||||||
|  | } | ||||||
|  |  | ||||||
| type ImageResponse struct { | type ImageResponse struct { | ||||||
| 	Created int `json:"created"` | 	Created int `json:"created"` | ||||||
| 	Data    []struct { | 	Data    []struct { | ||||||
| @@ -92,13 +145,19 @@ type ImageResponse struct { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| type ChatCompletionsStreamResponse struct { | type ChatCompletionsStreamResponseChoice struct { | ||||||
| 	Choices []struct { |  | ||||||
| 	Delta struct { | 	Delta struct { | ||||||
| 		Content string `json:"content"` | 		Content string `json:"content"` | ||||||
| 	} `json:"delta"` | 	} `json:"delta"` | ||||||
| 		FinishReason string `json:"finish_reason"` | 	FinishReason *string `json:"finish_reason"` | ||||||
| 	} `json:"choices"` | } | ||||||
|  |  | ||||||
|  | type ChatCompletionsStreamResponse struct { | ||||||
|  | 	Id      string                                `json:"id"` | ||||||
|  | 	Object  string                                `json:"object"` | ||||||
|  | 	Created int64                                 `json:"created"` | ||||||
|  | 	Model   string                                `json:"model"` | ||||||
|  | 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type CompletionsStreamResponse struct { | type CompletionsStreamResponse struct { | ||||||
| @@ -124,15 +183,20 @@ func Relay(c *gin.Context) { | |||||||
| 		relayMode = RelayModeImagesGenerations | 		relayMode = RelayModeImagesGenerations | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { | 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { | ||||||
| 		relayMode = RelayModeEdits | 		relayMode = RelayModeEdits | ||||||
|  | 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||||
|  | 		relayMode = RelayModeAudio | ||||||
| 	} | 	} | ||||||
| 	var err *OpenAIErrorWithStatusCode | 	var err *OpenAIErrorWithStatusCode | ||||||
| 	switch relayMode { | 	switch relayMode { | ||||||
| 	case RelayModeImagesGenerations: | 	case RelayModeImagesGenerations: | ||||||
| 		err = relayImageHelper(c, relayMode) | 		err = relayImageHelper(c, relayMode) | ||||||
|  | 	case RelayModeAudio: | ||||||
|  | 		err = relayAudioHelper(c, relayMode) | ||||||
| 	default: | 	default: | ||||||
| 		err = relayTextHelper(c, relayMode) | 		err = relayTextHelper(c, relayMode) | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | 		requestId := c.GetString(common.RequestIdKey) | ||||||
| 		retryTimesStr := c.Query("retry") | 		retryTimesStr := c.Query("retry") | ||||||
| 		retryTimes, _ := strconv.Atoi(retryTimesStr) | 		retryTimes, _ := strconv.Atoi(retryTimesStr) | ||||||
| 		if retryTimesStr == "" { | 		if retryTimesStr == "" { | ||||||
| @@ -142,16 +206,17 @@ func Relay(c *gin.Context) { | |||||||
| 			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) | 			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) | ||||||
| 		} else { | 		} else { | ||||||
| 			if err.StatusCode == http.StatusTooManyRequests { | 			if err.StatusCode == http.StatusTooManyRequests { | ||||||
| 				err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" | 				err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" | ||||||
| 			} | 			} | ||||||
|  | 			err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) | ||||||
| 			c.JSON(err.StatusCode, gin.H{ | 			c.JSON(err.StatusCode, gin.H{ | ||||||
| 				"error": err.OpenAIError, | 				"error": err.OpenAIError, | ||||||
| 			}) | 			}) | ||||||
| 		} | 		} | ||||||
| 		channelId := c.GetInt("channel_id") | 		channelId := c.GetInt("channel_id") | ||||||
| 		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | 		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | ||||||
| 		// https://platform.openai.com/docs/guides/error-codes/api-errors | 		// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||||
| 		if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated") { | 		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { | ||||||
| 			channelId := c.GetInt("channel_id") | 			channelId := c.GetInt("channel_id") | ||||||
| 			channelName := c.GetString("channel_name") | 			channelName := c.GetString("channel_name") | ||||||
| 			disableChannel(channelId, channelName, err.Message) | 			disableChannel(channelId, channelName, err.Message) | ||||||
| @@ -173,10 +238,10 @@ func RelayNotImplemented(c *gin.Context) { | |||||||
|  |  | ||||||
| func RelayNotFound(c *gin.Context) { | func RelayNotFound(c *gin.Context) { | ||||||
| 	err := OpenAIError{ | 	err := OpenAIError{ | ||||||
| 		Message: fmt.Sprintf("API not found: %s:%s", c.Request.Method, c.Request.URL.Path), | 		Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), | ||||||
| 		Type:    "one_api_error", | 		Type:    "invalid_request_error", | ||||||
| 		Param:   "", | 		Param:   "", | ||||||
| 		Code:    "api_not_found", | 		Code:    "", | ||||||
| 	} | 	} | ||||||
| 	c.JSON(http.StatusNotFound, gin.H{ | 	c.JSON(http.StatusNotFound, gin.H{ | ||||||
| 		"error": err, | 		"error": err, | ||||||
|   | |||||||
| @@ -109,10 +109,10 @@ func AddToken(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if len(token.Name) == 0 || len(token.Name) > 20 { | 	if len(token.Name) > 30 { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "令牌名称长度必须在1-20之间", | 			"message": "令牌名称过长", | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -171,6 +171,13 @@ func UpdateToken(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  | 	if len(token.Name) > 30 { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": "令牌名称过长", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
| 	cleanToken, err := model.GetTokenByIds(token.Id, userId) | 	cleanToken, err := model.GetTokenByIds(token.Id, userId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|   | |||||||
| @@ -3,12 +3,13 @@ package controller | |||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-contrib/sessions" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-contrib/sessions" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type LoginRequest struct { | type LoginRequest struct { | ||||||
| @@ -477,6 +478,16 @@ func DeleteUser(c *gin.Context) { | |||||||
|  |  | ||||||
| func DeleteSelf(c *gin.Context) { | func DeleteSelf(c *gin.Context) { | ||||||
| 	id := c.GetInt("id") | 	id := c.GetInt("id") | ||||||
|  | 	user, _ := model.GetUserById(id, false) | ||||||
|  |  | ||||||
|  | 	if user.Role == common.RoleRootUser { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": "不能删除超级管理员账户", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	err := model.DeleteUserById(id) | 	err := model.DeleteUserById(id) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|   | |||||||
							
								
								
									
										14
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								go.mod
									
									
									
									
									
								
							| @@ -11,31 +11,34 @@ require ( | |||||||
| 	github.com/gin-gonic/gin v1.9.1 | 	github.com/gin-gonic/gin v1.9.1 | ||||||
| 	github.com/go-playground/validator/v10 v10.14.0 | 	github.com/go-playground/validator/v10 v10.14.0 | ||||||
| 	github.com/go-redis/redis/v8 v8.11.5 | 	github.com/go-redis/redis/v8 v8.11.5 | ||||||
|  | 	github.com/golang-jwt/jwt v3.2.2+incompatible | ||||||
| 	github.com/google/uuid v1.3.0 | 	github.com/google/uuid v1.3.0 | ||||||
| 	github.com/pkoukk/tiktoken-go v0.1.1 | 	github.com/gorilla/websocket v1.5.0 | ||||||
|  | 	github.com/pkoukk/tiktoken-go v0.1.5 | ||||||
| 	golang.org/x/crypto v0.9.0 | 	golang.org/x/crypto v0.9.0 | ||||||
| 	gorm.io/driver/mysql v1.4.3 | 	gorm.io/driver/mysql v1.4.3 | ||||||
| 	gorm.io/driver/sqlite v1.4.3 | 	gorm.io/driver/sqlite v1.4.3 | ||||||
| 	gorm.io/gorm v1.24.0 | 	gorm.io/gorm v1.25.0 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| require ( | require ( | ||||||
| 	github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff // indirect |  | ||||||
| 	github.com/bytedance/sonic v1.9.1 // indirect | 	github.com/bytedance/sonic v1.9.1 // indirect | ||||||
| 	github.com/cespare/xxhash/v2 v2.1.2 // indirect | 	github.com/cespare/xxhash/v2 v2.1.2 // indirect | ||||||
| 	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect | 	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect | ||||||
| 	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect | 	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect | ||||||
| 	github.com/dlclark/regexp2 v1.8.1 // indirect | 	github.com/dlclark/regexp2 v1.10.0 // indirect | ||||||
| 	github.com/gabriel-vasile/mimetype v1.4.2 // indirect | 	github.com/gabriel-vasile/mimetype v1.4.2 // indirect | ||||||
| 	github.com/gin-contrib/sse v0.1.0 // indirect | 	github.com/gin-contrib/sse v0.1.0 // indirect | ||||||
| 	github.com/go-playground/locales v0.14.1 // indirect | 	github.com/go-playground/locales v0.14.1 // indirect | ||||||
| 	github.com/go-playground/universal-translator v0.18.1 // indirect | 	github.com/go-playground/universal-translator v0.18.1 // indirect | ||||||
| 	github.com/go-sql-driver/mysql v1.6.0 // indirect | 	github.com/go-sql-driver/mysql v1.6.0 // indirect | ||||||
| 	github.com/goccy/go-json v0.10.2 // indirect | 	github.com/goccy/go-json v0.10.2 // indirect | ||||||
| 	github.com/gomodule/redigo v2.0.0+incompatible // indirect |  | ||||||
| 	github.com/gorilla/context v1.1.1 // indirect | 	github.com/gorilla/context v1.1.1 // indirect | ||||||
| 	github.com/gorilla/securecookie v1.1.1 // indirect | 	github.com/gorilla/securecookie v1.1.1 // indirect | ||||||
| 	github.com/gorilla/sessions v1.2.1 // indirect | 	github.com/gorilla/sessions v1.2.1 // indirect | ||||||
|  | 	github.com/jackc/pgpassfile v1.0.0 // indirect | ||||||
|  | 	github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect | ||||||
|  | 	github.com/jackc/pgx/v5 v5.3.1 // indirect | ||||||
| 	github.com/jinzhu/inflection v1.0.0 // indirect | 	github.com/jinzhu/inflection v1.0.0 // indirect | ||||||
| 	github.com/jinzhu/now v1.1.5 // indirect | 	github.com/jinzhu/now v1.1.5 // indirect | ||||||
| 	github.com/json-iterator/go v1.1.12 // indirect | 	github.com/json-iterator/go v1.1.12 // indirect | ||||||
| @@ -54,4 +57,5 @@ require ( | |||||||
| 	golang.org/x/text v0.9.0 // indirect | 	golang.org/x/text v0.9.0 // indirect | ||||||
| 	google.golang.org/protobuf v1.30.0 // indirect | 	google.golang.org/protobuf v1.30.0 // indirect | ||||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||||
|  | 	gorm.io/driver/postgres v1.5.2 // indirect | ||||||
| ) | ) | ||||||
|   | |||||||
							
								
								
									
										27
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								go.sum
									
									
									
									
									
								
							| @@ -1,5 +1,3 @@ | |||||||
| github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff h1:RmdPFa+slIr4SCBg4st/l/vZWVe9QJKMXGO60Bxbe04= |  | ||||||
| github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P97fT2CiHkbFQwkK3mjsFAP6zCYV2aXtjw= |  | ||||||
| github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= | github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= | ||||||
| github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= | github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= | ||||||
| github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= | github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= | ||||||
| @@ -14,8 +12,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c | |||||||
| github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | ||||||
| github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= | ||||||
| github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= | ||||||
| github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= | github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= | ||||||
| github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= | github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= | ||||||
| github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= | github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= | ||||||
| github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= | github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= | ||||||
| github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= | github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= | ||||||
| @@ -54,10 +52,10 @@ github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB | |||||||
| github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= | github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= | ||||||
| github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= | github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= | ||||||
| github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= | github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= | ||||||
|  | github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= | ||||||
|  | github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= | ||||||
| github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= | github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= | ||||||
| github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= | ||||||
| github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= |  | ||||||
| github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= |  | ||||||
| github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= | github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= | ||||||
| github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | ||||||
| github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= | github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= | ||||||
| @@ -67,9 +65,16 @@ github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8 | |||||||
| github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= | github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= | ||||||
| github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= | github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= | ||||||
| github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= | github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= | ||||||
| github.com/gorilla/sessions v1.1.1/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w= |  | ||||||
| github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= | github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= | ||||||
| github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= | github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= | ||||||
|  | github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= | ||||||
|  | github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= | ||||||
|  | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= | ||||||
|  | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= | ||||||
|  | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= | ||||||
|  | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= | ||||||
|  | github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= | ||||||
|  | github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= | ||||||
| github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | ||||||
| github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | ||||||
| github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||||
| @@ -113,8 +118,8 @@ github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZO | |||||||
| github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= | github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= | ||||||
| github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= | github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= | ||||||
| github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= | github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= | ||||||
| github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo= | github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4= | ||||||
| github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw= | github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= | ||||||
| github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||||||
| github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||||||
| github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= | github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= | ||||||
| @@ -188,9 +193,13 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= | |||||||
| gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | ||||||
| gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k= | gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k= | ||||||
| gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= | gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= | ||||||
|  | gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= | ||||||
|  | gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= | ||||||
| gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= | gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= | ||||||
| gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= | gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= | ||||||
| gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= | gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= | ||||||
| gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74= | gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74= | ||||||
| gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= | gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= | ||||||
|  | gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= | ||||||
|  | gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= | ||||||
| rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= | rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= | ||||||
|   | |||||||
							
								
								
									
										31
									
								
								i18n/en.json
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								i18n/en.json
									
									
									
									
									
								
							| @@ -3,6 +3,11 @@ | |||||||
|   "%d 点额度": "%d point quota", |   "%d 点额度": "%d point quota", | ||||||
|   "尚未实现": "Not yet implemented", |   "尚未实现": "Not yet implemented", | ||||||
|   "余额不足": "Insufficient balance", |   "余额不足": "Insufficient balance", | ||||||
|  |   "危险操作": "Hazardous operations", | ||||||
|  |   "输入你的账户名": "Enter your account name", | ||||||
|  |   "确认删除": "Confirm Delete", | ||||||
|  |   "确认绑定": "Confirm Binding", | ||||||
|  |   "您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.", | ||||||
|   "\"通道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"", |   "\"通道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"", | ||||||
|   "通道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s", |   "通道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s", | ||||||
|   "测试已在运行中": "Test is already running", |   "测试已在运行中": "Test is already running", | ||||||
| @@ -34,8 +39,8 @@ | |||||||
|   "兑换码个数必须大于0": "The number of redemption codes must be greater than 0", |   "兑换码个数必须大于0": "The number of redemption codes must be greater than 0", | ||||||
|   "一次兑换码批量生成的个数不能大于 100": "The number of redemption codes generated in a batch cannot be greater than 100", |   "一次兑换码批量生成的个数不能大于 100": "The number of redemption codes generated in a batch cannot be greater than 100", | ||||||
|   "通过令牌「%s」使用模型 %s 消耗 %s(模型倍率 %.2f,分组倍率 %.2f)": "Using model %s with token %s consumes %s (model rate %.2f, group rate %.2f)", |   "通过令牌「%s」使用模型 %s 消耗 %s(模型倍率 %.2f,分组倍率 %.2f)": "Using model %s with token %s consumes %s (model rate %.2f, group rate %.2f)", | ||||||
|   "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。": "The current group load is saturated, please try again later, or upgrade your account to improve service quality.", |   "当前分组上游负载已饱和,请稍后再试": "The current group load is saturated, please try again later", | ||||||
|   "令牌名称长度必须在1-20之间": "The length of the token name must be between 1-20", |   "令牌名称过长": "Token name is too long", | ||||||
|   "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期": "The token has expired and cannot be enabled. Please modify the expiration time of the token, or set it to never expire.", |   "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期": "The token has expired and cannot be enabled. Please modify the expiration time of the token, or set it to never expire.", | ||||||
|   "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度": "The available quota of the token has been used up and cannot be enabled. Please modify the remaining quota of the token, or set it to unlimited quota", |   "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度": "The available quota of the token has been used up and cannot be enabled. Please modify the remaining quota of the token, or set it to unlimited quota", | ||||||
|   "管理员关闭了密码登录": "The administrator has turned off password login", |   "管理员关闭了密码登录": "The administrator has turned off password login", | ||||||
| @@ -224,7 +229,7 @@ | |||||||
|   "已是最新版本": "Is the latest version", |   "已是最新版本": "Is the latest version", | ||||||
|   "检查更新": "Check for updates", |   "检查更新": "Check for updates", | ||||||
|   "公告": "Announcement", |   "公告": "Announcement", | ||||||
|   "在此输入新的公告内容": "Enter new announcement content here", |   "在此输入新的公告内容,支持 Markdown & HTML 代码": "Enter the new announcement content here, supports Markdown & HTML code", | ||||||
|   "保存公告": "Save Announcement", |   "保存公告": "Save Announcement", | ||||||
|   "个性化设置": "Personalization Settings", |   "个性化设置": "Personalization Settings", | ||||||
|   "系统名称": "System Name", |   "系统名称": "System Name", | ||||||
| @@ -427,7 +432,7 @@ | |||||||
|   "一分钟后过期": "Expires after one minute", |   "一分钟后过期": "Expires after one minute", | ||||||
|   "创建新的令牌": "Create New Token", |   "创建新的令牌": "Create New Token", | ||||||
|   "注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note that the quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account.", |   "注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note that the quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account.", | ||||||
|   "设置为无限额度": "Set to unlimited quota", |   "设为无限额度": "Set to unlimited quota", | ||||||
|   "更新令牌信息": "Update Token Information", |   "更新令牌信息": "Update Token Information", | ||||||
|   "请输入充值码!": "Please enter the recharge code!", |   "请输入充值码!": "Please enter the recharge code!", | ||||||
|   "请输入名称": "Please enter a name", |   "请输入名称": "Please enter a name", | ||||||
| @@ -493,6 +498,7 @@ | |||||||
|   "参数替换为你的部署名称(模型名称中的点会被剔除)": "Replace the parameter with your deployment name (dots in the model name will be removed)", |   "参数替换为你的部署名称(模型名称中的点会被剔除)": "Replace the parameter with your deployment name (dots in the model name will be removed)", | ||||||
|   "模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!", |   "模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!", | ||||||
|   "取消无限额度": "Cancel unlimited quota", |   "取消无限额度": "Cancel unlimited quota", | ||||||
|  |   "取消": "Cancel", | ||||||
|   "请输入新的剩余额度": "Please enter the new remaining quota", |   "请输入新的剩余额度": "Please enter the new remaining quota", | ||||||
|   "请输入单个兑换码中包含的额度": "Please enter the quota included in a single redemption code", |   "请输入单个兑换码中包含的额度": "Please enter the quota included in a single redemption code", | ||||||
|   "请输入用户名": "Please enter username", |   "请输入用户名": "Please enter username", | ||||||
| @@ -503,5 +509,20 @@ | |||||||
|   "请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT", |   "请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT", | ||||||
|   "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", |   "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", | ||||||
|   "Homepage URL 填": "Fill in the Homepage URL", |   "Homepage URL 填": "Fill in the Homepage URL", | ||||||
|   "Authorization callback URL 填": "Fill in the Authorization callback URL" |   "Authorization callback URL 填": "Fill in the Authorization callback URL", | ||||||
|  |   "请为通道命名": "Please name the channel", | ||||||
|  |   "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", | ||||||
|  |   "模型重定向": "Model redirection", | ||||||
|  |   "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel", | ||||||
|  |   "注意,": "Note that, ", | ||||||
|  |   ",图片演示。": "related image demo.", | ||||||
|  |   "令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!", | ||||||
|  |   "代理": "Proxy", | ||||||
|  |   "此项可选,用于通过代理站来进行 API 调用,请输入代理站地址,格式为:https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com", | ||||||
|  |   "取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?", | ||||||
|  |   "按照如下格式输入:": "Enter in the following format:", | ||||||
|  |   "模型版本": "Model version", | ||||||
|  |   "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", | ||||||
|  |   "点击查看": "click to view", | ||||||
|  |   "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!" | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										37
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								main.go
									
									
									
									
									
								
							| @@ -2,6 +2,7 @@ package main | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"embed" | 	"embed" | ||||||
|  | 	"fmt" | ||||||
| 	"github.com/gin-contrib/sessions" | 	"github.com/gin-contrib/sessions" | ||||||
| 	"github.com/gin-contrib/sessions/cookie" | 	"github.com/gin-contrib/sessions/cookie" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| @@ -21,11 +22,14 @@ var buildFS embed.FS | |||||||
| var indexPage []byte | var indexPage []byte | ||||||
|  |  | ||||||
| func main() { | func main() { | ||||||
| 	common.SetupGinLog() | 	common.SetupLogger() | ||||||
| 	common.SysLog("One API " + common.Version + " started") | 	common.SysLog("One API " + common.Version + " started") | ||||||
| 	if os.Getenv("GIN_MODE") != "debug" { | 	if os.Getenv("GIN_MODE") != "debug" { | ||||||
| 		gin.SetMode(gin.ReleaseMode) | 		gin.SetMode(gin.ReleaseMode) | ||||||
| 	} | 	} | ||||||
|  | 	if common.DebugEnabled { | ||||||
|  | 		common.SysLog("running in debug mode") | ||||||
|  | 	} | ||||||
| 	// Initialize SQL Database | 	// Initialize SQL Database | ||||||
| 	err := model.InitDB() | 	err := model.InitDB() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -47,17 +51,17 @@ func main() { | |||||||
| 	// Initialize options | 	// Initialize options | ||||||
| 	model.InitOptionMap() | 	model.InitOptionMap() | ||||||
| 	if common.RedisEnabled { | 	if common.RedisEnabled { | ||||||
|  | 		// for compatibility with old versions | ||||||
|  | 		common.MemoryCacheEnabled = true | ||||||
|  | 	} | ||||||
|  | 	if common.MemoryCacheEnabled { | ||||||
|  | 		common.SysLog("memory cache enabled") | ||||||
|  | 		common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) | ||||||
| 		model.InitChannelCache() | 		model.InitChannelCache() | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("SYNC_FREQUENCY") != "" { | 	if common.MemoryCacheEnabled { | ||||||
| 		frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY")) | 		go model.SyncOptions(common.SyncFrequency) | ||||||
| 		if err != nil { | 		go model.SyncChannelCache(common.SyncFrequency) | ||||||
| 			common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 		go model.SyncOptions(frequency) |  | ||||||
| 		if common.RedisEnabled { |  | ||||||
| 			go model.SyncChannelCache(frequency) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { | 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { | ||||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) | 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) | ||||||
| @@ -73,13 +77,20 @@ func main() { | |||||||
| 		} | 		} | ||||||
| 		go controller.AutomaticallyTestChannels(frequency) | 		go controller.AutomaticallyTestChannels(frequency) | ||||||
| 	} | 	} | ||||||
|  | 	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { | ||||||
|  | 		common.BatchUpdateEnabled = true | ||||||
|  | 		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") | ||||||
|  | 		model.InitBatchUpdater() | ||||||
|  | 	} | ||||||
|  | 	controller.InitTokenEncoders() | ||||||
|  |  | ||||||
| 	// Initialize HTTP server | 	// Initialize HTTP server | ||||||
| 	server := gin.Default() | 	server := gin.New() | ||||||
|  | 	server.Use(gin.Recovery()) | ||||||
| 	// This will cause SSE not to work!!! | 	// This will cause SSE not to work!!! | ||||||
| 	//server.Use(gzip.Gzip(gzip.DefaultCompression)) | 	//server.Use(gzip.Gzip(gzip.DefaultCompression)) | ||||||
| 	server.Use(middleware.CORS()) | 	server.Use(middleware.RequestId()) | ||||||
|  | 	middleware.SetUpLogger(server) | ||||||
| 	// Initialize session store | 	// Initialize session store | ||||||
| 	store := cookie.NewStore([]byte(common.SessionSecret)) | 	store := cookie.NewStore([]byte(common.SessionSecret)) | ||||||
| 	server.Use(sessions.Sessions("session", store)) | 	server.Use(sessions.Sessions("session", store)) | ||||||
|   | |||||||
| @@ -91,23 +91,16 @@ func TokenAuth() func(c *gin.Context) { | |||||||
| 		key = parts[0] | 		key = parts[0] | ||||||
| 		token, err := model.ValidateUserToken(key) | 		token, err := model.ValidateUserToken(key) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			c.JSON(http.StatusUnauthorized, gin.H{ | 			abortWithMessage(c, http.StatusUnauthorized, err.Error()) | ||||||
| 				"error": gin.H{ |  | ||||||
| 					"message": err.Error(), |  | ||||||
| 					"type":    "one_api_error", |  | ||||||
| 				}, |  | ||||||
| 			}) |  | ||||||
| 			c.Abort() |  | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		if !model.CacheIsUserEnabled(token.UserId) { | 		userEnabled, err := model.CacheIsUserEnabled(token.UserId) | ||||||
| 			c.JSON(http.StatusForbidden, gin.H{ | 		if err != nil { | ||||||
| 				"error": gin.H{ | 			abortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
| 					"message": "用户已被封禁", | 			return | ||||||
| 					"type":    "one_api_error", | 		} | ||||||
| 				}, | 		if !userEnabled { | ||||||
| 			}) | 			abortWithMessage(c, http.StatusForbidden, "用户已被封禁") | ||||||
| 			c.Abort() |  | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		c.Set("id", token.UserId) | 		c.Set("id", token.UserId) | ||||||
| @@ -123,13 +116,7 @@ func TokenAuth() func(c *gin.Context) { | |||||||
| 			if model.IsAdmin(token.UserId) { | 			if model.IsAdmin(token.UserId) { | ||||||
| 				c.Set("channelId", parts[1]) | 				c.Set("channelId", parts[1]) | ||||||
| 			} else { | 			} else { | ||||||
| 				c.JSON(http.StatusForbidden, gin.H{ | 				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": "普通用户不支持指定渠道", |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -25,48 +25,27 @@ func Distribute() func(c *gin.Context) { | |||||||
| 		if ok { | 		if ok { | ||||||
| 			id, err := strconv.Atoi(channelId.(string)) | 			id, err := strconv.Atoi(channelId.(string)) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				c.JSON(http.StatusBadRequest, gin.H{ | 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": "无效的渠道 ID", |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			channel, err = model.GetChannelById(id, true) | 			channel, err = model.GetChannelById(id, true) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				c.JSON(http.StatusBadRequest, gin.H{ | 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": "无效的渠道 ID", |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			if channel.Status != common.ChannelStatusEnabled { | 			if channel.Status != common.ChannelStatusEnabled { | ||||||
| 				c.JSON(http.StatusForbidden, gin.H{ | 				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": "该渠道已被禁用", |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} else { | 		} else { | ||||||
| 			// Select a channel for the user | 			// Select a channel for the user | ||||||
| 			var modelRequest ModelRequest | 			var modelRequest ModelRequest | ||||||
| 			err := common.UnmarshalBodyReusable(c, &modelRequest) | 			var err error | ||||||
|  | 			if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||||
|  | 				err = common.UnmarshalBodyReusable(c, &modelRequest) | ||||||
|  | 			} | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				c.JSON(http.StatusBadRequest, gin.H{ | 				abortWithMessage(c, http.StatusBadRequest, "无效的请求") | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": "无效的请求", |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||||
| @@ -84,31 +63,35 @@ func Distribute() func(c *gin.Context) { | |||||||
| 					modelRequest.Model = "dall-e" | 					modelRequest.Model = "dall-e" | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
|  | 			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||||
|  | 				if modelRequest.Model == "" { | ||||||
|  | 					modelRequest.Model = "whisper-1" | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) | 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				message := "无可用渠道" | 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | ||||||
| 				if channel != nil { | 				if channel != nil { | ||||||
| 					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | 					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | ||||||
| 					message = "数据库一致性已被破坏,请联系管理员" | 					message = "数据库一致性已被破坏,请联系管理员" | ||||||
| 				} | 				} | ||||||
| 				c.JSON(http.StatusServiceUnavailable, gin.H{ | 				abortWithMessage(c, http.StatusServiceUnavailable, message) | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": message, |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		c.Set("channel", channel.Type) | 		c.Set("channel", channel.Type) | ||||||
| 		c.Set("channel_id", channel.Id) | 		c.Set("channel_id", channel.Id) | ||||||
| 		c.Set("channel_name", channel.Name) | 		c.Set("channel_name", channel.Name) | ||||||
| 		c.Set("model_mapping", channel.ModelMapping) | 		c.Set("model_mapping", channel.GetModelMapping()) | ||||||
| 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||||
| 		c.Set("base_url", channel.BaseURL) | 		c.Set("base_url", channel.GetBaseURL()) | ||||||
| 		if channel.Type == common.ChannelTypeAzure { | 		switch channel.Type { | ||||||
|  | 		case common.ChannelTypeAzure: | ||||||
| 			c.Set("api_version", channel.Other) | 			c.Set("api_version", channel.Other) | ||||||
|  | 		case common.ChannelTypeXunfei: | ||||||
|  | 			c.Set("api_version", channel.Other) | ||||||
|  | 		case common.ChannelTypeAIProxyLibrary: | ||||||
|  | 			c.Set("library_id", channel.Other) | ||||||
| 		} | 		} | ||||||
| 		c.Next() | 		c.Next() | ||||||
| 	} | 	} | ||||||
|   | |||||||
							
								
								
									
										25
									
								
								middleware/logger.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								middleware/logger.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | |||||||
|  | package middleware | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"one-api/common" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func SetUpLogger(server *gin.Engine) { | ||||||
|  | 	server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { | ||||||
|  | 		var requestID string | ||||||
|  | 		if param.Keys != nil { | ||||||
|  | 			requestID = param.Keys[common.RequestIdKey].(string) | ||||||
|  | 		} | ||||||
|  | 		return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", | ||||||
|  | 			param.TimeStamp.Format("2006/01/02 - 15:04:05"), | ||||||
|  | 			requestID, | ||||||
|  | 			param.StatusCode, | ||||||
|  | 			param.Latency, | ||||||
|  | 			param.ClientIP, | ||||||
|  | 			param.Method, | ||||||
|  | 			param.Path, | ||||||
|  | 		) | ||||||
|  | 	})) | ||||||
|  | } | ||||||
							
								
								
									
										18
									
								
								middleware/request-id.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								middleware/request-id.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | |||||||
|  | package middleware | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"one-api/common" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func RequestId() func(c *gin.Context) { | ||||||
|  | 	return func(c *gin.Context) { | ||||||
|  | 		id := common.GetTimeString() + common.GetRandomString(8) | ||||||
|  | 		c.Set(common.RequestIdKey, id) | ||||||
|  | 		ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) | ||||||
|  | 		c.Request = c.Request.WithContext(ctx) | ||||||
|  | 		c.Header(common.RequestIdKey, id) | ||||||
|  | 		c.Next() | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										17
									
								
								middleware/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								middleware/utils.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | |||||||
|  | package middleware | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"one-api/common" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func abortWithMessage(c *gin.Context, statusCode int, message string) { | ||||||
|  | 	c.JSON(statusCode, gin.H{ | ||||||
|  | 		"error": gin.H{ | ||||||
|  | 			"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), | ||||||
|  | 			"type":    "one_api_error", | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	c.Abort() | ||||||
|  | 	common.LogError(c.Request.Context(), message) | ||||||
|  | } | ||||||
| @@ -10,15 +10,18 @@ type Ability struct { | |||||||
| 	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"` | 	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"` | ||||||
| 	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` | 	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` | ||||||
| 	Enabled   bool   `json:"enabled"` | 	Enabled   bool   `json:"enabled"` | ||||||
|  | 	Priority  *int64 `json:"priority" gorm:"bigint;default:0;index"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||||
| 	ability := Ability{} | 	ability := Ability{} | ||||||
| 	var err error = nil | 	var err error = nil | ||||||
|  | 	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model) | ||||||
|  | 	channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery) | ||||||
| 	if common.UsingSQLite { | 	if common.UsingSQLite { | ||||||
| 		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error | 		err = channelQuery.Order("RANDOM()").First(&ability).Error | ||||||
| 	} else { | 	} else { | ||||||
| 		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error | 		err = channelQuery.Order("RAND()").First(&ability).Error | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -40,6 +43,7 @@ func (channel *Channel) AddAbilities() error { | |||||||
| 				Model:     model, | 				Model:     model, | ||||||
| 				ChannelId: channel.Id, | 				ChannelId: channel.Id, | ||||||
| 				Enabled:   channel.Status == common.ChannelStatusEnabled, | 				Enabled:   channel.Status == common.ChannelStatusEnabled, | ||||||
|  | 				Priority:  channel.Priority, | ||||||
| 			} | 			} | ||||||
| 			abilities = append(abilities, ability) | 			abilities = append(abilities, ability) | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -6,17 +6,18 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
|  | 	"sort" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | var ( | ||||||
| 	TokenCacheSeconds         = 60 * 60 | 	TokenCacheSeconds         = common.SyncFrequency | ||||||
| 	UserId2GroupCacheSeconds  = 60 * 60 | 	UserId2GroupCacheSeconds  = common.SyncFrequency | ||||||
| 	UserId2QuotaCacheSeconds  = 10 * 60 | 	UserId2QuotaCacheSeconds  = common.SyncFrequency | ||||||
| 	UserId2StatusCacheSeconds = 60 * 60 | 	UserId2StatusCacheSeconds = common.SyncFrequency | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func CacheGetTokenByKey(key string) (*Token, error) { | func CacheGetTokenByKey(key string) (*Token, error) { | ||||||
| @@ -35,7 +36,7 @@ func CacheGetTokenByKey(key string) (*Token, error) { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), TokenCacheSeconds*time.Second) | 		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.SysError("Redis set token error: " + err.Error()) | 			common.SysError("Redis set token error: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| @@ -55,7 +56,7 @@ func CacheGetUserGroup(id int) (group string, err error) { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return "", err | 			return "", err | ||||||
| 		} | 		} | ||||||
| 		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, UserId2GroupCacheSeconds*time.Second) | 		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.SysError("Redis set user group error: " + err.Error()) | 			common.SysError("Redis set user group error: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| @@ -73,7 +74,7 @@ func CacheGetUserQuota(id int) (quota int, err error) { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return 0, err | 			return 0, err | ||||||
| 		} | 		} | ||||||
| 		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second) | 		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.SysError("Redis set user quota error: " + err.Error()) | 			common.SysError("Redis set user quota error: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| @@ -91,27 +92,40 @@ func CacheUpdateUserQuota(id int) error { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second) | 	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func CacheIsUserEnabled(userId int) bool { | func CacheDecreaseUserQuota(id int, quota int) error { | ||||||
|  | 	if !common.RedisEnabled { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota)) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func CacheIsUserEnabled(userId int) (bool, error) { | ||||||
| 	if !common.RedisEnabled { | 	if !common.RedisEnabled { | ||||||
| 		return IsUserEnabled(userId) | 		return IsUserEnabled(userId) | ||||||
| 	} | 	} | ||||||
| 	enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) | 	enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) | ||||||
| 	if err != nil { | 	if err == nil { | ||||||
| 		status := common.UserStatusDisabled | 		return enabled == "1", nil | ||||||
| 		if IsUserEnabled(userId) { |  | ||||||
| 			status = common.UserStatusEnabled |  | ||||||
| 	} | 	} | ||||||
| 		enabled = fmt.Sprintf("%d", status) |  | ||||||
| 		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, UserId2StatusCacheSeconds*time.Second) | 	userEnabled, err := IsUserEnabled(userId) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 	enabled = "0" | ||||||
|  | 	if userEnabled { | ||||||
|  | 		enabled = "1" | ||||||
|  | 	} | ||||||
|  | 	err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("Redis set user enabled error: " + err.Error()) | 		common.SysError("Redis set user enabled error: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	} | 	return userEnabled, err | ||||||
| 	return enabled == "1" |  | ||||||
| } | } | ||||||
|  |  | ||||||
| var group2model2channels map[string]map[string][]*Channel | var group2model2channels map[string]map[string][]*Channel | ||||||
| @@ -146,6 +160,17 @@ func InitChannelCache() { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// sort by priority | ||||||
|  | 	for group, model2channels := range newGroup2model2channels { | ||||||
|  | 		for model, channels := range model2channels { | ||||||
|  | 			sort.Slice(channels, func(i, j int) bool { | ||||||
|  | 				return channels[i].GetPriority() > channels[j].GetPriority() | ||||||
|  | 			}) | ||||||
|  | 			newGroup2model2channels[group][model] = channels | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	channelSyncLock.Lock() | 	channelSyncLock.Lock() | ||||||
| 	group2model2channels = newGroup2model2channels | 	group2model2channels = newGroup2model2channels | ||||||
| 	channelSyncLock.Unlock() | 	channelSyncLock.Unlock() | ||||||
| @@ -161,7 +186,7 @@ func SyncChannelCache(frequency int) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||||
| 	if !common.RedisEnabled { | 	if !common.MemoryCacheEnabled { | ||||||
| 		return GetRandomSatisfiedChannel(group, model) | 		return GetRandomSatisfiedChannel(group, model) | ||||||
| 	} | 	} | ||||||
| 	channelSyncLock.RLock() | 	channelSyncLock.RLock() | ||||||
| @@ -170,6 +195,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error | |||||||
| 	if len(channels) == 0 { | 	if len(channels) == 0 { | ||||||
| 		return nil, errors.New("channel not found") | 		return nil, errors.New("channel not found") | ||||||
| 	} | 	} | ||||||
| 	idx := rand.Intn(len(channels)) | 	endIdx := len(channels) | ||||||
|  | 	// choose by priority | ||||||
|  | 	firstChannel := channels[0] | ||||||
|  | 	if firstChannel.GetPriority() > 0 { | ||||||
|  | 		for i := range channels { | ||||||
|  | 			if channels[i].GetPriority() != firstChannel.GetPriority() { | ||||||
|  | 				endIdx = i | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	idx := rand.Intn(endIdx) | ||||||
| 	return channels[idx], nil | 	return channels[idx], nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -11,18 +11,19 @@ type Channel struct { | |||||||
| 	Key                string  `json:"key" gorm:"not null;index"` | 	Key                string  `json:"key" gorm:"not null;index"` | ||||||
| 	Status             int     `json:"status" gorm:"default:1"` | 	Status             int     `json:"status" gorm:"default:1"` | ||||||
| 	Name               string  `json:"name" gorm:"index"` | 	Name               string  `json:"name" gorm:"index"` | ||||||
| 	Weight             int     `json:"weight"` | 	Weight             *uint   `json:"weight" gorm:"default:0"` | ||||||
| 	CreatedTime        int64   `json:"created_time" gorm:"bigint"` | 	CreatedTime        int64   `json:"created_time" gorm:"bigint"` | ||||||
| 	TestTime           int64   `json:"test_time" gorm:"bigint"` | 	TestTime           int64   `json:"test_time" gorm:"bigint"` | ||||||
| 	ResponseTime       int     `json:"response_time"` // in milliseconds | 	ResponseTime       int     `json:"response_time"` // in milliseconds | ||||||
| 	BaseURL            string  `json:"base_url" gorm:"column:base_url"` | 	BaseURL            *string `json:"base_url" gorm:"column:base_url;default:''"` | ||||||
| 	Other              string  `json:"other"` | 	Other              string  `json:"other"` | ||||||
| 	Balance            float64 `json:"balance"` // in USD | 	Balance            float64 `json:"balance"` // in USD | ||||||
| 	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"` | 	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"` | ||||||
| 	Models             string  `json:"models"` | 	Models             string  `json:"models"` | ||||||
| 	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"` | 	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"` | ||||||
| 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"` | 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"` | ||||||
| 	ModelMapping       string  `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | ||||||
|  | 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | ||||||
| @@ -78,6 +79,27 @@ func BatchInsertChannels(channels []Channel) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (channel *Channel) GetPriority() int64 { | ||||||
|  | 	if channel.Priority == nil { | ||||||
|  | 		return 0 | ||||||
|  | 	} | ||||||
|  | 	return *channel.Priority | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (channel *Channel) GetBaseURL() string { | ||||||
|  | 	if channel.BaseURL == nil { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	return *channel.BaseURL | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (channel *Channel) GetModelMapping() string { | ||||||
|  | 	if channel.ModelMapping == nil { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	return *channel.ModelMapping | ||||||
|  | } | ||||||
|  |  | ||||||
| func (channel *Channel) Insert() error { | func (channel *Channel) Insert() error { | ||||||
| 	var err error | 	var err error | ||||||
| 	err = DB.Create(channel).Error | 	err = DB.Create(channel).Error | ||||||
| @@ -141,6 +163,14 @@ func UpdateChannelStatusById(id int, status int) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateChannelUsedQuota(id int, quota int) { | func UpdateChannelUsedQuota(id int, quota int) { | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	updateChannelUsedQuota(id, quota) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func updateChannelUsedQuota(id int, quota int) { | ||||||
| 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("failed to update channel used quota: " + err.Error()) | 		common.SysError("failed to update channel used quota: " + err.Error()) | ||||||
|   | |||||||
							
								
								
									
										30
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -1,13 +1,15 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Log struct { | type Log struct { | ||||||
| 	Id               int    `json:"id"` | 	Id               int    `json:"id"` | ||||||
| 	UserId           int    `json:"user_id"` | 	UserId           int    `json:"user_id" gorm:"index"` | ||||||
| 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index"` | 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index"` | ||||||
| 	Type             int    `json:"type" gorm:"index"` | 	Type             int    `json:"type" gorm:"index"` | ||||||
| 	Content          string `json:"content"` | 	Content          string `json:"content"` | ||||||
| @@ -17,6 +19,7 @@ type Log struct { | |||||||
| 	Quota            int    `json:"quota" gorm:"default:0"` | 	Quota            int    `json:"quota" gorm:"default:0"` | ||||||
| 	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"` | 	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"` | ||||||
| 	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"` | 	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"` | ||||||
|  | 	ChannelId        int    `json:"channel" gorm:"index"` | ||||||
| } | } | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| @@ -44,7 +47,8 @@ func RecordLog(userId int, logType int, content string) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | ||||||
|  | 	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||||
| 	if !common.LogConsumeEnabled { | 	if !common.LogConsumeEnabled { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -59,14 +63,15 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN | |||||||
| 		TokenName:        tokenName, | 		TokenName:        tokenName, | ||||||
| 		ModelName:        modelName, | 		ModelName:        modelName, | ||||||
| 		Quota:            quota, | 		Quota:            quota, | ||||||
|  | 		ChannelId:        channelId, | ||||||
| 	} | 	} | ||||||
| 	err := DB.Create(log).Error | 	err := DB.Create(log).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("failed to record log: " + err.Error()) | 		common.LogError(ctx, "failed to record log: "+err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) { | func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { | ||||||
| 	var tx *gorm.DB | 	var tx *gorm.DB | ||||||
| 	if logType == LogTypeUnknown { | 	if logType == LogTypeUnknown { | ||||||
| 		tx = DB | 		tx = DB | ||||||
| @@ -88,6 +93,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName | |||||||
| 	if endTimestamp != 0 { | 	if endTimestamp != 0 { | ||||||
| 		tx = tx.Where("created_at <= ?", endTimestamp) | 		tx = tx.Where("created_at <= ?", endTimestamp) | ||||||
| 	} | 	} | ||||||
|  | 	if channel != 0 { | ||||||
|  | 		tx = tx.Where("channel = ?", channel) | ||||||
|  | 	} | ||||||
| 	err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error | 	err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error | ||||||
| 	return logs, err | 	return logs, err | ||||||
| } | } | ||||||
| @@ -125,8 +133,8 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | |||||||
| 	return logs, err | 	return logs, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) { | func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { | ||||||
| 	tx := DB.Table("logs").Select("sum(quota)") | 	tx := DB.Table("logs").Select("ifnull(sum(quota),0)") | ||||||
| 	if username != "" { | 	if username != "" { | ||||||
| 		tx = tx.Where("username = ?", username) | 		tx = tx.Where("username = ?", username) | ||||||
| 	} | 	} | ||||||
| @@ -142,12 +150,15 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa | |||||||
| 	if modelName != "" { | 	if modelName != "" { | ||||||
| 		tx = tx.Where("model_name = ?", modelName) | 		tx = tx.Where("model_name = ?", modelName) | ||||||
| 	} | 	} | ||||||
|  | 	if channel != 0 { | ||||||
|  | 		tx = tx.Where("channel = ?", channel) | ||||||
|  | 	} | ||||||
| 	tx.Where("type = ?", LogTypeConsume).Scan("a) | 	tx.Where("type = ?", LogTypeConsume).Scan("a) | ||||||
| 	return quota | 	return quota | ||||||
| } | } | ||||||
|  |  | ||||||
| func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | ||||||
| 	tx := DB.Table("logs").Select("sum(prompt_tokens) + sum(completion_tokens)") | 	tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") | ||||||
| 	if username != "" { | 	if username != "" { | ||||||
| 		tx = tx.Where("username = ?", username) | 		tx = tx.Where("username = ?", username) | ||||||
| 	} | 	} | ||||||
| @@ -166,3 +177,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa | |||||||
| 	tx.Where("type = ?", LogTypeConsume).Scan(&token) | 	tx.Where("type = ?", LogTypeConsume).Scan(&token) | ||||||
| 	return token | 	return token | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func DeleteOldLog(targetTimestamp int64) (int64, error) { | ||||||
|  | 	result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) | ||||||
|  | 	return result.RowsAffected, result.Error | ||||||
|  | } | ||||||
|   | |||||||
| @@ -2,10 +2,13 @@ package model | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"gorm.io/driver/mysql" | 	"gorm.io/driver/mysql" | ||||||
|  | 	"gorm.io/driver/postgres" | ||||||
| 	"gorm.io/driver/sqlite" | 	"gorm.io/driver/sqlite" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"os" | 	"os" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var DB *gorm.DB | var DB *gorm.DB | ||||||
| @@ -33,34 +36,52 @@ func createRootAccountIfNeed() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func CountTable(tableName string) (num int64) { | func chooseDB() (*gorm.DB, error) { | ||||||
| 	DB.Table(tableName).Count(&num) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func InitDB() (err error) { |  | ||||||
| 	var db *gorm.DB |  | ||||||
| 	if os.Getenv("SQL_DSN") != "" { | 	if os.Getenv("SQL_DSN") != "" { | ||||||
| 		// Use MySQL | 		dsn := os.Getenv("SQL_DSN") | ||||||
| 		common.SysLog("using MySQL as database") | 		if strings.HasPrefix(dsn, "postgres://") { | ||||||
| 		db, err = gorm.Open(mysql.Open(os.Getenv("SQL_DSN")), &gorm.Config{ | 			// Use PostgreSQL | ||||||
|  | 			common.SysLog("using PostgreSQL as database") | ||||||
|  | 			return gorm.Open(postgres.New(postgres.Config{ | ||||||
|  | 				DSN:                  dsn, | ||||||
|  | 				PreferSimpleProtocol: true, // disables implicit prepared statement usage | ||||||
|  | 			}), &gorm.Config{ | ||||||
| 				PrepareStmt: true, // precompile SQL | 				PrepareStmt: true, // precompile SQL | ||||||
| 			}) | 			}) | ||||||
| 	} else { | 		} | ||||||
|  | 		// Use MySQL | ||||||
|  | 		common.SysLog("using MySQL as database") | ||||||
|  | 		return gorm.Open(mysql.Open(dsn), &gorm.Config{ | ||||||
|  | 			PrepareStmt: true, // precompile SQL | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
| 	// Use SQLite | 	// Use SQLite | ||||||
| 	common.SysLog("SQL_DSN not set, using SQLite as database") | 	common.SysLog("SQL_DSN not set, using SQLite as database") | ||||||
| 	common.UsingSQLite = true | 	common.UsingSQLite = true | ||||||
| 		db, err = gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ | 	return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ | ||||||
| 		PrepareStmt: true, // precompile SQL | 		PrepareStmt: true, // precompile SQL | ||||||
| 	}) | 	}) | ||||||
| 	} | } | ||||||
| 	common.SysLog("database connected") |  | ||||||
|  | func InitDB() (err error) { | ||||||
|  | 	db, err := chooseDB() | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
|  | 		if common.DebugEnabled { | ||||||
|  | 			db = db.Debug() | ||||||
|  | 		} | ||||||
| 		DB = db | 		DB = db | ||||||
|  | 		sqlDB, err := DB.DB() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) | ||||||
|  | 		sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) | ||||||
|  | 		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) | ||||||
|  |  | ||||||
| 		if !common.IsMasterNode { | 		if !common.IsMasterNode { | ||||||
| 			return nil | 			return nil | ||||||
| 		} | 		} | ||||||
| 		err := db.AutoMigrate(&Channel{}) | 		err = db.AutoMigrate(&Channel{}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -39,6 +39,8 @@ func InitOptionMap() { | |||||||
| 	common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) | 	common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) | ||||||
| 	common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) | 	common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) | ||||||
| 	common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) | 	common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) | ||||||
|  | 	common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) | ||||||
|  | 	common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") | ||||||
| 	common.OptionMap["SMTPServer"] = "" | 	common.OptionMap["SMTPServer"] = "" | ||||||
| 	common.OptionMap["SMTPFrom"] = "" | 	common.OptionMap["SMTPFrom"] = "" | ||||||
| 	common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) | 	common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) | ||||||
| @@ -141,6 +143,8 @@ func updateOptionMap(key string, value string) (err error) { | |||||||
| 			common.TurnstileCheckEnabled = boolValue | 			common.TurnstileCheckEnabled = boolValue | ||||||
| 		case "RegisterEnabled": | 		case "RegisterEnabled": | ||||||
| 			common.RegisterEnabled = boolValue | 			common.RegisterEnabled = boolValue | ||||||
|  | 		case "EmailDomainRestrictionEnabled": | ||||||
|  | 			common.EmailDomainRestrictionEnabled = boolValue | ||||||
| 		case "AutomaticDisableChannelEnabled": | 		case "AutomaticDisableChannelEnabled": | ||||||
| 			common.AutomaticDisableChannelEnabled = boolValue | 			common.AutomaticDisableChannelEnabled = boolValue | ||||||
| 		case "ApproximateTokenEnabled": | 		case "ApproximateTokenEnabled": | ||||||
| @@ -154,6 +158,8 @@ func updateOptionMap(key string, value string) (err error) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	switch key { | 	switch key { | ||||||
|  | 	case "EmailDomainWhitelist": | ||||||
|  | 		common.EmailDomainWhitelist = strings.Split(value, ",") | ||||||
| 	case "SMTPServer": | 	case "SMTPServer": | ||||||
| 		common.SMTPServer = value | 		common.SMTPServer = value | ||||||
| 	case "SMTPPort": | 	case "SMTPPort": | ||||||
|   | |||||||
| @@ -51,20 +51,21 @@ func Redeem(key string, userId int) (quota int, err error) { | |||||||
| 	redemption := &Redemption{} | 	redemption := &Redemption{} | ||||||
|  |  | ||||||
| 	err = DB.Transaction(func(tx *gorm.DB) error { | 	err = DB.Transaction(func(tx *gorm.DB) error { | ||||||
| 		err := DB.Where("`key` = ?", key).First(redemption).Error | 		err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errors.New("无效的兑换码") | 			return errors.New("无效的兑换码") | ||||||
| 		} | 		} | ||||||
| 		if redemption.Status != common.RedemptionCodeStatusEnabled { | 		if redemption.Status != common.RedemptionCodeStatusEnabled { | ||||||
| 			return errors.New("该兑换码已被使用") | 			return errors.New("该兑换码已被使用") | ||||||
| 		} | 		} | ||||||
| 		err = DB.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error | 		err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 		redemption.RedeemedTime = common.GetTimestamp() | 		redemption.RedeemedTime = common.GetTimestamp() | ||||||
| 		redemption.Status = common.RedemptionCodeStatusUsed | 		redemption.Status = common.RedemptionCodeStatusUsed | ||||||
| 		return redemption.SelectUpdate() | 		err = tx.Save(redemption).Error | ||||||
|  | 		return err | ||||||
| 	}) | 	}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return 0, errors.New("兑换失败," + err.Error()) | 		return 0, errors.New("兑换失败," + err.Error()) | ||||||
|   | |||||||
| @@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) { | |||||||
| 	} | 	} | ||||||
| 	token, err = CacheGetTokenByKey(key) | 	token, err = CacheGetTokenByKey(key) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
|  | 		if token.Status == common.TokenStatusExhausted { | ||||||
|  | 			return nil, errors.New("该令牌额度已用尽") | ||||||
|  | 		} else if token.Status == common.TokenStatusExpired { | ||||||
|  | 			return nil, errors.New("该令牌已过期") | ||||||
|  | 		} | ||||||
| 		if token.Status != common.TokenStatusEnabled { | 		if token.Status != common.TokenStatusEnabled { | ||||||
| 			return nil, errors.New("该令牌状态不可用") | 			return nil, errors.New("该令牌状态不可用") | ||||||
| 		} | 		} | ||||||
| 		if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { | 		if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { | ||||||
|  | 			if !common.RedisEnabled { | ||||||
| 				token.Status = common.TokenStatusExpired | 				token.Status = common.TokenStatusExpired | ||||||
| 				err := token.SelectUpdate() | 				err := token.SelectUpdate() | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					common.SysError("failed to update token status" + err.Error()) | 					common.SysError("failed to update token status" + err.Error()) | ||||||
| 				} | 				} | ||||||
|  | 			} | ||||||
| 			return nil, errors.New("该令牌已过期") | 			return nil, errors.New("该令牌已过期") | ||||||
| 		} | 		} | ||||||
| 		if !token.UnlimitedQuota && token.RemainQuota <= 0 { | 		if !token.UnlimitedQuota && token.RemainQuota <= 0 { | ||||||
|  | 			if !common.RedisEnabled { | ||||||
|  | 				// in this case, we can make sure the token is exhausted | ||||||
| 				token.Status = common.TokenStatusExhausted | 				token.Status = common.TokenStatusExhausted | ||||||
| 				err := token.SelectUpdate() | 				err := token.SelectUpdate() | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					common.SysError("failed to update token status" + err.Error()) | 					common.SysError("failed to update token status" + err.Error()) | ||||||
| 				} | 				} | ||||||
|  | 			} | ||||||
| 			return nil, errors.New("该令牌额度已用尽") | 			return nil, errors.New("该令牌额度已用尽") | ||||||
| 		} | 		} | ||||||
| 		go func() { |  | ||||||
| 			token.AccessedTime = common.GetTimestamp() |  | ||||||
| 			err := token.SelectUpdate() |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("failed to update token" + err.Error()) |  | ||||||
| 			} |  | ||||||
| 		}() |  | ||||||
| 		return token, nil | 		return token, nil | ||||||
| 	} | 	} | ||||||
| 	return nil, errors.New("无效的令牌") | 	return nil, errors.New("无效的令牌") | ||||||
| @@ -131,10 +134,19 @@ func IncreaseTokenQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeTokenQuota, id, quota) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return increaseTokenQuota(id, quota) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func increaseTokenQuota(id int, quota int) (err error) { | ||||||
| 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"remain_quota":  gorm.Expr("remain_quota + ?", quota), | 			"remain_quota":  gorm.Expr("remain_quota + ?", quota), | ||||||
| 			"used_quota":    gorm.Expr("used_quota - ?", quota), | 			"used_quota":    gorm.Expr("used_quota - ?", quota), | ||||||
|  | 			"accessed_time": common.GetTimestamp(), | ||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	return err | 	return err | ||||||
| @@ -144,10 +156,19 @@ func DecreaseTokenQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return decreaseTokenQuota(id, quota) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func decreaseTokenQuota(id int, quota int) (err error) { | ||||||
| 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"remain_quota":  gorm.Expr("remain_quota - ?", quota), | 			"remain_quota":  gorm.Expr("remain_quota - ?", quota), | ||||||
| 			"used_quota":    gorm.Expr("used_quota + ?", quota), | 			"used_quota":    gorm.Expr("used_quota + ?", quota), | ||||||
|  | 			"accessed_time": common.GetTimestamp(), | ||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	return err | 	return err | ||||||
|   | |||||||
| @@ -226,17 +226,16 @@ func IsAdmin(userId int) bool { | |||||||
| 	return user.Role >= common.RoleAdminUser | 	return user.Role >= common.RoleAdminUser | ||||||
| } | } | ||||||
|  |  | ||||||
| func IsUserEnabled(userId int) bool { | func IsUserEnabled(userId int) (bool, error) { | ||||||
| 	if userId == 0 { | 	if userId == 0 { | ||||||
| 		return false | 		return false, errors.New("user id is empty") | ||||||
| 	} | 	} | ||||||
| 	var user User | 	var user User | ||||||
| 	err := DB.Where("id = ?", userId).Select("status").Find(&user).Error | 	err := DB.Where("id = ?", userId).Select("status").Find(&user).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("no such user " + err.Error()) | 		return false, err | ||||||
| 		return false |  | ||||||
| 	} | 	} | ||||||
| 	return user.Status == common.UserStatusEnabled | 	return user.Status == common.UserStatusEnabled, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func ValidateAccessToken(token string) (user *User) { | func ValidateAccessToken(token string) (user *User) { | ||||||
| @@ -275,6 +274,14 @@ func IncreaseUserQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeUserQuota, id, quota) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return increaseUserQuota(id, quota) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func increaseUserQuota(id int, quota int) (err error) { | ||||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error | 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| @@ -283,6 +290,14 @@ func DecreaseUserQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeUserQuota, id, -quota) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return decreaseUserQuota(id, quota) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func decreaseUserQuota(id int, quota int) (err error) { | ||||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error | 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| @@ -293,10 +308,18 @@ func GetRootUserEmail() (email string) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	updateUserUsedQuotaAndRequestCount(id, quota, 1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | ||||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"used_quota":    gorm.Expr("used_quota + ?", quota), | 			"used_quota":    gorm.Expr("used_quota + ?", quota), | ||||||
| 			"request_count": gorm.Expr("request_count + ?", 1), | 			"request_count": gorm.Expr("request_count + ?", count), | ||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|   | |||||||
							
								
								
									
										75
									
								
								model/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								model/utils.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,75 @@ | |||||||
|  | package model | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	BatchUpdateTypeUserQuota = iota | ||||||
|  | 	BatchUpdateTypeTokenQuota | ||||||
|  | 	BatchUpdateTypeUsedQuotaAndRequestCount | ||||||
|  | 	BatchUpdateTypeChannelUsedQuota | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var batchUpdateStores []map[int]int | ||||||
|  | var batchUpdateLocks []sync.Mutex | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||||
|  | 		batchUpdateStores = append(batchUpdateStores, make(map[int]int)) | ||||||
|  | 		batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func InitBatchUpdater() { | ||||||
|  | 	go func() { | ||||||
|  | 		for { | ||||||
|  | 			time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) | ||||||
|  | 			batchUpdate() | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func addNewRecord(type_ int, id int, value int) { | ||||||
|  | 	batchUpdateLocks[type_].Lock() | ||||||
|  | 	defer batchUpdateLocks[type_].Unlock() | ||||||
|  | 	if _, ok := batchUpdateStores[type_][id]; !ok { | ||||||
|  | 		batchUpdateStores[type_][id] = value | ||||||
|  | 	} else { | ||||||
|  | 		batchUpdateStores[type_][id] += value | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func batchUpdate() { | ||||||
|  | 	common.SysLog("batch update started") | ||||||
|  | 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||||
|  | 		batchUpdateLocks[i].Lock() | ||||||
|  | 		store := batchUpdateStores[i] | ||||||
|  | 		batchUpdateStores[i] = make(map[int]int) | ||||||
|  | 		batchUpdateLocks[i].Unlock() | ||||||
|  |  | ||||||
|  | 		for key, value := range store { | ||||||
|  | 			switch i { | ||||||
|  | 			case BatchUpdateTypeUserQuota: | ||||||
|  | 				err := increaseUserQuota(key, value) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.SysError("failed to batch update user quota: " + err.Error()) | ||||||
|  | 				} | ||||||
|  | 			case BatchUpdateTypeTokenQuota: | ||||||
|  | 				err := increaseTokenQuota(key, value) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.SysError("failed to batch update token quota: " + err.Error()) | ||||||
|  | 				} | ||||||
|  | 			case BatchUpdateTypeUsedQuotaAndRequestCount: | ||||||
|  | 				updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect | ||||||
|  | 			case BatchUpdateTypeChannelUsedQuota: | ||||||
|  | 				updateChannelUsedQuota(key, value) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	common.SysLog("batch update finished") | ||||||
|  | } | ||||||
| @@ -21,6 +21,7 @@ func SetApiRouter(router *gin.Engine) { | |||||||
| 		apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) | 		apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) | ||||||
| 		apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) | 		apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) | ||||||
| 		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) | 		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) | ||||||
|  | 		apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) | ||||||
| 		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) | 		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) | ||||||
| 		apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) | 		apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) | ||||||
| 		apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) | 		apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) | ||||||
| @@ -97,6 +98,7 @@ func SetApiRouter(router *gin.Engine) { | |||||||
| 		} | 		} | ||||||
| 		logRoute := apiRouter.Group("/log") | 		logRoute := apiRouter.Group("/log") | ||||||
| 		logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs) | 		logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs) | ||||||
|  | 		logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs) | ||||||
| 		logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat) | 		logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat) | ||||||
| 		logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat) | 		logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat) | ||||||
| 		logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) | 		logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) | ||||||
|   | |||||||
| @@ -8,11 +8,12 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| func SetRelayRouter(router *gin.Engine) { | func SetRelayRouter(router *gin.Engine) { | ||||||
|  | 	router.Use(middleware.CORS()) | ||||||
| 	// https://platform.openai.com/docs/api-reference/introduction | 	// https://platform.openai.com/docs/api-reference/introduction | ||||||
| 	modelsRouter := router.Group("/v1/models") | 	modelsRouter := router.Group("/v1/models") | ||||||
| 	modelsRouter.Use(middleware.TokenAuth()) | 	modelsRouter.Use(middleware.TokenAuth()) | ||||||
| 	{ | 	{ | ||||||
| 		modelsRouter.GET("/", controller.ListModels) | 		modelsRouter.GET("", controller.ListModels) | ||||||
| 		modelsRouter.GET("/:model", controller.RetrieveModel) | 		modelsRouter.GET("/:model", controller.RetrieveModel) | ||||||
| 	} | 	} | ||||||
| 	relayV1Router := router.Group("/v1") | 	relayV1Router := router.Group("/v1") | ||||||
| @@ -26,8 +27,8 @@ func SetRelayRouter(router *gin.Engine) { | |||||||
| 		relayV1Router.POST("/images/variations", controller.RelayNotImplemented) | 		relayV1Router.POST("/images/variations", controller.RelayNotImplemented) | ||||||
| 		relayV1Router.POST("/embeddings", controller.Relay) | 		relayV1Router.POST("/embeddings", controller.Relay) | ||||||
| 		relayV1Router.POST("/engines/:model/embeddings", controller.Relay) | 		relayV1Router.POST("/engines/:model/embeddings", controller.Relay) | ||||||
| 		relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented) | 		relayV1Router.POST("/audio/transcriptions", controller.Relay) | ||||||
| 		relayV1Router.POST("/audio/translations", controller.RelayNotImplemented) | 		relayV1Router.POST("/audio/translations", controller.Relay) | ||||||
| 		relayV1Router.GET("/files", controller.RelayNotImplemented) | 		relayV1Router.GET("/files", controller.RelayNotImplemented) | ||||||
| 		relayV1Router.POST("/files", controller.RelayNotImplemented) | 		relayV1Router.POST("/files", controller.RelayNotImplemented) | ||||||
| 		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) | 		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) | ||||||
|   | |||||||
| @@ -18,7 +18,7 @@ func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { | |||||||
| 	router.Use(middleware.Cache()) | 	router.Use(middleware.Cache()) | ||||||
| 	router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/build"))) | 	router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/build"))) | ||||||
| 	router.NoRoute(func(c *gin.Context) { | 	router.NoRoute(func(c *gin.Context) { | ||||||
| 		if strings.HasPrefix(c.Request.RequestURI, "/v1") { | 		if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") { | ||||||
| 			controller.RelayNotFound(c) | 			controller.RelayNotFound(c) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; | import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react'; | ||||||
| import { Link } from 'react-router-dom'; | import { Link } from 'react-router-dom'; | ||||||
| import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; | import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers'; | ||||||
|  |  | ||||||
| import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; | import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; | ||||||
| import { renderGroup, renderNumber } from '../helpers/render'; | import { renderGroup, renderNumber } from '../helpers/render'; | ||||||
| @@ -24,7 +24,7 @@ function renderType(type) { | |||||||
|     } |     } | ||||||
|     type2label[0] = { value: 0, text: '未知类型', color: 'grey' }; |     type2label[0] = { value: 0, text: '未知类型', color: 'grey' }; | ||||||
|   } |   } | ||||||
|   return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>; |   return <Label basic color={type2label[type]?.color}>{type2label[type]?.text}</Label>; | ||||||
| } | } | ||||||
|  |  | ||||||
| function renderBalance(type, balance) { | function renderBalance(type, balance) { | ||||||
| @@ -96,7 +96,7 @@ const ChannelsTable = () => { | |||||||
|       }); |       }); | ||||||
|   }, []); |   }, []); | ||||||
|  |  | ||||||
|   const manageChannel = async (id, action, idx) => { |   const manageChannel = async (id, action, idx, priority) => { | ||||||
|     let data = { id }; |     let data = { id }; | ||||||
|     let res; |     let res; | ||||||
|     switch (action) { |     switch (action) { | ||||||
| @@ -111,6 +111,13 @@ const ChannelsTable = () => { | |||||||
|         data.status = 2; |         data.status = 2; | ||||||
|         res = await API.put('/api/channel/', data); |         res = await API.put('/api/channel/', data); | ||||||
|         break; |         break; | ||||||
|  |       case 'priority': | ||||||
|  |         if (priority === '') { | ||||||
|  |           return; | ||||||
|  |         } | ||||||
|  |         data.priority = parseInt(priority); | ||||||
|  |         res = await API.put('/api/channel/', data); | ||||||
|  |         break; | ||||||
|     } |     } | ||||||
|     const { success, message } = res.data; |     const { success, message } = res.data; | ||||||
|     if (success) { |     if (success) { | ||||||
| @@ -195,6 +202,7 @@ const ChannelsTable = () => { | |||||||
|       showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); |       showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); | ||||||
|     } else { |     } else { | ||||||
|       showError(message); |       showError(message); | ||||||
|  |       showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。") | ||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
| @@ -334,6 +342,14 @@ const ChannelsTable = () => { | |||||||
|             > |             > | ||||||
|               余额 |               余额 | ||||||
|             </Table.HeaderCell> |             </Table.HeaderCell> | ||||||
|  |             <Table.HeaderCell | ||||||
|  |                 style={{ cursor: 'pointer' }} | ||||||
|  |                 onClick={() => { | ||||||
|  |                   sortChannel('priority'); | ||||||
|  |                 }} | ||||||
|  |             > | ||||||
|  |               优先级 | ||||||
|  |             </Table.HeaderCell> | ||||||
|             <Table.HeaderCell>操作</Table.HeaderCell> |             <Table.HeaderCell>操作</Table.HeaderCell> | ||||||
|           </Table.Row> |           </Table.Row> | ||||||
|         </Table.Header> |         </Table.Header> | ||||||
| @@ -363,9 +379,28 @@ const ChannelsTable = () => { | |||||||
|                   </Table.Cell> |                   </Table.Cell> | ||||||
|                   <Table.Cell> |                   <Table.Cell> | ||||||
|                     <Popup |                     <Popup | ||||||
|                       content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'} |                       trigger={<span onClick={() => { | ||||||
|                       key={channel.id} |                         updateChannelBalance(channel.id, channel.name, idx); | ||||||
|                       trigger={renderBalance(channel.type, channel.balance)} |                       }} style={{ cursor: 'pointer' }}> | ||||||
|  |                       {renderBalance(channel.type, channel.balance)} | ||||||
|  |                     </span>} | ||||||
|  |                       content='点击更新' | ||||||
|  |                       basic | ||||||
|  |                     /> | ||||||
|  |                   </Table.Cell> | ||||||
|  |                   <Table.Cell> | ||||||
|  |                     <Popup | ||||||
|  |                         trigger={<Input type="number"  defaultValue={channel.priority} onBlur={(event) => { | ||||||
|  |                           manageChannel( | ||||||
|  |                               channel.id, | ||||||
|  |                               'priority', | ||||||
|  |                               idx, | ||||||
|  |                               event.target.value, | ||||||
|  |                           ); | ||||||
|  |                         }}> | ||||||
|  |                           <input style={{maxWidth:'60px'}} /> | ||||||
|  |                         </Input>} | ||||||
|  |                         content='渠道选择优先级,越高越优先' | ||||||
|                         basic |                         basic | ||||||
|                     /> |                     /> | ||||||
|                   </Table.Cell> |                   </Table.Cell> | ||||||
| @@ -380,16 +415,16 @@ const ChannelsTable = () => { | |||||||
|                       > |                       > | ||||||
|                         测试 |                         测试 | ||||||
|                       </Button> |                       </Button> | ||||||
|                       <Button |                       {/*<Button*/} | ||||||
|                         size={'small'} |                       {/*  size={'small'}*/} | ||||||
|                         positive |                       {/*  positive*/} | ||||||
|                         loading={updatingBalance} |                       {/*  loading={updatingBalance}*/} | ||||||
|                         onClick={() => { |                       {/*  onClick={() => {*/} | ||||||
|                           updateChannelBalance(channel.id, channel.name, idx); |                       {/*    updateChannelBalance(channel.id, channel.name, idx);*/} | ||||||
|                         }} |                       {/*  }}*/} | ||||||
|                       > |                       {/*>*/} | ||||||
|                         更新余额 |                       {/*  更新余额*/} | ||||||
|                       </Button> |                       {/*</Button>*/} | ||||||
|                       <Popup |                       <Popup | ||||||
|                         trigger={ |                         trigger={ | ||||||
|                           <Button size='small' negative> |                           <Button size='small' negative> | ||||||
| @@ -437,7 +472,7 @@ const ChannelsTable = () => { | |||||||
|  |  | ||||||
|         <Table.Footer> |         <Table.Footer> | ||||||
|           <Table.Row> |           <Table.Row> | ||||||
|             <Table.HeaderCell colSpan='8'> |             <Table.HeaderCell colSpan='9'> | ||||||
|               <Button size='small' as={Link} to='/channel/add' loading={loading}> |               <Button size='small' as={Link} to='/channel/add' loading={loading}> | ||||||
|                 添加新的渠道 |                 添加新的渠道 | ||||||
|               </Button> |               </Button> | ||||||
|   | |||||||
| @@ -13,8 +13,8 @@ const GitHubOAuth = () => { | |||||||
|  |  | ||||||
|   let navigate = useNavigate(); |   let navigate = useNavigate(); | ||||||
|  |  | ||||||
|   const sendCode = async (code, count) => { |   const sendCode = async (code, state, count) => { | ||||||
|     const res = await API.get(`/api/oauth/github?code=${code}`); |     const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`); | ||||||
|     const { success, message, data } = res.data; |     const { success, message, data } = res.data; | ||||||
|     if (success) { |     if (success) { | ||||||
|       if (message === 'bind') { |       if (message === 'bind') { | ||||||
| @@ -36,13 +36,14 @@ const GitHubOAuth = () => { | |||||||
|       count++; |       count++; | ||||||
|       setPrompt(`出现错误,第 ${count} 次重试中...`); |       setPrompt(`出现错误,第 ${count} 次重试中...`); | ||||||
|       await new Promise((resolve) => setTimeout(resolve, count * 2000)); |       await new Promise((resolve) => setTimeout(resolve, count * 2000)); | ||||||
|       await sendCode(code, count); |       await sendCode(code, state, count); | ||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   useEffect(() => { |   useEffect(() => { | ||||||
|     let code = searchParams.get('code'); |     let code = searchParams.get('code'); | ||||||
|     sendCode(code, 0).then(); |     let state = searchParams.get('state'); | ||||||
|  |     sendCode(code, state, 0).then(); | ||||||
|   }, []); |   }, []); | ||||||
|  |  | ||||||
|   return ( |   return ( | ||||||
|   | |||||||
| @@ -1,36 +1,26 @@ | |||||||
| import React, { useContext, useEffect, useState } from 'react'; | import React, { useContext, useEffect, useState } from 'react'; | ||||||
| import { | import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react'; | ||||||
|   Button, |  | ||||||
|   Divider, |  | ||||||
|   Form, |  | ||||||
|   Grid, |  | ||||||
|   Header, |  | ||||||
|   Image, |  | ||||||
|   Message, |  | ||||||
|   Modal, |  | ||||||
|   Segment, |  | ||||||
| } from 'semantic-ui-react'; |  | ||||||
| import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | ||||||
| import { UserContext } from '../context/User'; | import { UserContext } from '../context/User'; | ||||||
| import { API, getLogo, showError, showSuccess, showInfo } from '../helpers'; | import { API, getLogo, showError, showSuccess, showWarning } from '../helpers'; | ||||||
|  | import { onGitHubOAuthClicked } from './utils'; | ||||||
|  |  | ||||||
| const LoginForm = () => { | const LoginForm = () => { | ||||||
|   const [inputs, setInputs] = useState({ |   const [inputs, setInputs] = useState({ | ||||||
|     username: '', |     username: '', | ||||||
|     password: '', |     password: '', | ||||||
|     wechat_verification_code: '', |     wechat_verification_code: '' | ||||||
|   }); |   }); | ||||||
|   const [searchParams, setSearchParams] = useSearchParams(); |   const [searchParams, setSearchParams] = useSearchParams(); | ||||||
|   const [submitted, setSubmitted] = useState(false); |   const [submitted, setSubmitted] = useState(false); | ||||||
|   const { username, password } = inputs; |   const { username, password } = inputs; | ||||||
|   const [userState, userDispatch] = useContext(UserContext); |   const [userState, userDispatch] = useContext(UserContext); | ||||||
|   let navigate = useNavigate(); |   let navigate = useNavigate(); | ||||||
|  |  | ||||||
|   const [status, setStatus] = useState({}); |   const [status, setStatus] = useState({}); | ||||||
|   const logo = getLogo(); |   const logo = getLogo(); | ||||||
|  |  | ||||||
|   useEffect(() => { |   useEffect(() => { | ||||||
|     if (searchParams.get("expired")) { |     if (searchParams.get('expired')) { | ||||||
|       showError('未登录或登录已过期,请重新登录!'); |       showError('未登录或登录已过期,请重新登录!'); | ||||||
|     } |     } | ||||||
|     let status = localStorage.getItem('status'); |     let status = localStorage.getItem('status'); | ||||||
| @@ -42,12 +32,6 @@ const LoginForm = () => { | |||||||
|  |  | ||||||
|   const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); |   const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); | ||||||
|  |  | ||||||
|   const onGitHubOAuthClicked = () => { |  | ||||||
|     window.open( |  | ||||||
|       `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` |  | ||||||
|     ); |  | ||||||
|   }; |  | ||||||
|  |  | ||||||
|   const onWeChatLoginClicked = () => { |   const onWeChatLoginClicked = () => { | ||||||
|     setShowWeChatLoginModal(true); |     setShowWeChatLoginModal(true); | ||||||
|   }; |   }; | ||||||
| @@ -78,14 +62,20 @@ const LoginForm = () => { | |||||||
|     if (username && password) { |     if (username && password) { | ||||||
|       const res = await API.post(`/api/user/login`, { |       const res = await API.post(`/api/user/login`, { | ||||||
|         username, |         username, | ||||||
|         password, |         password | ||||||
|       }); |       }); | ||||||
|       const { success, message, data } = res.data; |       const { success, message, data } = res.data; | ||||||
|       if (success) { |       if (success) { | ||||||
|         userDispatch({ type: 'login', payload: data }); |         userDispatch({ type: 'login', payload: data }); | ||||||
|         localStorage.setItem('user', JSON.stringify(data)); |         localStorage.setItem('user', JSON.stringify(data)); | ||||||
|         navigate('/'); |         if (username === 'root' && password === '123456') { | ||||||
|  |           navigate('/user/edit'); | ||||||
|           showSuccess('登录成功!'); |           showSuccess('登录成功!'); | ||||||
|  |           showWarning('请立刻修改默认密码!'); | ||||||
|  |         } else { | ||||||
|  |           navigate('/token'); | ||||||
|  |           showSuccess('登录成功!'); | ||||||
|  |         } | ||||||
|       } else { |       } else { | ||||||
|         showError(message); |         showError(message); | ||||||
|       } |       } | ||||||
| @@ -93,44 +83,44 @@ const LoginForm = () => { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   return ( |   return ( | ||||||
|     <Grid textAlign="center" style={{ marginTop: '48px' }}> |     <Grid textAlign='center' style={{ marginTop: '48px' }}> | ||||||
|       <Grid.Column style={{ maxWidth: 450 }}> |       <Grid.Column style={{ maxWidth: 450 }}> | ||||||
|         <Header as="h2" color="" textAlign="center"> |         <Header as='h2' color='' textAlign='center'> | ||||||
|           <Image src={logo} /> 用户登录 |           <Image src={logo} /> 用户登录 | ||||||
|         </Header> |         </Header> | ||||||
|         <Form size="large"> |         <Form size='large'> | ||||||
|           <Segment> |           <Segment> | ||||||
|             <Form.Input |             <Form.Input | ||||||
|               fluid |               fluid | ||||||
|               icon="user" |               icon='user' | ||||||
|               iconPosition="left" |               iconPosition='left' | ||||||
|               placeholder="用户名" |               placeholder='用户名' | ||||||
|               name="username" |               name='username' | ||||||
|               value={username} |               value={username} | ||||||
|               onChange={handleChange} |               onChange={handleChange} | ||||||
|             /> |             /> | ||||||
|             <Form.Input |             <Form.Input | ||||||
|               fluid |               fluid | ||||||
|               icon="lock" |               icon='lock' | ||||||
|               iconPosition="left" |               iconPosition='left' | ||||||
|               placeholder="密码" |               placeholder='密码' | ||||||
|               name="password" |               name='password' | ||||||
|               type="password" |               type='password' | ||||||
|               value={password} |               value={password} | ||||||
|               onChange={handleChange} |               onChange={handleChange} | ||||||
|             /> |             /> | ||||||
|             <Button color="" fluid size="large" onClick={handleSubmit}> |             <Button color='green' fluid size='large' onClick={handleSubmit}> | ||||||
|               登录 |               登录 | ||||||
|             </Button> |             </Button> | ||||||
|           </Segment> |           </Segment> | ||||||
|         </Form> |         </Form> | ||||||
|         <Message> |         <Message> | ||||||
|           忘记密码? |           忘记密码? | ||||||
|           <Link to="/reset" className="btn btn-link"> |           <Link to='/reset' className='btn btn-link'> | ||||||
|             点击重置 |             点击重置 | ||||||
|           </Link> |           </Link> | ||||||
|           ; 没有账户? |           ; 没有账户? | ||||||
|           <Link to="/register" className="btn btn-link"> |           <Link to='/register' className='btn btn-link'> | ||||||
|             点击注册 |             点击注册 | ||||||
|           </Link> |           </Link> | ||||||
|         </Message> |         </Message> | ||||||
| @@ -140,9 +130,9 @@ const LoginForm = () => { | |||||||
|             {status.github_oauth ? ( |             {status.github_oauth ? ( | ||||||
|               <Button |               <Button | ||||||
|                 circular |                 circular | ||||||
|                 color="black" |                 color='black' | ||||||
|                 icon="github" |                 icon='github' | ||||||
|                 onClick={onGitHubOAuthClicked} |                 onClick={() => onGitHubOAuthClicked(status.github_client_id)} | ||||||
|               /> |               /> | ||||||
|             ) : ( |             ) : ( | ||||||
|               <></> |               <></> | ||||||
| @@ -150,8 +140,8 @@ const LoginForm = () => { | |||||||
|             {status.wechat_login ? ( |             {status.wechat_login ? ( | ||||||
|               <Button |               <Button | ||||||
|                 circular |                 circular | ||||||
|                 color="green" |                 color='green' | ||||||
|                 icon="wechat" |                 icon='wechat' | ||||||
|                 onClick={onWeChatLoginClicked} |                 onClick={onWeChatLoginClicked} | ||||||
|               /> |               /> | ||||||
|             ) : ( |             ) : ( | ||||||
| @@ -175,18 +165,18 @@ const LoginForm = () => { | |||||||
|                   微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) |                   微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) | ||||||
|                 </p> |                 </p> | ||||||
|               </div> |               </div> | ||||||
|               <Form size="large"> |               <Form size='large'> | ||||||
|                 <Form.Input |                 <Form.Input | ||||||
|                   fluid |                   fluid | ||||||
|                   placeholder="验证码" |                   placeholder='验证码' | ||||||
|                   name="wechat_verification_code" |                   name='wechat_verification_code' | ||||||
|                   value={inputs.wechat_verification_code} |                   value={inputs.wechat_verification_code} | ||||||
|                   onChange={handleChange} |                   onChange={handleChange} | ||||||
|                 /> |                 /> | ||||||
|                 <Button |                 <Button | ||||||
|                   color="" |                   color='' | ||||||
|                   fluid |                   fluid | ||||||
|                   size="large" |                   size='large' | ||||||
|                   onClick={onSubmitWeChatVerificationCode} |                   onClick={onSubmitWeChatVerificationCode} | ||||||
|                 > |                 > | ||||||
|                   登录 |                   登录 | ||||||
|   | |||||||
| @@ -43,6 +43,7 @@ function renderType(type) { | |||||||
|  |  | ||||||
| const LogsTable = () => { | const LogsTable = () => { | ||||||
|   const [logs, setLogs] = useState([]); |   const [logs, setLogs] = useState([]); | ||||||
|  |   const [showStat, setShowStat] = useState(false); | ||||||
|   const [loading, setLoading] = useState(true); |   const [loading, setLoading] = useState(true); | ||||||
|   const [activePage, setActivePage] = useState(1); |   const [activePage, setActivePage] = useState(1); | ||||||
|   const [searchKeyword, setSearchKeyword] = useState(''); |   const [searchKeyword, setSearchKeyword] = useState(''); | ||||||
| @@ -55,9 +56,10 @@ const LogsTable = () => { | |||||||
|     token_name: '', |     token_name: '', | ||||||
|     model_name: '', |     model_name: '', | ||||||
|     start_timestamp: timestamp2string(0), |     start_timestamp: timestamp2string(0), | ||||||
|     end_timestamp: timestamp2string(now.getTime() / 1000 + 3600) |     end_timestamp: timestamp2string(now.getTime() / 1000 + 3600), | ||||||
|  |     channel: '' | ||||||
|   }); |   }); | ||||||
|   const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs; |   const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs; | ||||||
|  |  | ||||||
|   const [stat, setStat] = useState({ |   const [stat, setStat] = useState({ | ||||||
|     quota: 0, |     quota: 0, | ||||||
| @@ -83,7 +85,7 @@ const LogsTable = () => { | |||||||
|   const getLogStat = async () => { |   const getLogStat = async () => { | ||||||
|     let localStartTimestamp = Date.parse(start_timestamp) / 1000; |     let localStartTimestamp = Date.parse(start_timestamp) / 1000; | ||||||
|     let localEndTimestamp = Date.parse(end_timestamp) / 1000; |     let localEndTimestamp = Date.parse(end_timestamp) / 1000; | ||||||
|     let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`); |     let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`); | ||||||
|     const { success, message, data } = res.data; |     const { success, message, data } = res.data; | ||||||
|     if (success) { |     if (success) { | ||||||
|       setStat(data); |       setStat(data); | ||||||
| @@ -92,12 +94,23 @@ const LogsTable = () => { | |||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   const handleEyeClick = async () => { | ||||||
|  |     if (!showStat) { | ||||||
|  |       if (isAdminUser) { | ||||||
|  |         await getLogStat(); | ||||||
|  |       } else { | ||||||
|  |         await getLogSelfStat(); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     setShowStat(!showStat); | ||||||
|  |   }; | ||||||
|  |  | ||||||
|   const loadLogs = async (startIdx) => { |   const loadLogs = async (startIdx) => { | ||||||
|     let url = ''; |     let url = ''; | ||||||
|     let localStartTimestamp = Date.parse(start_timestamp) / 1000; |     let localStartTimestamp = Date.parse(start_timestamp) / 1000; | ||||||
|     let localEndTimestamp = Date.parse(end_timestamp) / 1000; |     let localEndTimestamp = Date.parse(end_timestamp) / 1000; | ||||||
|     if (isAdminUser) { |     if (isAdminUser) { | ||||||
|       url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; |       url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`; | ||||||
|     } else { |     } else { | ||||||
|       url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; |       url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; | ||||||
|     } |     } | ||||||
| @@ -129,13 +142,8 @@ const LogsTable = () => { | |||||||
|  |  | ||||||
|   const refresh = async () => { |   const refresh = async () => { | ||||||
|     setLoading(true); |     setLoading(true); | ||||||
|     setActivePage(1) |     setActivePage(1); | ||||||
|     await loadLogs(0); |     await loadLogs(0); | ||||||
|     if (isAdminUser) { |  | ||||||
|       getLogStat().then(); |  | ||||||
|     } else { |  | ||||||
|       getLogSelfStat().then(); |  | ||||||
|     } |  | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   useEffect(() => { |   useEffect(() => { | ||||||
| @@ -169,7 +177,7 @@ const LogsTable = () => { | |||||||
|     if (logs.length === 0) return; |     if (logs.length === 0) return; | ||||||
|     setLoading(true); |     setLoading(true); | ||||||
|     let sortedLogs = [...logs]; |     let sortedLogs = [...logs]; | ||||||
|     if (typeof sortedLogs[0][key] === 'string'){ |     if (typeof sortedLogs[0][key] === 'string') { | ||||||
|       sortedLogs.sort((a, b) => { |       sortedLogs.sort((a, b) => { | ||||||
|         return ('' + a[key]).localeCompare(b[key]); |         return ('' + a[key]).localeCompare(b[key]); | ||||||
|       }); |       }); | ||||||
| @@ -190,19 +198,17 @@ const LogsTable = () => { | |||||||
|   return ( |   return ( | ||||||
|     <> |     <> | ||||||
|       <Segment> |       <Segment> | ||||||
|         <Header as='h3'>使用明细(总消耗额度:{renderQuota(stat.quota)})</Header> |         <Header as='h3'> | ||||||
|  |           使用明细(总消耗额度: | ||||||
|  |           {showStat && renderQuota(stat.quota)} | ||||||
|  |           {!showStat && <span onClick={handleEyeClick} style={{ cursor: 'pointer', color: 'gray' }}>点击查看</span>} | ||||||
|  |           ) | ||||||
|  |         </Header> | ||||||
|         <Form> |         <Form> | ||||||
|           <Form.Group> |           <Form.Group> | ||||||
|             { |             <Form.Input fluid label={'令牌名称'} width={3} value={token_name} | ||||||
|               isAdminUser && ( |  | ||||||
|                 <Form.Input fluid label={'用户名称'} width={2} value={username} |  | ||||||
|                             placeholder={'可选值'} name='username' |  | ||||||
|                             onChange={handleInputChange} /> |  | ||||||
|               ) |  | ||||||
|             } |  | ||||||
|             <Form.Input fluid label={'令牌名称'} width={isAdminUser ? 2 : 3} value={token_name} |  | ||||||
|                         placeholder={'可选值'} name='token_name' onChange={handleInputChange} /> |                         placeholder={'可选值'} name='token_name' onChange={handleInputChange} /> | ||||||
|             <Form.Input fluid label='模型名称' width={isAdminUser ? 2 : 3} value={model_name} placeholder='可选值' |             <Form.Input fluid label='模型名称' width={3} value={model_name} placeholder='可选值' | ||||||
|                         name='model_name' |                         name='model_name' | ||||||
|                         onChange={handleInputChange} /> |                         onChange={handleInputChange} /> | ||||||
|             <Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local' |             <Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local' | ||||||
| @@ -213,6 +219,19 @@ const LogsTable = () => { | |||||||
|                         onChange={handleInputChange} /> |                         onChange={handleInputChange} /> | ||||||
|             <Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button> |             <Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button> | ||||||
|           </Form.Group> |           </Form.Group> | ||||||
|  |           { | ||||||
|  |             isAdminUser && <> | ||||||
|  |               <Form.Group> | ||||||
|  |                 <Form.Input fluid label={'渠道 ID'} width={3} value={channel} | ||||||
|  |                             placeholder='可选值' name='channel' | ||||||
|  |                             onChange={handleInputChange} /> | ||||||
|  |                 <Form.Input fluid label={'用户名称'} width={3} value={username} | ||||||
|  |                             placeholder={'可选值'} name='username' | ||||||
|  |                             onChange={handleInputChange} /> | ||||||
|  |  | ||||||
|  |               </Form.Group> | ||||||
|  |             </> | ||||||
|  |           } | ||||||
|         </Form> |         </Form> | ||||||
|         <Table basic compact size='small'> |         <Table basic compact size='small'> | ||||||
|           <Table.Header> |           <Table.Header> | ||||||
| @@ -226,6 +245,17 @@ const LogsTable = () => { | |||||||
|               > |               > | ||||||
|                 时间 |                 时间 | ||||||
|               </Table.HeaderCell> |               </Table.HeaderCell> | ||||||
|  |               { | ||||||
|  |                 isAdminUser && <Table.HeaderCell | ||||||
|  |                   style={{ cursor: 'pointer' }} | ||||||
|  |                   onClick={() => { | ||||||
|  |                     sortLog('channel'); | ||||||
|  |                   }} | ||||||
|  |                   width={1} | ||||||
|  |                 > | ||||||
|  |                   渠道 | ||||||
|  |                 </Table.HeaderCell> | ||||||
|  |               } | ||||||
|               { |               { | ||||||
|                 isAdminUser && <Table.HeaderCell |                 isAdminUser && <Table.HeaderCell | ||||||
|                   style={{ cursor: 'pointer' }} |                   style={{ cursor: 'pointer' }} | ||||||
| @@ -287,16 +317,16 @@ const LogsTable = () => { | |||||||
|                 onClick={() => { |                 onClick={() => { | ||||||
|                   sortLog('quota'); |                   sortLog('quota'); | ||||||
|                 }} |                 }} | ||||||
|                 width={2} |                 width={1} | ||||||
|               > |               > | ||||||
|                 消耗额度 |                 额度 | ||||||
|               </Table.HeaderCell> |               </Table.HeaderCell> | ||||||
|               <Table.HeaderCell |               <Table.HeaderCell | ||||||
|                 style={{ cursor: 'pointer' }} |                 style={{ cursor: 'pointer' }} | ||||||
|                 onClick={() => { |                 onClick={() => { | ||||||
|                   sortLog('content'); |                   sortLog('content'); | ||||||
|                 }} |                 }} | ||||||
|                 width={isAdminUser ? 4 : 5} |                 width={isAdminUser ? 4 : 6} | ||||||
|               > |               > | ||||||
|                 详情 |                 详情 | ||||||
|               </Table.HeaderCell> |               </Table.HeaderCell> | ||||||
| @@ -312,8 +342,13 @@ const LogsTable = () => { | |||||||
|               .map((log, idx) => { |               .map((log, idx) => { | ||||||
|                 if (log.deleted) return <></>; |                 if (log.deleted) return <></>; | ||||||
|                 return ( |                 return ( | ||||||
|                   <Table.Row key={log.created_at}> |                   <Table.Row key={log.id}> | ||||||
|                     <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell> |                     <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell> | ||||||
|  |                     { | ||||||
|  |                       isAdminUser && ( | ||||||
|  |                         <Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell> | ||||||
|  |                       ) | ||||||
|  |                     } | ||||||
|                     { |                     { | ||||||
|                       isAdminUser && ( |                       isAdminUser && ( | ||||||
|                         <Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell> |                         <Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell> | ||||||
| @@ -333,7 +368,7 @@ const LogsTable = () => { | |||||||
|  |  | ||||||
|           <Table.Footer> |           <Table.Footer> | ||||||
|             <Table.Row> |             <Table.Row> | ||||||
|               <Table.HeaderCell colSpan={'9'}> |               <Table.HeaderCell colSpan={'10'}> | ||||||
|                 <Select |                 <Select | ||||||
|                   placeholder='选择明细分类' |                   placeholder='选择明细分类' | ||||||
|                   options={LOG_OPTIONS} |                   options={LOG_OPTIONS} | ||||||
|   | |||||||
| @@ -1,8 +1,9 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Divider, Form, Grid, Header } from 'semantic-ui-react'; | import { Divider, Form, Grid, Header } from 'semantic-ui-react'; | ||||||
| import { API, showError, verifyJSON } from '../helpers'; | import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers'; | ||||||
|  |  | ||||||
| const OperationSetting = () => { | const OperationSetting = () => { | ||||||
|  |   let now = new Date(); | ||||||
|   let [inputs, setInputs] = useState({ |   let [inputs, setInputs] = useState({ | ||||||
|     QuotaForNewUser: 0, |     QuotaForNewUser: 0, | ||||||
|     QuotaForInviter: 0, |     QuotaForInviter: 0, | ||||||
| @@ -20,10 +21,11 @@ const OperationSetting = () => { | |||||||
|     DisplayInCurrencyEnabled: '', |     DisplayInCurrencyEnabled: '', | ||||||
|     DisplayTokenStatEnabled: '', |     DisplayTokenStatEnabled: '', | ||||||
|     ApproximateTokenEnabled: '', |     ApproximateTokenEnabled: '', | ||||||
|     RetryTimes: 0, |     RetryTimes: 0 | ||||||
|   }); |   }); | ||||||
|   const [originInputs, setOriginInputs] = useState({}); |   const [originInputs, setOriginInputs] = useState({}); | ||||||
|   let [loading, setLoading] = useState(false); |   let [loading, setLoading] = useState(false); | ||||||
|  |   let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago | ||||||
|  |  | ||||||
|   const getOptions = async () => { |   const getOptions = async () => { | ||||||
|     const res = await API.get('/api/option/'); |     const res = await API.get('/api/option/'); | ||||||
| @@ -130,6 +132,17 @@ const OperationSetting = () => { | |||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   const deleteHistoryLogs = async () => { | ||||||
|  |     console.log(inputs); | ||||||
|  |     const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`); | ||||||
|  |     const { success, message, data } = res.data; | ||||||
|  |     if (success) { | ||||||
|  |       showSuccess(`${data} 条日志已清理!`); | ||||||
|  |       return; | ||||||
|  |     } | ||||||
|  |     showError('日志清理失败:' + message); | ||||||
|  |   }; | ||||||
|  |  | ||||||
|   return ( |   return ( | ||||||
|     <Grid columns={1}> |     <Grid columns={1}> | ||||||
|       <Grid.Column> |       <Grid.Column> | ||||||
| @@ -179,12 +192,6 @@ const OperationSetting = () => { | |||||||
|             /> |             /> | ||||||
|           </Form.Group> |           </Form.Group> | ||||||
|           <Form.Group inline> |           <Form.Group inline> | ||||||
|             <Form.Checkbox |  | ||||||
|               checked={inputs.LogConsumeEnabled === 'true'} |  | ||||||
|               label='启用额度消费日志记录' |  | ||||||
|               name='LogConsumeEnabled' |  | ||||||
|               onChange={handleInputChange} |  | ||||||
|             /> |  | ||||||
|             <Form.Checkbox |             <Form.Checkbox | ||||||
|               checked={inputs.DisplayInCurrencyEnabled === 'true'} |               checked={inputs.DisplayInCurrencyEnabled === 'true'} | ||||||
|               label='以货币形式显示额度' |               label='以货币形式显示额度' | ||||||
| @@ -208,6 +215,28 @@ const OperationSetting = () => { | |||||||
|             submitConfig('general').then(); |             submitConfig('general').then(); | ||||||
|           }}>保存通用设置</Form.Button> |           }}>保存通用设置</Form.Button> | ||||||
|           <Divider /> |           <Divider /> | ||||||
|  |           <Header as='h3'> | ||||||
|  |             日志设置 | ||||||
|  |           </Header> | ||||||
|  |           <Form.Group inline> | ||||||
|  |             <Form.Checkbox | ||||||
|  |               checked={inputs.LogConsumeEnabled === 'true'} | ||||||
|  |               label='启用额度消费日志记录' | ||||||
|  |               name='LogConsumeEnabled' | ||||||
|  |               onChange={handleInputChange} | ||||||
|  |             /> | ||||||
|  |           </Form.Group> | ||||||
|  |           <Form.Group widths={4}> | ||||||
|  |             <Form.Input label='目标时间' value={historyTimestamp} type='datetime-local' | ||||||
|  |                         name='history_timestamp' | ||||||
|  |                         onChange={(e, { name, value }) => { | ||||||
|  |                           setHistoryTimestamp(value); | ||||||
|  |                         }} /> | ||||||
|  |           </Form.Group> | ||||||
|  |           <Form.Button onClick={() => { | ||||||
|  |             deleteHistoryLogs().then(); | ||||||
|  |           }}>清理历史日志</Form.Button> | ||||||
|  |           <Divider /> | ||||||
|           <Header as='h3'> |           <Header as='h3'> | ||||||
|             监控设置 |             监控设置 | ||||||
|           </Header> |           </Header> | ||||||
|   | |||||||
| @@ -112,7 +112,7 @@ const OtherSetting = () => { | |||||||
|           <Form.Group widths='equal'> |           <Form.Group widths='equal'> | ||||||
|             <Form.TextArea |             <Form.TextArea | ||||||
|               label='公告' |               label='公告' | ||||||
|               placeholder='在此输入新的公告内容' |               placeholder='在此输入新的公告内容,支持 Markdown & HTML 代码' | ||||||
|               value={inputs.Notice} |               value={inputs.Notice} | ||||||
|               name='Notice' |               name='Notice' | ||||||
|               onChange={handleInputChange} |               onChange={handleInputChange} | ||||||
|   | |||||||
| @@ -12,6 +12,11 @@ const PasswordResetConfirm = () => { | |||||||
|  |  | ||||||
|   const [loading, setLoading] = useState(false); |   const [loading, setLoading] = useState(false); | ||||||
|  |  | ||||||
|  |   const [disableButton, setDisableButton] = useState(false); | ||||||
|  |   const [countdown, setCountdown] = useState(30); | ||||||
|  |  | ||||||
|  |   const [newPassword, setNewPassword] = useState(''); | ||||||
|  |  | ||||||
|   const [searchParams, setSearchParams] = useSearchParams(); |   const [searchParams, setSearchParams] = useSearchParams(); | ||||||
|   useEffect(() => { |   useEffect(() => { | ||||||
|     let token = searchParams.get('token'); |     let token = searchParams.get('token'); | ||||||
| @@ -22,7 +27,21 @@ const PasswordResetConfirm = () => { | |||||||
|     }); |     }); | ||||||
|   }, []); |   }, []); | ||||||
|  |  | ||||||
|  |   useEffect(() => { | ||||||
|  |     let countdownInterval = null; | ||||||
|  |     if (disableButton && countdown > 0) { | ||||||
|  |       countdownInterval = setInterval(() => { | ||||||
|  |         setCountdown(countdown - 1); | ||||||
|  |       }, 1000); | ||||||
|  |     } else if (countdown === 0) { | ||||||
|  |       setDisableButton(false); | ||||||
|  |       setCountdown(30); | ||||||
|  |     } | ||||||
|  |     return () => clearInterval(countdownInterval);  | ||||||
|  |   }, [disableButton, countdown]); | ||||||
|  |  | ||||||
|   async function handleSubmit(e) { |   async function handleSubmit(e) { | ||||||
|  |     setDisableButton(true); | ||||||
|     if (!email) return; |     if (!email) return; | ||||||
|     setLoading(true); |     setLoading(true); | ||||||
|     const res = await API.post(`/api/user/reset`, { |     const res = await API.post(`/api/user/reset`, { | ||||||
| @@ -32,8 +51,9 @@ const PasswordResetConfirm = () => { | |||||||
|     const { success, message } = res.data; |     const { success, message } = res.data; | ||||||
|     if (success) { |     if (success) { | ||||||
|       let password = res.data.data; |       let password = res.data.data; | ||||||
|  |       setNewPassword(password); | ||||||
|       await copy(password); |       await copy(password); | ||||||
|       showNotice(`密码已重置并已复制到剪贴板:${password}`); |       showNotice(`新密码已复制到剪贴板:${password}`); | ||||||
|     } else { |     } else { | ||||||
|       showError(message); |       showError(message); | ||||||
|     } |     } | ||||||
| @@ -57,14 +77,31 @@ const PasswordResetConfirm = () => { | |||||||
|               value={email} |               value={email} | ||||||
|               readOnly |               readOnly | ||||||
|             /> |             /> | ||||||
|  |             {newPassword && ( | ||||||
|  |               <Form.Input | ||||||
|  |               fluid | ||||||
|  |               icon='lock' | ||||||
|  |               iconPosition='left' | ||||||
|  |               placeholder='新密码' | ||||||
|  |               name='newPassword' | ||||||
|  |               value={newPassword} | ||||||
|  |               readOnly | ||||||
|  |               onClick={(e) => { | ||||||
|  |                 e.target.select(); | ||||||
|  |                 navigator.clipboard.writeText(newPassword); | ||||||
|  |                 showNotice(`密码已复制到剪贴板:${newPassword}`); | ||||||
|  |               }} | ||||||
|  |             />             | ||||||
|  |             )} | ||||||
|             <Button |             <Button | ||||||
|               color='' |               color='green' | ||||||
|               fluid |               fluid | ||||||
|               size='large' |               size='large' | ||||||
|               onClick={handleSubmit} |               onClick={handleSubmit} | ||||||
|               loading={loading} |               loading={loading} | ||||||
|  |               disabled={disableButton} | ||||||
|             > |             > | ||||||
|               提交 |               {disableButton ? `密码重置完成` : '提交'} | ||||||
|             </Button> |             </Button> | ||||||
|           </Segment> |           </Segment> | ||||||
|         </Form> |         </Form> | ||||||
|   | |||||||
| @@ -5,7 +5,7 @@ import Turnstile from 'react-turnstile'; | |||||||
|  |  | ||||||
| const PasswordResetForm = () => { | const PasswordResetForm = () => { | ||||||
|   const [inputs, setInputs] = useState({ |   const [inputs, setInputs] = useState({ | ||||||
|     email: '', |     email: '' | ||||||
|   }); |   }); | ||||||
|   const { email } = inputs; |   const { email } = inputs; | ||||||
|  |  | ||||||
| @@ -13,24 +13,29 @@ const PasswordResetForm = () => { | |||||||
|   const [turnstileEnabled, setTurnstileEnabled] = useState(false); |   const [turnstileEnabled, setTurnstileEnabled] = useState(false); | ||||||
|   const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); |   const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); | ||||||
|   const [turnstileToken, setTurnstileToken] = useState(''); |   const [turnstileToken, setTurnstileToken] = useState(''); | ||||||
|  |   const [disableButton, setDisableButton] = useState(false); | ||||||
|  |   const [countdown, setCountdown] = useState(30); | ||||||
|  |  | ||||||
|   useEffect(() => { |   useEffect(() => { | ||||||
|     let status = localStorage.getItem('status'); |     let countdownInterval = null; | ||||||
|     if (status) { |     if (disableButton && countdown > 0) { | ||||||
|       status = JSON.parse(status); |       countdownInterval = setInterval(() => { | ||||||
|       if (status.turnstile_check) { |         setCountdown(countdown - 1); | ||||||
|         setTurnstileEnabled(true); |       }, 1000); | ||||||
|         setTurnstileSiteKey(status.turnstile_site_key); |     } else if (countdown === 0) { | ||||||
|  |       setDisableButton(false); | ||||||
|  |       setCountdown(30); | ||||||
|     } |     } | ||||||
|     } |     return () => clearInterval(countdownInterval); | ||||||
|   }, []); |   }, [disableButton, countdown]); | ||||||
|  |  | ||||||
|   function handleChange(e) { |   function handleChange(e) { | ||||||
|     const { name, value } = e.target; |     const { name, value } = e.target; | ||||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); |     setInputs(inputs => ({ ...inputs, [name]: value })); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   async function handleSubmit(e) { |   async function handleSubmit(e) { | ||||||
|  |     setDisableButton(true); | ||||||
|     if (!email) return; |     if (!email) return; | ||||||
|     if (turnstileEnabled && turnstileToken === '') { |     if (turnstileEnabled && turnstileToken === '') { | ||||||
|       showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); |       showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); | ||||||
| @@ -78,13 +83,14 @@ const PasswordResetForm = () => { | |||||||
|               <></> |               <></> | ||||||
|             )} |             )} | ||||||
|             <Button |             <Button | ||||||
|               color='' |               color='green' | ||||||
|               fluid |               fluid | ||||||
|               size='large' |               size='large' | ||||||
|               onClick={handleSubmit} |               onClick={handleSubmit} | ||||||
|               loading={loading} |               loading={loading} | ||||||
|  |               disabled={disableButton} | ||||||
|             > |             > | ||||||
|               提交 |               {disableButton ? `重试 (${countdown})` : '提交'} | ||||||
|             </Button> |             </Button> | ||||||
|           </Segment> |           </Segment> | ||||||
|         </Form> |         </Form> | ||||||
|   | |||||||
| @@ -1,22 +1,33 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useContext, useEffect, useState } from 'react'; | ||||||
| import { Button, Divider, Form, Header, Image, Message, Modal } from 'semantic-ui-react'; | import { Button, Divider, Form, Header, Image, Message, Modal } from 'semantic-ui-react'; | ||||||
| import { Link } from 'react-router-dom'; | import { Link, useNavigate } from 'react-router-dom'; | ||||||
| import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; | import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; | ||||||
| import Turnstile from 'react-turnstile'; | import Turnstile from 'react-turnstile'; | ||||||
|  | import { UserContext } from '../context/User'; | ||||||
|  | import { onGitHubOAuthClicked } from './utils'; | ||||||
|  |  | ||||||
| const PersonalSetting = () => { | const PersonalSetting = () => { | ||||||
|  |   const [userState, userDispatch] = useContext(UserContext); | ||||||
|  |   let navigate = useNavigate(); | ||||||
|  |  | ||||||
|   const [inputs, setInputs] = useState({ |   const [inputs, setInputs] = useState({ | ||||||
|     wechat_verification_code: '', |     wechat_verification_code: '', | ||||||
|     email_verification_code: '', |     email_verification_code: '', | ||||||
|     email: '', |     email: '', | ||||||
|  |     self_account_deletion_confirmation: '' | ||||||
|   }); |   }); | ||||||
|   const [status, setStatus] = useState({}); |   const [status, setStatus] = useState({}); | ||||||
|   const [showWeChatBindModal, setShowWeChatBindModal] = useState(false); |   const [showWeChatBindModal, setShowWeChatBindModal] = useState(false); | ||||||
|   const [showEmailBindModal, setShowEmailBindModal] = useState(false); |   const [showEmailBindModal, setShowEmailBindModal] = useState(false); | ||||||
|  |   const [showAccountDeleteModal, setShowAccountDeleteModal] = useState(false); | ||||||
|   const [turnstileEnabled, setTurnstileEnabled] = useState(false); |   const [turnstileEnabled, setTurnstileEnabled] = useState(false); | ||||||
|   const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); |   const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); | ||||||
|   const [turnstileToken, setTurnstileToken] = useState(''); |   const [turnstileToken, setTurnstileToken] = useState(''); | ||||||
|   const [loading, setLoading] = useState(false); |   const [loading, setLoading] = useState(false); | ||||||
|  |   const [disableButton, setDisableButton] = useState(false); | ||||||
|  |   const [countdown, setCountdown] = useState(30); | ||||||
|  |   const [affLink, setAffLink] = useState(""); | ||||||
|  |   const [systemToken, setSystemToken] = useState(""); | ||||||
|  |  | ||||||
|   useEffect(() => { |   useEffect(() => { | ||||||
|     let status = localStorage.getItem('status'); |     let status = localStorage.getItem('status'); | ||||||
| @@ -30,6 +41,19 @@ const PersonalSetting = () => { | |||||||
|     } |     } | ||||||
|   }, []); |   }, []); | ||||||
|  |  | ||||||
|  |   useEffect(() => { | ||||||
|  |     let countdownInterval = null; | ||||||
|  |     if (disableButton && countdown > 0) { | ||||||
|  |       countdownInterval = setInterval(() => { | ||||||
|  |         setCountdown(countdown - 1); | ||||||
|  |       }, 1000); | ||||||
|  |     } else if (countdown === 0) { | ||||||
|  |       setDisableButton(false); | ||||||
|  |       setCountdown(30); | ||||||
|  |     } | ||||||
|  |     return () => clearInterval(countdownInterval); // Clean up on unmount | ||||||
|  |   }, [disableButton, countdown]); | ||||||
|  |  | ||||||
|   const handleInputChange = (e, { name, value }) => { |   const handleInputChange = (e, { name, value }) => { | ||||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); |     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||||
|   }; |   }; | ||||||
| @@ -38,8 +62,10 @@ const PersonalSetting = () => { | |||||||
|     const res = await API.get('/api/user/token'); |     const res = await API.get('/api/user/token'); | ||||||
|     const { success, message, data } = res.data; |     const { success, message, data } = res.data; | ||||||
|     if (success) { |     if (success) { | ||||||
|  |       setSystemToken(data); | ||||||
|  |       setAffLink("");  | ||||||
|       await copy(data); |       await copy(data); | ||||||
|       showSuccess(`令牌已重置并已复制到剪贴板:${data}`); |       showSuccess(`令牌已重置并已复制到剪贴板`); | ||||||
|     } else { |     } else { | ||||||
|       showError(message); |       showError(message); | ||||||
|     } |     } | ||||||
| @@ -50,8 +76,42 @@ const PersonalSetting = () => { | |||||||
|     const { success, message, data } = res.data; |     const { success, message, data } = res.data; | ||||||
|     if (success) { |     if (success) { | ||||||
|       let link = `${window.location.origin}/register?aff=${data}`; |       let link = `${window.location.origin}/register?aff=${data}`; | ||||||
|  |       setAffLink(link); | ||||||
|  |       setSystemToken(""); | ||||||
|       await copy(link); |       await copy(link); | ||||||
|       showNotice(`邀请链接已复制到剪切板:${link}`); |       showSuccess(`邀请链接已复制到剪切板`); | ||||||
|  |     } else { | ||||||
|  |       showError(message); | ||||||
|  |     } | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   const handleAffLinkClick = async (e) => { | ||||||
|  |     e.target.select(); | ||||||
|  |     await copy(e.target.value); | ||||||
|  |     showSuccess(`邀请链接已复制到剪切板`); | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   const handleSystemTokenClick = async (e) => { | ||||||
|  |     e.target.select(); | ||||||
|  |     await copy(e.target.value); | ||||||
|  |     showSuccess(`系统令牌已复制到剪切板`); | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   const deleteAccount = async () => { | ||||||
|  |     if (inputs.self_account_deletion_confirmation !== userState.user.username) { | ||||||
|  |       showError('请输入你的账户名以确认删除!'); | ||||||
|  |       return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const res = await API.delete('/api/user/self'); | ||||||
|  |     const { success, message } = res.data; | ||||||
|  |  | ||||||
|  |     if (success) { | ||||||
|  |       showSuccess('账户已删除!'); | ||||||
|  |       await API.get('/api/user/logout'); | ||||||
|  |       userDispatch({ type: 'logout' }); | ||||||
|  |       localStorage.removeItem('user'); | ||||||
|  |       navigate('/login'); | ||||||
|     } else { |     } else { | ||||||
|       showError(message); |       showError(message); | ||||||
|     } |     } | ||||||
| @@ -71,13 +131,8 @@ const PersonalSetting = () => { | |||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   const openGitHubOAuth = () => { |  | ||||||
|     window.open( |  | ||||||
|       `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` |  | ||||||
|     ); |  | ||||||
|   }; |  | ||||||
|  |  | ||||||
|   const sendVerificationCode = async () => { |   const sendVerificationCode = async () => { | ||||||
|  |     setDisableButton(true); | ||||||
|     if (inputs.email === '') return; |     if (inputs.email === '') return; | ||||||
|     if (turnstileEnabled && turnstileToken === '') { |     if (turnstileEnabled && turnstileToken === '') { | ||||||
|       showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); |       showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); | ||||||
| @@ -123,6 +178,28 @@ const PersonalSetting = () => { | |||||||
|       </Button> |       </Button> | ||||||
|       <Button onClick={generateAccessToken}>生成系统访问令牌</Button> |       <Button onClick={generateAccessToken}>生成系统访问令牌</Button> | ||||||
|       <Button onClick={getAffLink}>复制邀请链接</Button> |       <Button onClick={getAffLink}>复制邀请链接</Button> | ||||||
|  |       <Button onClick={() => { | ||||||
|  |         setShowAccountDeleteModal(true); | ||||||
|  |       }}>删除个人账户</Button> | ||||||
|  |        | ||||||
|  |       {systemToken && ( | ||||||
|  |         <Form.Input  | ||||||
|  |           fluid  | ||||||
|  |           readOnly  | ||||||
|  |           value={systemToken}  | ||||||
|  |           onClick={handleSystemTokenClick} | ||||||
|  |           style={{ marginTop: '10px' }} | ||||||
|  |         /> | ||||||
|  |       )} | ||||||
|  |       {affLink && ( | ||||||
|  |         <Form.Input  | ||||||
|  |           fluid  | ||||||
|  |           readOnly  | ||||||
|  |           value={affLink}  | ||||||
|  |           onClick={handleAffLinkClick} | ||||||
|  |           style={{ marginTop: '10px' }} | ||||||
|  |         /> | ||||||
|  |       )} | ||||||
|       <Divider /> |       <Divider /> | ||||||
|       <Header as='h3'>账号绑定</Header> |       <Header as='h3'>账号绑定</Header> | ||||||
|       { |       { | ||||||
| @@ -167,7 +244,7 @@ const PersonalSetting = () => { | |||||||
|       </Modal> |       </Modal> | ||||||
|       { |       { | ||||||
|         status.github_oauth && ( |         status.github_oauth && ( | ||||||
|           <Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button> |           <Button onClick={()=>{onGitHubOAuthClicked(status.github_client_id)}}>绑定 GitHub 账号</Button> | ||||||
|         ) |         ) | ||||||
|       } |       } | ||||||
|       <Button |       <Button | ||||||
| @@ -195,8 +272,8 @@ const PersonalSetting = () => { | |||||||
|                 name='email' |                 name='email' | ||||||
|                 type='email' |                 type='email' | ||||||
|                 action={ |                 action={ | ||||||
|                   <Button onClick={sendVerificationCode} disabled={loading}> |                   <Button onClick={sendVerificationCode} disabled={disableButton || loading}> | ||||||
|                     获取验证码 |                     {disableButton ? `重新发送(${countdown})` : '获取验证码'} | ||||||
|                   </Button> |                   </Button> | ||||||
|                 } |                 } | ||||||
|               /> |               /> | ||||||
| @@ -217,6 +294,7 @@ const PersonalSetting = () => { | |||||||
|               ) : ( |               ) : ( | ||||||
|                 <></> |                 <></> | ||||||
|               )} |               )} | ||||||
|  |               <div style={{ display: 'flex', justifyContent: 'space-between', marginTop: '1rem' }}> | ||||||
|               <Button |               <Button | ||||||
|                 color='' |                 color='' | ||||||
|                 fluid |                 fluid | ||||||
| @@ -224,8 +302,69 @@ const PersonalSetting = () => { | |||||||
|                 onClick={bindEmail} |                 onClick={bindEmail} | ||||||
|                 loading={loading} |                 loading={loading} | ||||||
|               > |               > | ||||||
|                 绑定 |                 确认绑定 | ||||||
|               </Button> |               </Button> | ||||||
|  |               <div style={{ width: '1rem' }}></div>  | ||||||
|  |               <Button | ||||||
|  |                 fluid | ||||||
|  |                 size='large' | ||||||
|  |                 onClick={() => setShowEmailBindModal(false)} | ||||||
|  |               > | ||||||
|  |                 取消 | ||||||
|  |               </Button> | ||||||
|  |               </div> | ||||||
|  |             </Form> | ||||||
|  |           </Modal.Description> | ||||||
|  |         </Modal.Content> | ||||||
|  |       </Modal> | ||||||
|  |       <Modal | ||||||
|  |         onClose={() => setShowAccountDeleteModal(false)} | ||||||
|  |         onOpen={() => setShowAccountDeleteModal(true)} | ||||||
|  |         open={showAccountDeleteModal} | ||||||
|  |         size={'tiny'} | ||||||
|  |         style={{ maxWidth: '450px' }} | ||||||
|  |       > | ||||||
|  |         <Modal.Header>危险操作</Modal.Header> | ||||||
|  |         <Modal.Content> | ||||||
|  |         <Message>您正在删除自己的帐户,将清空所有数据且不可恢复</Message> | ||||||
|  |           <Modal.Description> | ||||||
|  |             <Form size='large'> | ||||||
|  |               <Form.Input | ||||||
|  |                 fluid | ||||||
|  |                 placeholder={`输入你的账户名 ${userState?.user?.username} 以确认删除`} | ||||||
|  |                 name='self_account_deletion_confirmation' | ||||||
|  |                 value={inputs.self_account_deletion_confirmation} | ||||||
|  |                 onChange={handleInputChange} | ||||||
|  |               /> | ||||||
|  |               {turnstileEnabled ? ( | ||||||
|  |                 <Turnstile | ||||||
|  |                   sitekey={turnstileSiteKey} | ||||||
|  |                   onVerify={(token) => { | ||||||
|  |                     setTurnstileToken(token); | ||||||
|  |                   }} | ||||||
|  |                 /> | ||||||
|  |               ) : ( | ||||||
|  |                 <></> | ||||||
|  |               )} | ||||||
|  |               <div style={{ display: 'flex', justifyContent: 'space-between', marginTop: '1rem' }}> | ||||||
|  |                 <Button | ||||||
|  |                   color='red' | ||||||
|  |                   fluid | ||||||
|  |                   size='large' | ||||||
|  |                   onClick={deleteAccount} | ||||||
|  |                   loading={loading} | ||||||
|  |                 > | ||||||
|  |                   确认删除 | ||||||
|  |                 </Button> | ||||||
|  |                 <div style={{ width: '1rem' }}></div> | ||||||
|  |                 <Button | ||||||
|  |                   fluid | ||||||
|  |                   size='large' | ||||||
|  |                   onClick={() => setShowAccountDeleteModal(false)} | ||||||
|  |                 > | ||||||
|  |                   取消 | ||||||
|  |                 </Button> | ||||||
|  |               </div> | ||||||
|             </Form> |             </Form> | ||||||
|           </Modal.Description> |           </Modal.Description> | ||||||
|         </Modal.Content> |         </Modal.Content> | ||||||
|   | |||||||
| @@ -1,5 +1,5 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Button, Form, Label, Message, Pagination, Table } from 'semantic-ui-react'; | import { Button, Form, Label, Popup, Pagination, Table } from 'semantic-ui-react'; | ||||||
| import { Link } from 'react-router-dom'; | import { Link } from 'react-router-dom'; | ||||||
| import { API, copy, showError, showInfo, showSuccess, showWarning, timestamp2string } from '../helpers'; | import { API, copy, showError, showInfo, showSuccess, showWarning, timestamp2string } from '../helpers'; | ||||||
|  |  | ||||||
| @@ -240,15 +240,25 @@ const RedemptionsTable = () => { | |||||||
|                       > |                       > | ||||||
|                         复制 |                         复制 | ||||||
|                       </Button> |                       </Button> | ||||||
|  |                       <Popup | ||||||
|  |                         trigger={ | ||||||
|  |                           <Button size='small' negative> | ||||||
|  |                             删除 | ||||||
|  |                           </Button> | ||||||
|  |                         } | ||||||
|  |                         on='click' | ||||||
|  |                         flowing | ||||||
|  |                         hoverable | ||||||
|  |                       > | ||||||
|                         <Button |                         <Button | ||||||
|                         size={'small'} |  | ||||||
|                           negative |                           negative | ||||||
|                           onClick={() => { |                           onClick={() => { | ||||||
|                             manageRedemption(redemption.id, 'delete', idx); |                             manageRedemption(redemption.id, 'delete', idx); | ||||||
|                           }} |                           }} | ||||||
|                         > |                         > | ||||||
|                         删除 |                           确认删除 | ||||||
|                         </Button> |                         </Button> | ||||||
|  |                       </Popup> | ||||||
|                       <Button |                       <Button | ||||||
|                         size={'small'} |                         size={'small'} | ||||||
|                         disabled={redemption.status === 3}  // used |                         disabled={redemption.status === 3}  // used | ||||||
|   | |||||||
| @@ -1,13 +1,5 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { | import { Button, Form, Grid, Header, Image, Message, Segment } from 'semantic-ui-react'; | ||||||
|   Button, |  | ||||||
|   Form, |  | ||||||
|   Grid, |  | ||||||
|   Header, |  | ||||||
|   Image, |  | ||||||
|   Message, |  | ||||||
|   Segment, |  | ||||||
| } from 'semantic-ui-react'; |  | ||||||
| import { Link, useNavigate } from 'react-router-dom'; | import { Link, useNavigate } from 'react-router-dom'; | ||||||
| import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; | import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; | ||||||
| import Turnstile from 'react-turnstile'; | import Turnstile from 'react-turnstile'; | ||||||
| @@ -18,7 +10,7 @@ const RegisterForm = () => { | |||||||
|     password: '', |     password: '', | ||||||
|     password2: '', |     password2: '', | ||||||
|     email: '', |     email: '', | ||||||
|     verification_code: '', |     verification_code: '' | ||||||
|   }); |   }); | ||||||
|   const { username, password, password2 } = inputs; |   const { username, password, password2 } = inputs; | ||||||
|   const [showEmailVerification, setShowEmailVerification] = useState(false); |   const [showEmailVerification, setShowEmailVerification] = useState(false); | ||||||
| @@ -178,7 +170,7 @@ const RegisterForm = () => { | |||||||
|               <></> |               <></> | ||||||
|             )} |             )} | ||||||
|             <Button |             <Button | ||||||
|               color='' |               color='green' | ||||||
|               fluid |               fluid | ||||||
|               size='large' |               size='large' | ||||||
|               onClick={handleSubmit} |               onClick={handleSubmit} | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Divider, Form, Grid, Header, Message } from 'semantic-ui-react'; | import { Button, Divider, Form, Grid, Header, Modal, Message } from 'semantic-ui-react'; | ||||||
| import { API, removeTrailingSlash, showError, verifyJSON } from '../helpers'; | import { API, removeTrailingSlash, showError } from '../helpers'; | ||||||
|  |  | ||||||
| const SystemSetting = () => { | const SystemSetting = () => { | ||||||
|   let [inputs, setInputs] = useState({ |   let [inputs, setInputs] = useState({ | ||||||
| @@ -26,9 +26,14 @@ const SystemSetting = () => { | |||||||
|     TurnstileSiteKey: '', |     TurnstileSiteKey: '', | ||||||
|     TurnstileSecretKey: '', |     TurnstileSecretKey: '', | ||||||
|     RegisterEnabled: '', |     RegisterEnabled: '', | ||||||
|  |     EmailDomainRestrictionEnabled: '', | ||||||
|  |     EmailDomainWhitelist: '' | ||||||
|   }); |   }); | ||||||
|   const [originInputs, setOriginInputs] = useState({}); |   const [originInputs, setOriginInputs] = useState({}); | ||||||
|   let [loading, setLoading] = useState(false); |   let [loading, setLoading] = useState(false); | ||||||
|  |   const [EmailDomainWhitelist, setEmailDomainWhitelist] = useState([]); | ||||||
|  |   const [restrictedDomainInput, setRestrictedDomainInput] = useState(''); | ||||||
|  |   const [showPasswordWarningModal, setShowPasswordWarningModal] = useState(false); | ||||||
|  |  | ||||||
|   const getOptions = async () => { |   const getOptions = async () => { | ||||||
|     const res = await API.get('/api/option/'); |     const res = await API.get('/api/option/'); | ||||||
| @@ -38,8 +43,15 @@ const SystemSetting = () => { | |||||||
|       data.forEach((item) => { |       data.forEach((item) => { | ||||||
|         newInputs[item.key] = item.value; |         newInputs[item.key] = item.value; | ||||||
|       }); |       }); | ||||||
|       setInputs(newInputs); |       setInputs({ | ||||||
|  |         ...newInputs, | ||||||
|  |         EmailDomainWhitelist: newInputs.EmailDomainWhitelist.split(',') | ||||||
|  |       }); | ||||||
|       setOriginInputs(newInputs); |       setOriginInputs(newInputs); | ||||||
|  |  | ||||||
|  |       setEmailDomainWhitelist(newInputs.EmailDomainWhitelist.split(',').map((item) => { | ||||||
|  |         return { key: item, text: item, value: item }; | ||||||
|  |       })); | ||||||
|     } else { |     } else { | ||||||
|       showError(message); |       showError(message); | ||||||
|     } |     } | ||||||
| @@ -58,6 +70,7 @@ const SystemSetting = () => { | |||||||
|       case 'GitHubOAuthEnabled': |       case 'GitHubOAuthEnabled': | ||||||
|       case 'WeChatAuthEnabled': |       case 'WeChatAuthEnabled': | ||||||
|       case 'TurnstileCheckEnabled': |       case 'TurnstileCheckEnabled': | ||||||
|  |       case 'EmailDomainRestrictionEnabled': | ||||||
|       case 'RegisterEnabled': |       case 'RegisterEnabled': | ||||||
|         value = inputs[key] === 'true' ? 'false' : 'true'; |         value = inputs[key] === 'true' ? 'false' : 'true'; | ||||||
|         break; |         break; | ||||||
| @@ -70,7 +83,12 @@ const SystemSetting = () => { | |||||||
|     }); |     }); | ||||||
|     const { success, message } = res.data; |     const { success, message } = res.data; | ||||||
|     if (success) { |     if (success) { | ||||||
|       setInputs((inputs) => ({ ...inputs, [key]: value })); |       if (key === 'EmailDomainWhitelist') { | ||||||
|  |         value = value.split(','); | ||||||
|  |       } | ||||||
|  |       setInputs((inputs) => ({ | ||||||
|  |         ...inputs, [key]: value | ||||||
|  |       })); | ||||||
|     } else { |     } else { | ||||||
|       showError(message); |       showError(message); | ||||||
|     } |     } | ||||||
| @@ -78,6 +96,11 @@ const SystemSetting = () => { | |||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   const handleInputChange = async (e, { name, value }) => { |   const handleInputChange = async (e, { name, value }) => { | ||||||
|  |     if (name === 'PasswordLoginEnabled' && inputs[name] === 'true') { | ||||||
|  |       // block disabling password login | ||||||
|  |       setShowPasswordWarningModal(true); | ||||||
|  |       return; | ||||||
|  |     } | ||||||
|     if ( |     if ( | ||||||
|       name === 'Notice' || |       name === 'Notice' || | ||||||
|       name.startsWith('SMTP') || |       name.startsWith('SMTP') || | ||||||
| @@ -88,7 +111,8 @@ const SystemSetting = () => { | |||||||
|       name === 'WeChatServerToken' || |       name === 'WeChatServerToken' || | ||||||
|       name === 'WeChatAccountQRCodeImageURL' || |       name === 'WeChatAccountQRCodeImageURL' || | ||||||
|       name === 'TurnstileSiteKey' || |       name === 'TurnstileSiteKey' || | ||||||
|       name === 'TurnstileSecretKey' |       name === 'TurnstileSecretKey' || | ||||||
|  |       name === 'EmailDomainWhitelist' | ||||||
|     ) { |     ) { | ||||||
|       setInputs((inputs) => ({ ...inputs, [name]: value })); |       setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||||
|     } else { |     } else { | ||||||
| @@ -125,6 +149,16 @@ const SystemSetting = () => { | |||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |  | ||||||
|  |   const submitEmailDomainWhitelist = async () => { | ||||||
|  |     if ( | ||||||
|  |       originInputs['EmailDomainWhitelist'] !== inputs.EmailDomainWhitelist.join(',') && | ||||||
|  |       inputs.SMTPToken !== '' | ||||||
|  |     ) { | ||||||
|  |       await updateOption('EmailDomainWhitelist', inputs.EmailDomainWhitelist.join(',')); | ||||||
|  |     } | ||||||
|  |   }; | ||||||
|  |  | ||||||
|   const submitWeChat = async () => { |   const submitWeChat = async () => { | ||||||
|     if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) { |     if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) { | ||||||
|       await updateOption( |       await updateOption( | ||||||
| @@ -173,6 +207,22 @@ const SystemSetting = () => { | |||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   const submitNewRestrictedDomain = () => { | ||||||
|  |     const localDomainList = inputs.EmailDomainWhitelist; | ||||||
|  |     if (restrictedDomainInput !== '' && !localDomainList.includes(restrictedDomainInput)) { | ||||||
|  |       setRestrictedDomainInput(''); | ||||||
|  |       setInputs({ | ||||||
|  |         ...inputs, | ||||||
|  |         EmailDomainWhitelist: [...localDomainList, restrictedDomainInput], | ||||||
|  |       }); | ||||||
|  |       setEmailDomainWhitelist([...EmailDomainWhitelist, { | ||||||
|  |         key: restrictedDomainInput, | ||||||
|  |         text: restrictedDomainInput, | ||||||
|  |         value: restrictedDomainInput, | ||||||
|  |       }]); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|   return ( |   return ( | ||||||
|     <Grid columns={1}> |     <Grid columns={1}> | ||||||
|       <Grid.Column> |       <Grid.Column> | ||||||
| @@ -199,6 +249,32 @@ const SystemSetting = () => { | |||||||
|               name='PasswordLoginEnabled' |               name='PasswordLoginEnabled' | ||||||
|               onChange={handleInputChange} |               onChange={handleInputChange} | ||||||
|             /> |             /> | ||||||
|  |             { | ||||||
|  |               showPasswordWarningModal && | ||||||
|  |               <Modal | ||||||
|  |                 open={showPasswordWarningModal} | ||||||
|  |                 onClose={() => setShowPasswordWarningModal(false)} | ||||||
|  |                 size={'tiny'} | ||||||
|  |                 style={{ maxWidth: '450px' }} | ||||||
|  |               > | ||||||
|  |                 <Modal.Header>警告</Modal.Header> | ||||||
|  |                 <Modal.Content> | ||||||
|  |                   <p>取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?</p> | ||||||
|  |                 </Modal.Content> | ||||||
|  |                 <Modal.Actions> | ||||||
|  |                   <Button onClick={() => setShowPasswordWarningModal(false)}>取消</Button> | ||||||
|  |                   <Button | ||||||
|  |                     color='yellow' | ||||||
|  |                     onClick={async () => { | ||||||
|  |                       setShowPasswordWarningModal(false); | ||||||
|  |                       await updateOption('PasswordLoginEnabled', 'false'); | ||||||
|  |                     }} | ||||||
|  |                   > | ||||||
|  |                     确定 | ||||||
|  |                   </Button> | ||||||
|  |                 </Modal.Actions> | ||||||
|  |               </Modal> | ||||||
|  |             } | ||||||
|             <Form.Checkbox |             <Form.Checkbox | ||||||
|               checked={inputs.PasswordRegisterEnabled === 'true'} |               checked={inputs.PasswordRegisterEnabled === 'true'} | ||||||
|               label='允许通过密码进行注册' |               label='允许通过密码进行注册' | ||||||
| @@ -239,6 +315,54 @@ const SystemSetting = () => { | |||||||
|             /> |             /> | ||||||
|           </Form.Group> |           </Form.Group> | ||||||
|           <Divider /> |           <Divider /> | ||||||
|  |           <Header as='h3'> | ||||||
|  |             配置邮箱域名白名单 | ||||||
|  |             <Header.Subheader>用以防止恶意用户利用临时邮箱批量注册</Header.Subheader> | ||||||
|  |           </Header> | ||||||
|  |           <Form.Group widths={3}> | ||||||
|  |             <Form.Checkbox | ||||||
|  |               label='启用邮箱域名白名单' | ||||||
|  |               name='EmailDomainRestrictionEnabled' | ||||||
|  |               onChange={handleInputChange} | ||||||
|  |               checked={inputs.EmailDomainRestrictionEnabled === 'true'} | ||||||
|  |             /> | ||||||
|  |           </Form.Group> | ||||||
|  |           <Form.Group widths={2}> | ||||||
|  |             <Form.Dropdown | ||||||
|  |               label='允许的邮箱域名' | ||||||
|  |               placeholder='允许的邮箱域名' | ||||||
|  |               name='EmailDomainWhitelist' | ||||||
|  |               required | ||||||
|  |               fluid | ||||||
|  |               multiple | ||||||
|  |               selection | ||||||
|  |               onChange={handleInputChange} | ||||||
|  |               value={inputs.EmailDomainWhitelist} | ||||||
|  |               autoComplete='new-password' | ||||||
|  |               options={EmailDomainWhitelist} | ||||||
|  |             /> | ||||||
|  |             <Form.Input | ||||||
|  |               label='添加新的允许的邮箱域名' | ||||||
|  |               action={ | ||||||
|  |                 <Button type='button' onClick={() => { | ||||||
|  |                   submitNewRestrictedDomain(); | ||||||
|  |                 }}>填入</Button> | ||||||
|  |               } | ||||||
|  |               onKeyDown={(e) => { | ||||||
|  |                 if (e.key === 'Enter') { | ||||||
|  |                   submitNewRestrictedDomain(); | ||||||
|  |                 } | ||||||
|  |               }} | ||||||
|  |               autoComplete='new-password' | ||||||
|  |               placeholder='输入新的允许的邮箱域名' | ||||||
|  |               value={restrictedDomainInput} | ||||||
|  |               onChange={(e, { value }) => { | ||||||
|  |                 setRestrictedDomainInput(value); | ||||||
|  |               }} | ||||||
|  |             /> | ||||||
|  |           </Form.Group> | ||||||
|  |           <Form.Button onClick={submitEmailDomainWhitelist}>保存邮箱域名白名单设置</Form.Button> | ||||||
|  |           <Divider /> | ||||||
|           <Header as='h3'> |           <Header as='h3'> | ||||||
|             配置 SMTP |             配置 SMTP | ||||||
|             <Header.Subheader>用以支持系统的邮件发送</Header.Subheader> |             <Header.Subheader>用以支持系统的邮件发送</Header.Subheader> | ||||||
| @@ -284,7 +408,7 @@ const SystemSetting = () => { | |||||||
|               onChange={handleInputChange} |               onChange={handleInputChange} | ||||||
|               type='password' |               type='password' | ||||||
|               autoComplete='new-password' |               autoComplete='new-password' | ||||||
|               value={inputs.SMTPToken} |               checked={inputs.RegisterEnabled === 'true'} | ||||||
|               placeholder='敏感信息不会发送到前端显示' |               placeholder='敏感信息不会发送到前端显示' | ||||||
|             /> |             /> | ||||||
|           </Form.Group> |           </Form.Group> | ||||||
|   | |||||||
| @@ -1,11 +1,22 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Button, Form, Label, Modal, Pagination, Popup, Table } from 'semantic-ui-react'; | import { Button, Dropdown, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; | ||||||
| import { Link } from 'react-router-dom'; | import { Link } from 'react-router-dom'; | ||||||
| import { API, copy, showError, showSuccess, showWarning, timestamp2string } from '../helpers'; | import { API, copy, showError, showSuccess, showWarning, timestamp2string } from '../helpers'; | ||||||
|  |  | ||||||
| import { ITEMS_PER_PAGE } from '../constants'; | import { ITEMS_PER_PAGE } from '../constants'; | ||||||
| import { renderQuota } from '../helpers/render'; | import { renderQuota } from '../helpers/render'; | ||||||
|  |  | ||||||
|  | const COPY_OPTIONS = [ | ||||||
|  |   { key: 'next', text: 'ChatGPT Next Web', value: 'next' }, | ||||||
|  |   { key: 'ama', text: 'AMA 问天', value: 'ama' }, | ||||||
|  |   { key: 'opencat', text: 'OpenCat', value: 'opencat' }, | ||||||
|  | ]; | ||||||
|  |  | ||||||
|  | const OPEN_LINK_OPTIONS = [ | ||||||
|  |   { key: 'ama', text: 'AMA 问天', value: 'ama' }, | ||||||
|  |   { key: 'opencat', text: 'OpenCat', value: 'opencat' }, | ||||||
|  | ]; | ||||||
|  |  | ||||||
| function renderTimestamp(timestamp) { | function renderTimestamp(timestamp) { | ||||||
|   return ( |   return ( | ||||||
|     <> |     <> | ||||||
| @@ -68,6 +79,84 @@ const TokensTable = () => { | |||||||
|   const refresh = async () => { |   const refresh = async () => { | ||||||
|     setLoading(true); |     setLoading(true); | ||||||
|     await loadTokens(activePage - 1); |     await loadTokens(activePage - 1); | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   const onCopy = async (type, key) => { | ||||||
|  |     let status = localStorage.getItem('status'); | ||||||
|  |     let serverAddress = ''; | ||||||
|  |     if (status) { | ||||||
|  |       status = JSON.parse(status); | ||||||
|  |       serverAddress = status.server_address; | ||||||
|  |     } | ||||||
|  |     if (serverAddress === '') { | ||||||
|  |       serverAddress = window.location.origin; | ||||||
|  |     } | ||||||
|  |     let encodedServerAddress = encodeURIComponent(serverAddress); | ||||||
|  |     const nextLink = localStorage.getItem('chat_link'); | ||||||
|  |     let nextUrl; | ||||||
|  |    | ||||||
|  |     if (nextLink) { | ||||||
|  |       nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; | ||||||
|  |     } else { | ||||||
|  |       nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     let url; | ||||||
|  |     switch (type) { | ||||||
|  |       case 'ama': | ||||||
|  |         url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`; | ||||||
|  |         break; | ||||||
|  |       case 'opencat': | ||||||
|  |         url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`; | ||||||
|  |         break; | ||||||
|  |       case 'next': | ||||||
|  |         url = nextUrl; | ||||||
|  |         break; | ||||||
|  |       default: | ||||||
|  |         url = `sk-${key}`; | ||||||
|  |     } | ||||||
|  |     if (await copy(url)) { | ||||||
|  |       showSuccess('已复制到剪贴板!'); | ||||||
|  |     } else { | ||||||
|  |       showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。'); | ||||||
|  |       setSearchKeyword(url); | ||||||
|  |     } | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   const onOpenLink = async (type, key) => { | ||||||
|  |     let status = localStorage.getItem('status'); | ||||||
|  |     let serverAddress = ''; | ||||||
|  |     if (status) { | ||||||
|  |       status = JSON.parse(status); | ||||||
|  |       serverAddress = status.server_address;  | ||||||
|  |     } | ||||||
|  |     if (serverAddress === '') { | ||||||
|  |       serverAddress = window.location.origin; | ||||||
|  |     } | ||||||
|  |     let encodedServerAddress = encodeURIComponent(serverAddress); | ||||||
|  |     const chatLink = localStorage.getItem('chat_link'); | ||||||
|  |     let defaultUrl; | ||||||
|  |    | ||||||
|  |     if (chatLink) { | ||||||
|  |       defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`; | ||||||
|  |     } else { | ||||||
|  |       defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; | ||||||
|  |     } | ||||||
|  |     let url; | ||||||
|  |     switch (type) { | ||||||
|  |       case 'ama': | ||||||
|  |         url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`; | ||||||
|  |         break; | ||||||
|  |    | ||||||
|  |       case 'opencat': | ||||||
|  |         url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`; | ||||||
|  |         break; | ||||||
|  |          | ||||||
|  |       default: | ||||||
|  |         url = defaultUrl; | ||||||
|  |     } | ||||||
|  |    | ||||||
|  |     window.open(url, '_blank'); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   useEffect(() => { |   useEffect(() => { | ||||||
| @@ -235,21 +324,51 @@ const TokensTable = () => { | |||||||
|                   <Table.Cell>{token.expired_time === -1 ? '永不过期' : renderTimestamp(token.expired_time)}</Table.Cell> |                   <Table.Cell>{token.expired_time === -1 ? '永不过期' : renderTimestamp(token.expired_time)}</Table.Cell> | ||||||
|                   <Table.Cell> |                   <Table.Cell> | ||||||
|                     <div> |                     <div> | ||||||
|  |                     <Button.Group color='green' size={'small'}> | ||||||
|                         <Button |                         <Button | ||||||
|                           size={'small'} |                           size={'small'} | ||||||
|                           positive |                           positive | ||||||
|                           onClick={async () => { |                           onClick={async () => { | ||||||
|                           let key = "sk-" + token.key; |                             await onCopy('', token.key); | ||||||
|                           if (await copy(key)) { |  | ||||||
|                             showSuccess('已复制到剪贴板!'); |  | ||||||
|                           } else { |  | ||||||
|                             showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。'); |  | ||||||
|                             setSearchKeyword(key); |  | ||||||
|                           } |  | ||||||
|                           }} |                           }} | ||||||
|                         > |                         > | ||||||
|                           复制 |                           复制 | ||||||
|                         </Button> |                         </Button> | ||||||
|  |                         <Dropdown | ||||||
|  |                           className='button icon' | ||||||
|  |                           floating | ||||||
|  |                           options={COPY_OPTIONS.map(option => ({ | ||||||
|  |                             ...option, | ||||||
|  |                             onClick: async () => { | ||||||
|  |                               await onCopy(option.value, token.key); | ||||||
|  |                             } | ||||||
|  |                           }))} | ||||||
|  |                           trigger={<></>} | ||||||
|  |                         /> | ||||||
|  |                       </Button.Group> | ||||||
|  |                       {' '} | ||||||
|  |                       <Button.Group color='blue' size={'small'}> | ||||||
|  |                         <Button | ||||||
|  |                             size={'small'} | ||||||
|  |                             positive | ||||||
|  |                             onClick={() => {      | ||||||
|  |                               onOpenLink('', token.key);        | ||||||
|  |                             }}> | ||||||
|  |                             聊天 | ||||||
|  |                           </Button> | ||||||
|  |                           <Dropdown    | ||||||
|  |                             className="button icon"        | ||||||
|  |                             floating | ||||||
|  |                             options={OPEN_LINK_OPTIONS.map(option => ({ | ||||||
|  |                               ...option, | ||||||
|  |                               onClick: async () => { | ||||||
|  |                                 await onOpenLink(option.value, token.key); | ||||||
|  |                               } | ||||||
|  |                             }))}        | ||||||
|  |                             trigger={<></>}    | ||||||
|  |                           /> | ||||||
|  |                       </Button.Group> | ||||||
|  |                       {' '} | ||||||
|                       <Popup |                       <Popup | ||||||
|                         trigger={ |                         trigger={ | ||||||
|                           <Button size='small' negative> |                           <Button size='small' negative> | ||||||
|   | |||||||
| @@ -227,7 +227,7 @@ const UsersTable = () => { | |||||||
|                       content={user.email ? user.email : '未绑定邮箱地址'} |                       content={user.email ? user.email : '未绑定邮箱地址'} | ||||||
|                       key={user.username} |                       key={user.username} | ||||||
|                       header={user.display_name ? user.display_name : user.username} |                       header={user.display_name ? user.display_name : user.username} | ||||||
|                       trigger={<span>{renderText(user.username, 10)}</span>} |                       trigger={<span>{renderText(user.username, 15)}</span>} | ||||||
|                       hoverable |                       hoverable | ||||||
|                     /> |                     /> | ||||||
|                   </Table.Cell> |                   </Table.Cell> | ||||||
|   | |||||||
							
								
								
									
										20
									
								
								web/src/components/utils.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								web/src/components/utils.js
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | |||||||
|  | import { API, showError } from '../helpers'; | ||||||
|  |  | ||||||
|  | export async function getOAuthState() { | ||||||
|  |   const res = await API.get('/api/oauth/state'); | ||||||
|  |   const { success, message, data } = res.data; | ||||||
|  |   if (success) { | ||||||
|  |     return data; | ||||||
|  |   } else { | ||||||
|  |     showError(message); | ||||||
|  |     return ''; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | export async function onGitHubOAuthClicked(github_client_id) { | ||||||
|  |   const state = await getOAuthState(); | ||||||
|  |   if (!state) return; | ||||||
|  |   window.open( | ||||||
|  |     `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email` | ||||||
|  |   ); | ||||||
|  | } | ||||||
| @@ -1,14 +1,24 @@ | |||||||
| export const CHANNEL_OPTIONS = [ | export const CHANNEL_OPTIONS = [ | ||||||
|   { key: 1, text: 'OpenAI', value: 1, color: 'green' }, |   { key: 1, text: 'OpenAI', value: 1, color: 'green' }, | ||||||
|   { key: 8, text: '自定义', value: 8, color: 'pink' }, |   { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, | ||||||
|   { key: 3, text: 'Azure', value: 3, color: 'olive' }, |   { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, | ||||||
|   { key: 2, text: 'API2D', value: 2, color: 'blue' }, |   { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, | ||||||
|   { key: 4, text: 'CloseAI', value: 4, color: 'teal' }, |   { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, | ||||||
|   { key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' }, |   { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, | ||||||
|   { key: 6, text: 'OpenAI Max', value: 6, color: 'violet' }, |   { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, | ||||||
|   { key: 7, text: 'OhMyGPT', value: 7, color: 'purple' }, |   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, | ||||||
|   { key: 9, text: 'AI.LS', value: 9, color: 'yellow' }, |   { key: 19, text: '360 智脑', value: 19, color: 'blue' }, | ||||||
|   { key: 10, text: 'AI Proxy', value: 10, color: 'purple' }, |   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||||
|   { key: 12, text: 'API2GPT', value: 12, color: 'blue' }, |   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, | ||||||
|   { key: 13, text: 'AIGC2D', value: 13, color: 'purple' } |   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, | ||||||
|  |   { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, | ||||||
|  |   { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, | ||||||
|  |   { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, | ||||||
|  |   { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' }, | ||||||
|  |   { key: 10, text: '代理:AI Proxy', value: 10, color: 'purple' }, | ||||||
|  |   { key: 4, text: '代理:CloseAI', value: 4, color: 'teal' }, | ||||||
|  |   { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' }, | ||||||
|  |   { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' }, | ||||||
|  |   { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' }, | ||||||
|  |   { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' } | ||||||
| ]; | ]; | ||||||
| @@ -1,5 +1,5 @@ | |||||||
| export const toastConstants = { | export const toastConstants = { | ||||||
|   SUCCESS_TIMEOUT: 500, |   SUCCESS_TIMEOUT: 1500, | ||||||
|   INFO_TIMEOUT: 3000, |   INFO_TIMEOUT: 3000, | ||||||
|   ERROR_TIMEOUT: 5000, |   ERROR_TIMEOUT: 5000, | ||||||
|   WARNING_TIMEOUT: 10000, |   WARNING_TIMEOUT: 10000, | ||||||
|   | |||||||
| @@ -1,6 +1,11 @@ | |||||||
| import { toast } from 'react-toastify'; | import { toast } from 'react-toastify'; | ||||||
| import { toastConstants } from '../constants'; | import { toastConstants } from '../constants'; | ||||||
|  | import React from 'react'; | ||||||
|  |  | ||||||
|  | const HTMLToastContent = ({ htmlContent }) => { | ||||||
|  |   return <div dangerouslySetInnerHTML={{ __html: htmlContent }} />; | ||||||
|  | }; | ||||||
|  | export default HTMLToastContent; | ||||||
| export function isAdmin() { | export function isAdmin() { | ||||||
|   let user = localStorage.getItem('user'); |   let user = localStorage.getItem('user'); | ||||||
|   if (!user) return false; |   if (!user) return false; | ||||||
| @@ -107,8 +112,12 @@ export function showInfo(message) { | |||||||
|   toast.info(message, showInfoOptions); |   toast.info(message, showInfoOptions); | ||||||
| } | } | ||||||
|  |  | ||||||
| export function showNotice(message) { | export function showNotice(message, isHTML = false) { | ||||||
|  |   if (isHTML) { | ||||||
|  |     toast(<HTMLToastContent htmlContent={message} />, showNoticeOptions); | ||||||
|  |   } else { | ||||||
|     toast.info(message, showNoticeOptions); |     toast.info(message, showNoticeOptions); | ||||||
|  |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| export function openPage(url) { | export function openPage(url) { | ||||||
|   | |||||||
| @@ -46,9 +46,7 @@ const About = () => { | |||||||
|             about.startsWith('https://') ? <iframe |             about.startsWith('https://') ? <iframe | ||||||
|               src={about} |               src={about} | ||||||
|               style={{ width: '100%', height: '100vh', border: 'none' }} |               style={{ width: '100%', height: '100vh', border: 'none' }} | ||||||
|             /> : <Segment> |             /> : <div style={{ fontSize: 'larger' }} dangerouslySetInnerHTML={{ __html: about }}></div> | ||||||
|               <div style={{ fontSize: 'larger' }} dangerouslySetInnerHTML={{ __html: about }}></div> |  | ||||||
|             </Segment> |  | ||||||
|           } |           } | ||||||
|         </> |         </> | ||||||
|       } |       } | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react'; | import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react'; | ||||||
| import { useParams } from 'react-router-dom'; | import { useNavigate, useParams } from 'react-router-dom'; | ||||||
| import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; | import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; | ||||||
| import { CHANNEL_OPTIONS } from '../../constants'; | import { CHANNEL_OPTIONS } from '../../constants'; | ||||||
|  |  | ||||||
| @@ -10,11 +10,30 @@ const MODEL_MAPPING_EXAMPLE = { | |||||||
|   'gpt-4-32k-0314': 'gpt-4-32k' |   'gpt-4-32k-0314': 'gpt-4-32k' | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | function type2secretPrompt(type) { | ||||||
|  |   // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') | ||||||
|  |   switch (type) { | ||||||
|  |     case 15: | ||||||
|  |       return '按照如下格式输入:APIKey|SecretKey'; | ||||||
|  |     case 18: | ||||||
|  |       return '按照如下格式输入:APPID|APISecret|APIKey'; | ||||||
|  |     case 22: | ||||||
|  |       return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041'; | ||||||
|  |     default: | ||||||
|  |       return '请输入渠道对应的鉴权密钥'; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
| const EditChannel = () => { | const EditChannel = () => { | ||||||
|   const params = useParams(); |   const params = useParams(); | ||||||
|  |   const navigate = useNavigate(); | ||||||
|   const channelId = params.id; |   const channelId = params.id; | ||||||
|   const isEdit = channelId !== undefined; |   const isEdit = channelId !== undefined; | ||||||
|   const [loading, setLoading] = useState(isEdit); |   const [loading, setLoading] = useState(isEdit); | ||||||
|  |   const handleCancel = () => { | ||||||
|  |     navigate('/channel'); | ||||||
|  |   }; | ||||||
|  |  | ||||||
|   const originInputs = { |   const originInputs = { | ||||||
|     name: '', |     name: '', | ||||||
|     type: 1, |     type: 1, | ||||||
| @@ -27,6 +46,7 @@ const EditChannel = () => { | |||||||
|   }; |   }; | ||||||
|   const [batch, setBatch] = useState(false); |   const [batch, setBatch] = useState(false); | ||||||
|   const [inputs, setInputs] = useState(originInputs); |   const [inputs, setInputs] = useState(originInputs); | ||||||
|  |   const [originModelOptions, setOriginModelOptions] = useState([]); | ||||||
|   const [modelOptions, setModelOptions] = useState([]); |   const [modelOptions, setModelOptions] = useState([]); | ||||||
|   const [groupOptions, setGroupOptions] = useState([]); |   const [groupOptions, setGroupOptions] = useState([]); | ||||||
|   const [basicModels, setBasicModels] = useState([]); |   const [basicModels, setBasicModels] = useState([]); | ||||||
| @@ -34,6 +54,33 @@ const EditChannel = () => { | |||||||
|   const [customModel, setCustomModel] = useState(''); |   const [customModel, setCustomModel] = useState(''); | ||||||
|   const handleInputChange = (e, { name, value }) => { |   const handleInputChange = (e, { name, value }) => { | ||||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); |     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||||
|  |     if (name === 'type' && inputs.models.length === 0) { | ||||||
|  |       let localModels = []; | ||||||
|  |       switch (value) { | ||||||
|  |         case 14: | ||||||
|  |           localModels = ['claude-instant-1', 'claude-2']; | ||||||
|  |           break; | ||||||
|  |         case 11: | ||||||
|  |           localModels = ['PaLM-2']; | ||||||
|  |           break; | ||||||
|  |         case 15: | ||||||
|  |           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; | ||||||
|  |           break; | ||||||
|  |         case 17: | ||||||
|  |           localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1']; | ||||||
|  |           break; | ||||||
|  |         case 16: | ||||||
|  |           localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||||
|  |           break; | ||||||
|  |         case 18: | ||||||
|  |           localModels = ['SparkDesk']; | ||||||
|  |           break; | ||||||
|  |         case 19: | ||||||
|  |           localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4']; | ||||||
|  |           break; | ||||||
|  |       } | ||||||
|  |       setInputs((inputs) => ({ ...inputs, models: localModels })); | ||||||
|  |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   const loadChannel = async () => { |   const loadChannel = async () => { | ||||||
| @@ -44,19 +91,6 @@ const EditChannel = () => { | |||||||
|         data.models = []; |         data.models = []; | ||||||
|       } else { |       } else { | ||||||
|         data.models = data.models.split(','); |         data.models = data.models.split(','); | ||||||
|         setTimeout(() => { |  | ||||||
|           let localModelOptions = [...modelOptions]; |  | ||||||
|           data.models.forEach((model) => { |  | ||||||
|             if (!localModelOptions.find((option) => option.key === model)) { |  | ||||||
|               localModelOptions.push({ |  | ||||||
|                 key: model, |  | ||||||
|                 text: model, |  | ||||||
|                 value: model |  | ||||||
|               }); |  | ||||||
|             } |  | ||||||
|           }); |  | ||||||
|           setModelOptions(localModelOptions); |  | ||||||
|         }, 1000); |  | ||||||
|       } |       } | ||||||
|       if (data.group === '') { |       if (data.group === '') { | ||||||
|         data.groups = []; |         data.groups = []; | ||||||
| @@ -76,13 +110,16 @@ const EditChannel = () => { | |||||||
|   const fetchModels = async () => { |   const fetchModels = async () => { | ||||||
|     try { |     try { | ||||||
|       let res = await API.get(`/api/channel/models`); |       let res = await API.get(`/api/channel/models`); | ||||||
|       setModelOptions(res.data.data.map((model) => ({ |       let localModelOptions = res.data.data.map((model) => ({ | ||||||
|         key: model.id, |         key: model.id, | ||||||
|         text: model.id, |         text: model.id, | ||||||
|         value: model.id |         value: model.id | ||||||
|       }))); |       })); | ||||||
|  |       setOriginModelOptions(localModelOptions); | ||||||
|       setFullModels(res.data.data.map((model) => model.id)); |       setFullModels(res.data.data.map((model) => model.id)); | ||||||
|       setBasicModels(res.data.data.filter((model) => !model.id.startsWith('gpt-4')).map((model) => model.id)); |       setBasicModels(res.data.data.filter((model) => { | ||||||
|  |         return model.id.startsWith('gpt-3') || model.id.startsWith('text-'); | ||||||
|  |       }).map((model) => model.id)); | ||||||
|     } catch (error) { |     } catch (error) { | ||||||
|       showError(error.message); |       showError(error.message); | ||||||
|     } |     } | ||||||
| @@ -101,6 +138,20 @@ const EditChannel = () => { | |||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   useEffect(() => { | ||||||
|  |     let localModelOptions = [...originModelOptions]; | ||||||
|  |     inputs.models.forEach((model) => { | ||||||
|  |       if (!localModelOptions.find((option) => option.key === model)) { | ||||||
|  |         localModelOptions.push({ | ||||||
|  |           key: model, | ||||||
|  |           text: model, | ||||||
|  |           value: model | ||||||
|  |         }); | ||||||
|  |       } | ||||||
|  |     }); | ||||||
|  |     setModelOptions(localModelOptions); | ||||||
|  |   }, [originModelOptions, inputs.models]); | ||||||
|  |  | ||||||
|   useEffect(() => { |   useEffect(() => { | ||||||
|     if (isEdit) { |     if (isEdit) { | ||||||
|       loadChannel().then(); |       loadChannel().then(); | ||||||
| @@ -123,11 +174,14 @@ const EditChannel = () => { | |||||||
|       return; |       return; | ||||||
|     } |     } | ||||||
|     let localInputs = inputs; |     let localInputs = inputs; | ||||||
|     if (localInputs.base_url.endsWith('/')) { |     if (localInputs.base_url && localInputs.base_url.endsWith('/')) { | ||||||
|       localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); |       localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); | ||||||
|     } |     } | ||||||
|     if (localInputs.type === 3 && localInputs.other === '') { |     if (localInputs.type === 3 && localInputs.other === '') { | ||||||
|       localInputs.other = '2023-03-15-preview'; |       localInputs.other = '2023-06-01-preview'; | ||||||
|  |     } | ||||||
|  |     if (localInputs.type === 18 && localInputs.other === '') { | ||||||
|  |       localInputs.other = 'v2.1'; | ||||||
|     } |     } | ||||||
|     let res; |     let res; | ||||||
|     localInputs.models = localInputs.models.join(','); |     localInputs.models = localInputs.models.join(','); | ||||||
| @@ -150,6 +204,24 @@ const EditChannel = () => { | |||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   const addCustomModel = () => { | ||||||
|  |     if (customModel.trim() === '') return; | ||||||
|  |     if (inputs.models.includes(customModel)) return; | ||||||
|  |     let localModels = [...inputs.models]; | ||||||
|  |     localModels.push(customModel); | ||||||
|  |     let localModelOptions = []; | ||||||
|  |     localModelOptions.push({ | ||||||
|  |       key: customModel, | ||||||
|  |       text: customModel, | ||||||
|  |       value: customModel | ||||||
|  |     }); | ||||||
|  |     setModelOptions(modelOptions => { | ||||||
|  |       return [...modelOptions, ...localModelOptions]; | ||||||
|  |     }); | ||||||
|  |     setCustomModel(''); | ||||||
|  |     handleInputChange(null, { name: 'models', value: localModels }); | ||||||
|  |   }; | ||||||
|  |  | ||||||
|   return ( |   return ( | ||||||
|     <> |     <> | ||||||
|       <Segment loading={loading}> |       <Segment loading={loading}> | ||||||
| @@ -187,7 +259,7 @@ const EditChannel = () => { | |||||||
|                   <Form.Input |                   <Form.Input | ||||||
|                     label='默认 API 版本' |                     label='默认 API 版本' | ||||||
|                     name='other' |                     name='other' | ||||||
|                     placeholder={'请输入默认 API 版本,例如:2023-03-15-preview,该配置可以被实际的请求查询参数所覆盖'} |                     placeholder={'请输入默认 API 版本,例如:2023-06-01-preview,该配置可以被实际的请求查询参数所覆盖'} | ||||||
|                     onChange={handleInputChange} |                     onChange={handleInputChange} | ||||||
|                     value={inputs.other} |                     value={inputs.other} | ||||||
|                     autoComplete='new-password' |                     autoComplete='new-password' | ||||||
| @@ -210,26 +282,12 @@ const EditChannel = () => { | |||||||
|               </Form.Field> |               </Form.Field> | ||||||
|             ) |             ) | ||||||
|           } |           } | ||||||
|           { |  | ||||||
|             inputs.type !== 3 && inputs.type !== 8 && ( |  | ||||||
|               <Form.Field> |  | ||||||
|                 <Form.Input |  | ||||||
|                   label='镜像' |  | ||||||
|                   name='base_url' |  | ||||||
|                   placeholder={'此项可选,输入镜像站地址,格式为:https://domain.com'} |  | ||||||
|                   onChange={handleInputChange} |  | ||||||
|                   value={inputs.base_url} |  | ||||||
|                   autoComplete='new-password' |  | ||||||
|                 /> |  | ||||||
|               </Form.Field> |  | ||||||
|             ) |  | ||||||
|           } |  | ||||||
|           <Form.Field> |           <Form.Field> | ||||||
|             <Form.Input |             <Form.Input | ||||||
|               label='名称' |               label='名称' | ||||||
|               required |               required | ||||||
|               name='name' |               name='name' | ||||||
|               placeholder={'请输入名称'} |               placeholder={'请为渠道命名'} | ||||||
|               onChange={handleInputChange} |               onChange={handleInputChange} | ||||||
|               value={inputs.name} |               value={inputs.name} | ||||||
|               autoComplete='new-password' |               autoComplete='new-password' | ||||||
| @@ -238,7 +296,7 @@ const EditChannel = () => { | |||||||
|           <Form.Field> |           <Form.Field> | ||||||
|             <Form.Dropdown |             <Form.Dropdown | ||||||
|               label='分组' |               label='分组' | ||||||
|               placeholder={'请选择分组'} |               placeholder={'请选择可以使用该渠道的分组'} | ||||||
|               name='groups' |               name='groups' | ||||||
|               required |               required | ||||||
|               fluid |               fluid | ||||||
| @@ -252,10 +310,38 @@ const EditChannel = () => { | |||||||
|               options={groupOptions} |               options={groupOptions} | ||||||
|             /> |             /> | ||||||
|           </Form.Field> |           </Form.Field> | ||||||
|  |           { | ||||||
|  |             inputs.type === 18 && ( | ||||||
|  |               <Form.Field> | ||||||
|  |                 <Form.Input | ||||||
|  |                   label='模型版本' | ||||||
|  |                   name='other' | ||||||
|  |                   placeholder={'请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1'} | ||||||
|  |                   onChange={handleInputChange} | ||||||
|  |                   value={inputs.other} | ||||||
|  |                   autoComplete='new-password' | ||||||
|  |                 /> | ||||||
|  |               </Form.Field> | ||||||
|  |             ) | ||||||
|  |           } | ||||||
|  |           { | ||||||
|  |             inputs.type === 21 && ( | ||||||
|  |               <Form.Field> | ||||||
|  |                 <Form.Input | ||||||
|  |                   label='知识库 ID' | ||||||
|  |                   name='other' | ||||||
|  |                   placeholder={'请输入知识库 ID,例如:123456'} | ||||||
|  |                   onChange={handleInputChange} | ||||||
|  |                   value={inputs.other} | ||||||
|  |                   autoComplete='new-password' | ||||||
|  |                 /> | ||||||
|  |               </Form.Field> | ||||||
|  |             ) | ||||||
|  |           } | ||||||
|           <Form.Field> |           <Form.Field> | ||||||
|             <Form.Dropdown |             <Form.Dropdown | ||||||
|               label='模型' |               label='模型' | ||||||
|               placeholder={'请选择该通道所支持的模型'} |               placeholder={'请选择该渠道所支持的模型'} | ||||||
|               name='models' |               name='models' | ||||||
|               required |               required | ||||||
|               fluid |               fluid | ||||||
| @@ -279,30 +365,25 @@ const EditChannel = () => { | |||||||
|             }}>清除所有模型</Button> |             }}>清除所有模型</Button> | ||||||
|             <Input |             <Input | ||||||
|               action={ |               action={ | ||||||
|                 <Button type={'button'} onClick={()=>{ |                 <Button type={'button'} onClick={addCustomModel}>填入</Button> | ||||||
|                   let localModels = [...inputs.models]; |  | ||||||
|                   localModels.push(customModel); |  | ||||||
|                   let localModelOptions = [...modelOptions]; |  | ||||||
|                   localModelOptions.push({ |  | ||||||
|                     key: customModel, |  | ||||||
|                     text: customModel, |  | ||||||
|                     value: customModel, |  | ||||||
|                   }); |  | ||||||
|                   setModelOptions(localModelOptions); |  | ||||||
|                   handleInputChange(null, { name: 'models', value: localModels }); |  | ||||||
|                 }}>填入</Button> |  | ||||||
|               } |               } | ||||||
|               placeholder='输入自定义模型名称' |               placeholder='输入自定义模型名称' | ||||||
|               value={customModel} |               value={customModel} | ||||||
|               onChange={(e, { value }) => { |               onChange={(e, { value }) => { | ||||||
|                 setCustomModel(value); |                 setCustomModel(value); | ||||||
|               }} |               }} | ||||||
|  |               onKeyDown={(e) => { | ||||||
|  |                 if (e.key === 'Enter') { | ||||||
|  |                   addCustomModel(); | ||||||
|  |                   e.preventDefault(); | ||||||
|  |                 } | ||||||
|  |               }} | ||||||
|             /> |             /> | ||||||
|           </div> |           </div> | ||||||
|           <Form.Field> |           <Form.Field> | ||||||
|             <Form.TextArea |             <Form.TextArea | ||||||
|               label='模型映射' |               label='模型重定向' | ||||||
|               placeholder={`此项可选,为一个 JSON 文本,键为用户请求的模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`} |               placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`} | ||||||
|               name='model_mapping' |               name='model_mapping' | ||||||
|               onChange={handleInputChange} |               onChange={handleInputChange} | ||||||
|               value={inputs.model_mapping} |               value={inputs.model_mapping} | ||||||
| @@ -327,7 +408,7 @@ const EditChannel = () => { | |||||||
|                 label='密钥' |                 label='密钥' | ||||||
|                 name='key' |                 name='key' | ||||||
|                 required |                 required | ||||||
|                 placeholder={'请输入密钥'} |                 placeholder={type2secretPrompt(inputs.type)} | ||||||
|                 onChange={handleInputChange} |                 onChange={handleInputChange} | ||||||
|                 value={inputs.key} |                 value={inputs.key} | ||||||
|                 autoComplete='new-password' |                 autoComplete='new-password' | ||||||
| @@ -344,7 +425,36 @@ const EditChannel = () => { | |||||||
|               /> |               /> | ||||||
|             ) |             ) | ||||||
|           } |           } | ||||||
|           <Button type={isEdit ? "button" : "submit"} positive onClick={submit}>提交</Button> |           { | ||||||
|  |             inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && ( | ||||||
|  |               <Form.Field> | ||||||
|  |                 <Form.Input | ||||||
|  |                   label='代理' | ||||||
|  |                   name='base_url' | ||||||
|  |                   placeholder={'此项可选,用于通过代理站来进行 API 调用,请输入代理站地址,格式为:https://domain.com'} | ||||||
|  |                   onChange={handleInputChange} | ||||||
|  |                   value={inputs.base_url} | ||||||
|  |                   autoComplete='new-password' | ||||||
|  |                 /> | ||||||
|  |               </Form.Field> | ||||||
|  |             ) | ||||||
|  |           } | ||||||
|  |           { | ||||||
|  |             inputs.type === 22 && ( | ||||||
|  |               <Form.Field> | ||||||
|  |                 <Form.Input | ||||||
|  |                   label='私有部署地址' | ||||||
|  |                   name='base_url' | ||||||
|  |                   placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'} | ||||||
|  |                   onChange={handleInputChange} | ||||||
|  |                   value={inputs.base_url} | ||||||
|  |                   autoComplete='new-password' | ||||||
|  |                 /> | ||||||
|  |               </Form.Field> | ||||||
|  |             ) | ||||||
|  |           } | ||||||
|  |           <Button onClick={handleCancel}>取消</Button> | ||||||
|  |           <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button> | ||||||
|         </Form> |         </Form> | ||||||
|       </Segment> |       </Segment> | ||||||
|     </> |     </> | ||||||
|   | |||||||
| @@ -15,7 +15,8 @@ const Home = () => { | |||||||
|     if (success) { |     if (success) { | ||||||
|       let oldNotice = localStorage.getItem('notice'); |       let oldNotice = localStorage.getItem('notice'); | ||||||
|         if (data !== oldNotice && data !== '') { |         if (data !== oldNotice && data !== '') { | ||||||
|         showNotice(data); |             const htmlNotice = marked(data); | ||||||
|  |             showNotice(htmlNotice, true); | ||||||
|             localStorage.setItem('notice', data); |             localStorage.setItem('notice', data); | ||||||
|         } |         } | ||||||
|     } else { |     } else { | ||||||
| @@ -64,7 +65,7 @@ const Home = () => { | |||||||
|                     <Card.Meta>系统信息总览</Card.Meta> |                     <Card.Meta>系统信息总览</Card.Meta> | ||||||
|                     <Card.Description> |                     <Card.Description> | ||||||
|                       <p>名称:{statusState?.status?.system_name}</p> |                       <p>名称:{statusState?.status?.system_name}</p> | ||||||
|                       <p>版本:{statusState?.status?.version}</p> |                       <p>版本:{statusState?.status?.version ? statusState?.status?.version : "unknown"}</p> | ||||||
|                       <p> |                       <p> | ||||||
|                         源码: |                         源码: | ||||||
|                         <a |                         <a | ||||||
|   | |||||||
| @@ -1,11 +1,12 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Button, Form, Header, Segment } from 'semantic-ui-react'; | import { Button, Form, Header, Segment } from 'semantic-ui-react'; | ||||||
| import { useParams } from 'react-router-dom'; | import { useParams, useNavigate } from 'react-router-dom'; | ||||||
| import { API, downloadTextAsFile, showError, showSuccess } from '../../helpers'; | import { API, downloadTextAsFile, showError, showSuccess } from '../../helpers'; | ||||||
| import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render'; | import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render'; | ||||||
|  |  | ||||||
| const EditRedemption = () => { | const EditRedemption = () => { | ||||||
|   const params = useParams(); |   const params = useParams(); | ||||||
|  |   const navigate = useNavigate(); | ||||||
|   const redemptionId = params.id; |   const redemptionId = params.id; | ||||||
|   const isEdit = redemptionId !== undefined; |   const isEdit = redemptionId !== undefined; | ||||||
|   const [loading, setLoading] = useState(isEdit); |   const [loading, setLoading] = useState(isEdit); | ||||||
| @@ -17,6 +18,10 @@ const EditRedemption = () => { | |||||||
|   const [inputs, setInputs] = useState(originInputs); |   const [inputs, setInputs] = useState(originInputs); | ||||||
|   const { name, quota, count } = inputs; |   const { name, quota, count } = inputs; | ||||||
|  |  | ||||||
|  |   const handleCancel = () => { | ||||||
|  |     navigate('/redemption'); | ||||||
|  |   }; | ||||||
|  |    | ||||||
|   const handleInputChange = (e, { name, value }) => { |   const handleInputChange = (e, { name, value }) => { | ||||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); |     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||||
|   }; |   }; | ||||||
| @@ -113,6 +118,7 @@ const EditRedemption = () => { | |||||||
|             </> |             </> | ||||||
|           } |           } | ||||||
|           <Button positive onClick={submit}>提交</Button> |           <Button positive onClick={submit}>提交</Button> | ||||||
|  |           <Button onClick={handleCancel}>取消</Button> | ||||||
|         </Form> |         </Form> | ||||||
|       </Segment> |       </Segment> | ||||||
|     </> |     </> | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Button, Form, Header, Message, Segment } from 'semantic-ui-react'; | import { Button, Form, Header, Message, Segment } from 'semantic-ui-react'; | ||||||
| import { useParams } from 'react-router-dom'; | import { useParams, useNavigate } from 'react-router-dom'; | ||||||
| import { API, showError, showSuccess, timestamp2string } from '../../helpers'; | import { API, showError, showSuccess, timestamp2string } from '../../helpers'; | ||||||
| import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render'; | import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render'; | ||||||
|  |  | ||||||
| @@ -17,11 +17,13 @@ const EditToken = () => { | |||||||
|   }; |   }; | ||||||
|   const [inputs, setInputs] = useState(originInputs); |   const [inputs, setInputs] = useState(originInputs); | ||||||
|   const { name, remain_quota, expired_time, unlimited_quota } = inputs; |   const { name, remain_quota, expired_time, unlimited_quota } = inputs; | ||||||
|  |   const navigate = useNavigate(); | ||||||
|   const handleInputChange = (e, { name, value }) => { |   const handleInputChange = (e, { name, value }) => { | ||||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); |     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||||
|   }; |   }; | ||||||
|  |   const handleCancel = () => { | ||||||
|  |     navigate("/token"); | ||||||
|  |   } | ||||||
|   const setExpiredTime = (month, day, hour, minute) => { |   const setExpiredTime = (month, day, hour, minute) => { | ||||||
|     let now = new Date(); |     let now = new Date(); | ||||||
|     let timestamp = now.getTime() / 1000; |     let timestamp = now.getTime() / 1000; | ||||||
| @@ -83,7 +85,7 @@ const EditToken = () => { | |||||||
|       if (isEdit) { |       if (isEdit) { | ||||||
|         showSuccess('令牌更新成功!'); |         showSuccess('令牌更新成功!'); | ||||||
|       } else { |       } else { | ||||||
|         showSuccess('令牌创建成功!'); |         showSuccess('令牌创建成功,请在列表页面点击复制获取令牌!'); | ||||||
|         setInputs(originInputs); |         setInputs(originInputs); | ||||||
|       } |       } | ||||||
|     } else { |     } else { | ||||||
| @@ -150,8 +152,9 @@ const EditToken = () => { | |||||||
|           </Form.Field> |           </Form.Field> | ||||||
|           <Button type={'button'} onClick={() => { |           <Button type={'button'} onClick={() => { | ||||||
|             setUnlimitedQuota(); |             setUnlimitedQuota(); | ||||||
|           }}>{unlimited_quota ? '取消无限额度' : '设置为无限额度'}</Button> |           }}>{unlimited_quota ? '取消无限额度' : '设为无限额度'}</Button> | ||||||
|           <Button positive onClick={submit}>提交</Button> |           <Button floated='right' positive onClick={submit}>提交</Button> | ||||||
|  |           <Button floated='right' onClick={handleCancel}>取消</Button> | ||||||
|         </Form> |         </Form> | ||||||
|       </Segment> |       </Segment> | ||||||
|     </> |     </> | ||||||
|   | |||||||
| @@ -7,12 +7,15 @@ const TopUp = () => { | |||||||
|   const [redemptionCode, setRedemptionCode] = useState(''); |   const [redemptionCode, setRedemptionCode] = useState(''); | ||||||
|   const [topUpLink, setTopUpLink] = useState(''); |   const [topUpLink, setTopUpLink] = useState(''); | ||||||
|   const [userQuota, setUserQuota] = useState(0); |   const [userQuota, setUserQuota] = useState(0); | ||||||
|  |   const [isSubmitting, setIsSubmitting] = useState(false); | ||||||
|  |  | ||||||
|   const topUp = async () => { |   const topUp = async () => { | ||||||
|     if (redemptionCode === '') { |     if (redemptionCode === '') { | ||||||
|       showInfo('请输入充值码!') |       showInfo('请输入充值码!') | ||||||
|       return; |       return; | ||||||
|     } |     } | ||||||
|  |     setIsSubmitting(true); | ||||||
|  |     try { | ||||||
|       const res = await API.post('/api/user/topup', { |       const res = await API.post('/api/user/topup', { | ||||||
|         key: redemptionCode |         key: redemptionCode | ||||||
|       }); |       }); | ||||||
| @@ -26,6 +29,11 @@ const TopUp = () => { | |||||||
|       } else { |       } else { | ||||||
|         showError(message); |         showError(message); | ||||||
|       } |       } | ||||||
|  |     } catch (err) { | ||||||
|  |       showError('请求失败'); | ||||||
|  |     } finally { | ||||||
|  |       setIsSubmitting(false);  | ||||||
|  |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   const openTopUpLink = () => { |   const openTopUpLink = () => { | ||||||
| @@ -74,8 +82,8 @@ const TopUp = () => { | |||||||
|             <Button color='green' onClick={openTopUpLink}> |             <Button color='green' onClick={openTopUpLink}> | ||||||
|               获取兑换码 |               获取兑换码 | ||||||
|             </Button> |             </Button> | ||||||
|             <Button color='yellow' onClick={topUp}> |             <Button color='yellow' onClick={topUp} disabled={isSubmitting}> | ||||||
|               充值 |                 {isSubmitting ? '兑换中...' : '兑换'} | ||||||
|             </Button> |             </Button> | ||||||
|           </Form> |           </Form> | ||||||
|         </Grid.Column> |         </Grid.Column> | ||||||
| @@ -92,5 +100,4 @@ const TopUp = () => { | |||||||
|   ); |   ); | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  |  | ||||||
| export default TopUp; | export default TopUp; | ||||||
| @@ -1,6 +1,6 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Button, Form, Header, Segment } from 'semantic-ui-react'; | import { Button, Form, Header, Segment } from 'semantic-ui-react'; | ||||||
| import { useParams } from 'react-router-dom'; | import { useParams, useNavigate } from 'react-router-dom'; | ||||||
| import { API, showError, showSuccess } from '../../helpers'; | import { API, showError, showSuccess } from '../../helpers'; | ||||||
| import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render'; | import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render'; | ||||||
|  |  | ||||||
| @@ -36,7 +36,10 @@ const EditUser = () => { | |||||||
|       showError(error.message); |       showError(error.message); | ||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |   const navigate = useNavigate(); | ||||||
|  |   const handleCancel = () => { | ||||||
|  |     navigate("/setting"); | ||||||
|  |   } | ||||||
|   const loadUser = async () => { |   const loadUser = async () => { | ||||||
|     let res = undefined; |     let res = undefined; | ||||||
|     if (userId) { |     if (userId) { | ||||||
| @@ -99,7 +102,7 @@ const EditUser = () => { | |||||||
|               label='密码' |               label='密码' | ||||||
|               name='password' |               name='password' | ||||||
|               type={'password'} |               type={'password'} | ||||||
|               placeholder={'请输入新的密码'} |               placeholder={'请输入新的密码,最短 8 位'} | ||||||
|               onChange={handleInputChange} |               onChange={handleInputChange} | ||||||
|               value={password} |               value={password} | ||||||
|               autoComplete='new-password' |               autoComplete='new-password' | ||||||
| @@ -176,6 +179,7 @@ const EditUser = () => { | |||||||
|               readOnly |               readOnly | ||||||
|             /> |             /> | ||||||
|           </Form.Field> |           </Form.Field> | ||||||
|  |           <Button onClick={handleCancel}>取消</Button> | ||||||
|           <Button positive onClick={submit}>提交</Button> |           <Button positive onClick={submit}>提交</Button> | ||||||
|         </Form> |         </Form> | ||||||
|       </Segment> |       </Segment> | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user