mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-25 19:03:43 +08:00 
			
		
		
		
	Compare commits
	
		
			160 Commits
		
	
	
		
			v0.6.0-alp
			...
			v0.6.5
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 7bf61f9165 | ||
|  | a10232f43a | ||
|  | af543ab8ec | ||
|  | e086da05b1 | ||
|  | 3af4649b52 | ||
|  | 52c32c0b4a | ||
|  | 3fe2863ff7 | ||
|  | acf8cb6248 | ||
|  | 572fc9ffb8 | ||
|  | 569c04acb0 | ||
|  | 961b4108e6 | ||
|  | 0b8ccb94eb | ||
|  | f586ae0ad8 | ||
|  | 24ed170e7b | ||
|  | f70506eac1 | ||
|  | 8f4d78e24d | ||
|  | cd2707692f | ||
|  | 2ab7d25a80 | ||
|  | f9d914873f | ||
|  | 880e12c855 | ||
|  | 0cb224e62e | ||
|  | a44fb5d482 | ||
|  | eec41849ec | ||
|  | d4347e7a35 | ||
|  | b50b43eb65 | ||
|  | 348adc2b02 | ||
|  | dcf24b98dc | ||
|  | af679e04f4 | ||
|  | 93cbca6a9f | ||
|  | 840ef80d94 | ||
|  | 9a2662af0d | ||
|  | 77f9e75654 | ||
|  | 5b41f57423 | ||
|  | 0bb7db0b44 | ||
|  | 4d61b9937b | ||
|  | 68605800af | ||
|  | c49778c254 | ||
|  | f02c7138ea | ||
|  | ca3228855a | ||
|  | f8cc63f00b | ||
|  | 0a37aa4cbd | ||
|  | 054b00b725 | ||
|  | 76569bb0b6 | ||
|  | 1994256bac | ||
|  | 1f80b0a39f | ||
|  | f73f2e51df | ||
|  | 6f036bd0c9 | ||
|  | fb90747c23 | ||
|  | ed70881a58 | ||
|  | 8b9fa3d6e4 | ||
|  | 8b9813d63b | ||
|  | dc7aaf2de5 | ||
|  | 065da8ef8c | ||
|  | e3cfb1fa52 | ||
|  | f89ae5ad58 | ||
|  | 06a3fc5421 | ||
|  | a9c464ec5a | ||
|  | 3f3c13c98c | ||
|  | 2ba28c72cb | ||
|  | 5e81e19bc8 | ||
|  | 96d7a99312 | ||
|  | 24be9de098 | ||
|  | 5b349efff9 | ||
|  | f76c46d648 | ||
|  | cdfdeea3b4 | ||
|  | 56ddbb842a | ||
|  | 99f81a267c | ||
|  | c243cd5535 | ||
|  | e96b173abe | ||
|  | 4ae311e964 | ||
|  | b14cb748d8 | ||
|  | ade19ba4a2 | ||
|  | 4d86d021c4 | ||
|  | 7a44adb5a7 | ||
|  | 9821bc7281 | ||
|  | 08831881f1 | ||
|  | 0eb2272bb7 | ||
|  | 704ec1a827 | ||
|  | 1d7470d6ad | ||
|  | 1185303346 | ||
|  | c212fcf8d7 | ||
|  | c285e000cc | ||
|  | d25ed4c009 | ||
|  | 7400885fbb | ||
|  | 11af81eb39 | ||
|  | 205aba694f | ||
|  | 8dac3afebc | ||
|  | a07791bf93 | ||
|  | 4bb662c0e4 | ||
|  | 4998d58319 | ||
|  | 190203cf8f | ||
|  | 6325c8e0b4 | ||
|  | b204f6d82b | ||
|  | 752639560f | ||
|  | 996f4d99dd | ||
|  | ebfee3b46c | ||
|  | 3e2e805d61 | ||
|  | 3edf7247c4 | ||
|  | 0926b6206b | ||
|  | 7cd57f3125 | ||
|  | 66efabd5ae | ||
|  | 8ede66a896 | ||
|  | b169173860 | ||
|  | f33555ae78 | ||
|  | c28ec10795 | ||
|  | e3767cbb07 | ||
|  | be9eb59fbb | ||
|  | 89e111ac69 | ||
|  | 2dcef85285 | ||
|  | 79d0cd378a | ||
|  | e99150bdb9 | ||
|  | a72e5fcc9e | ||
|  | 0710f8cd66 | ||
|  | 49cad7d4a5 | ||
|  | a90161cf00 | ||
|  | a45fc7d736 | ||
|  | 45940dcb12 | ||
|  | 969042b001 | ||
|  | 7e7369dbc4 | ||
|  | e54e647170 | ||
|  | 358920c858 | ||
|  | 1ea598c773 | ||
|  | 796be42487 | ||
|  | 5b50eb94e5 | ||
|  | 71c61365eb | ||
|  | b09f979b80 | ||
|  | 12440874b0 | ||
|  | 6ebc99460e | ||
|  | 27ad8bfb98 | ||
|  | 8388aa537f | ||
|  | 2346bf70af | ||
|  | f05b403ca5 | ||
|  | b33616df44 | ||
|  | cf16f44970 | ||
|  | bf2e26a48f | ||
|  | 4fb22ad4ce | ||
|  | 95cfb8e8c9 | ||
|  | c6ace985c2 | ||
|  | 10a926b8f3 | ||
|  | 2df877a352 | ||
|  | 9d8967f7d3 | ||
|  | b35f3523d3 | ||
|  | 82e916b5ff | ||
|  | de18d6fe16 | ||
|  | 1d0b7fb5ae | ||
|  | f9490bb72e | ||
|  | 76467285e8 | ||
|  | df1fd9aa81 | ||
|  | 614c2e0442 | ||
|  | eac6a0b9aa | ||
|  | b747cdbc6f | ||
|  | 6b27d6659a | ||
|  | dc5b781191 | ||
|  | c880b4a9a3 | ||
|  | 565ea58e68 | ||
|  | f141a37a9e | ||
|  | 5b78886ad3 | ||
|  | 87c7c4f0e6 | ||
|  | 4c4a873890 | ||
|  | 0664bdfda1 | 
							
								
								
									
										7
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							| @@ -20,6 +20,13 @@ jobs: | ||||
|       - name: Check out the repo | ||||
|         uses: actions/checkout@v3 | ||||
|  | ||||
|       - name: Check repository URL | ||||
|         run: | | ||||
|           REPO_URL=$(git config --get remote.origin.url) | ||||
|           if [[ $REPO_URL == *"pro" ]]; then | ||||
|             exit 1 | ||||
|           fi       | ||||
|  | ||||
|       - name: Save version info | ||||
|         run: | | ||||
|           git describe --tags > VERSION  | ||||
|   | ||||
							
								
								
									
										7
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -20,6 +20,13 @@ jobs: | ||||
|       - name: Check out the repo | ||||
|         uses: actions/checkout@v3 | ||||
|  | ||||
|       - name: Check repository URL | ||||
|         run: | | ||||
|           REPO_URL=$(git config --get remote.origin.url) | ||||
|           if [[ $REPO_URL == *"pro" ]]; then | ||||
|             exit 1 | ||||
|           fi         | ||||
|  | ||||
|       - name: Save version info | ||||
|         run: | | ||||
|           git describe --tags > VERSION  | ||||
|   | ||||
							
								
								
									
										7
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -21,6 +21,13 @@ jobs: | ||||
|       - name: Check out the repo | ||||
|         uses: actions/checkout@v3 | ||||
|  | ||||
|       - name: Check repository URL | ||||
|         run: | | ||||
|           REPO_URL=$(git config --get remote.origin.url) | ||||
|           if [[ $REPO_URL == *"pro" ]]; then | ||||
|             exit 1 | ||||
|           fi | ||||
|  | ||||
|       - name: Save version info | ||||
|         run: | | ||||
|           git describe --tags > VERSION  | ||||
|   | ||||
							
								
								
									
										10
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -20,10 +20,16 @@ jobs: | ||||
|         uses: actions/checkout@v3 | ||||
|         with: | ||||
|           fetch-depth: 0 | ||||
|       - name: Check repository URL | ||||
|         run: | | ||||
|           REPO_URL=$(git config --get remote.origin.url) | ||||
|           if [[ $REPO_URL == *"pro" ]]; then | ||||
|             exit 1 | ||||
|           fi | ||||
|       - uses: actions/setup-node@v3 | ||||
|         with: | ||||
|           node-version: 16 | ||||
|       - name: Build Frontend (theme default) | ||||
|       - name: Build Frontend | ||||
|         env: | ||||
|           CI: "" | ||||
|         run: | | ||||
| @@ -38,7 +44,7 @@ jobs: | ||||
|       - name: Build Backend (amd64) | ||||
|         run: | | ||||
|           go mod download | ||||
|           go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api | ||||
|           go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api | ||||
|  | ||||
|       - name: Build Backend (arm64) | ||||
|         run: | | ||||
|   | ||||
							
								
								
									
										10
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -20,10 +20,16 @@ jobs: | ||||
|         uses: actions/checkout@v3 | ||||
|         with: | ||||
|           fetch-depth: 0 | ||||
|       - name: Check repository URL | ||||
|         run: | | ||||
|           REPO_URL=$(git config --get remote.origin.url) | ||||
|           if [[ $REPO_URL == *"pro" ]]; then | ||||
|             exit 1 | ||||
|           fi | ||||
|       - uses: actions/setup-node@v3 | ||||
|         with: | ||||
|           node-version: 16 | ||||
|       - name: Build Frontend (theme default) | ||||
|       - name: Build Frontend | ||||
|         env: | ||||
|           CI: "" | ||||
|         run: | | ||||
| @@ -38,7 +44,7 @@ jobs: | ||||
|       - name: Build Backend | ||||
|         run: | | ||||
|           go mod download | ||||
|           go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos | ||||
|           go build -ldflags "-X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api-macos | ||||
|       - name: Release | ||||
|         uses: softprops/action-gh-release@v1 | ||||
|         if: startsWith(github.ref, 'refs/tags/') | ||||
|   | ||||
							
								
								
									
										10
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -23,10 +23,16 @@ jobs: | ||||
|         uses: actions/checkout@v3 | ||||
|         with: | ||||
|           fetch-depth: 0 | ||||
|       - name: Check repository URL | ||||
|         run: | | ||||
|           REPO_URL=$(git config --get remote.origin.url) | ||||
|           if [[ $REPO_URL == *"pro" ]]; then | ||||
|             exit 1 | ||||
|           fi | ||||
|       - uses: actions/setup-node@v3 | ||||
|         with: | ||||
|           node-version: 16 | ||||
|       - name: Build Frontend (theme default) | ||||
|       - name: Build Frontend | ||||
|         env: | ||||
|           CI: "" | ||||
|         run: | | ||||
| @@ -41,7 +47,7 @@ jobs: | ||||
|       - name: Build Backend | ||||
|         run: | | ||||
|           go mod download | ||||
|           go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe | ||||
|           go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api.exe | ||||
|       - name: Release | ||||
|         uses: softprops/action-gh-release@v1 | ||||
|         if: startsWith(github.ref, 'refs/tags/') | ||||
|   | ||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -8,3 +8,4 @@ build | ||||
| logs | ||||
| data | ||||
| /web/node_modules | ||||
| cmd.md | ||||
| @@ -12,6 +12,10 @@ WORKDIR /web/berry | ||||
| RUN npm install | ||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||
|  | ||||
| WORKDIR /web/air | ||||
| RUN npm install | ||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||
|  | ||||
| FROM golang AS builder2 | ||||
|  | ||||
| ENV GO111MODULE=on \ | ||||
|   | ||||
							
								
								
									
										14
									
								
								README.en.md
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								README.en.md
									
									
									
									
									
								
							| @@ -241,17 +241,19 @@ If the channel ID is not provided, load balancing will be used to distribute the | ||||
|     + Example: `SESSION_SECRET=random_string` | ||||
| 3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0. | ||||
|     + Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||
| 4. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. | ||||
| 4. `LOG_SQL_DSN`: When set, a separate database will be used for the `logs` table; please use MySQL or PostgreSQL. | ||||
|     + Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs` | ||||
| 5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. | ||||
|     + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||
| 5. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. | ||||
| 6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. | ||||
|     + Example: `SYNC_FREQUENCY=60` | ||||
| 6. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. | ||||
| 7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. | ||||
|     + Example: `NODE_TYPE=slave` | ||||
| 7. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. | ||||
| 8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. | ||||
|     + Example: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||
| 8. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. | ||||
| 9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. | ||||
|     + Example: `CHANNEL_TEST_FREQUENCY=1440` | ||||
| 9. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. | ||||
| 10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. | ||||
|     + Example: `POLLING_INTERVAL=5` | ||||
|  | ||||
| ### Command Line Parameters | ||||
|   | ||||
							
								
								
									
										13
									
								
								README.ja.md
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								README.ja.md
									
									
									
									
									
								
							| @@ -242,17 +242,18 @@ graph LR | ||||
|     + 例: `SESSION_SECRET=random_string` | ||||
| 3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。 | ||||
|     + 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||
| 4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。 | ||||
| 4. `LOG_SQL_DSN`: を設定すると、`logs`テーブルには独立したデータベースが使用されます。MySQLまたはPostgreSQLを使用してください。 | ||||
| 5. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。 | ||||
|     + 例: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||
| 5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 | ||||
| 6. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 | ||||
|     + 例: `SYNC_FREQUENCY=60` | ||||
| 6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。 | ||||
| 7. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。 | ||||
|     + 例: `NODE_TYPE=slave` | ||||
| 7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 | ||||
| 8. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 | ||||
|     + 例: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||
| 8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 | ||||
| 9. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 | ||||
|     + 例: `CHANNEL_TEST_FREQUENCY=1440` | ||||
| 9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 | ||||
| 10. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 | ||||
|     + 例: `POLLING_INTERVAL=5` | ||||
|  | ||||
| ### コマンドラインパラメータ | ||||
|   | ||||
							
								
								
									
										60
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										60
									
								
								README.md
									
									
									
									
									
								
							| @@ -67,6 +67,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||
|    + [x] [Anthropic Claude 系列模型](https://anthropic.com) | ||||
|    + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) | ||||
|    + [x] [Mistral 系列模型](https://mistral.ai/) | ||||
|    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||
|    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) | ||||
|    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||
| @@ -74,15 +75,20 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|    + [x] [360 智脑](https://ai.360.cn) | ||||
|    + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) | ||||
|    + [x] [Moonshot AI](https://platform.moonshot.cn/) | ||||
|    + [x] [百川大模型](https://platform.baichuan-ai.com) | ||||
|    + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) | ||||
|    + [ ] [MINIMAX](https://api.minimax.chat/) (WIP) | ||||
|    + [x] [MINIMAX](https://api.minimax.chat/) | ||||
|    + [x] [Groq](https://wow.groq.com/) | ||||
|    + [x] [Ollama](https://github.com/ollama/ollama) | ||||
|    + [x] [零一万物](https://platform.lingyiwanwu.com/) | ||||
|    + [x] [阶跃星辰](https://platform.stepfun.com/) | ||||
| 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | ||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||
| 5. 支持**多机部署**,[详见此处](#多机部署)。 | ||||
| 6. 支持**令牌管理**,设置令牌的过期时间和额度。 | ||||
| 6. 支持**令牌管理**,设置令牌的过期时间、额度、允许的 IP 范围以及允许的模型访问。 | ||||
| 7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 | ||||
| 8. 支持**通道管理**,批量创建通道。 | ||||
| 8. 支持**渠道管理**,批量创建渠道。 | ||||
| 9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 | ||||
| 10. 支持渠道**设置模型列表**。 | ||||
| 11. 支持**查看额度明细**。 | ||||
| @@ -96,13 +102,15 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
| 19. 支持丰富的**自定义**设置, | ||||
|     1. 支持自定义系统名称,logo 以及页脚。 | ||||
|     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 | ||||
| 20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。 | ||||
| 20. 支持通过系统访问令牌调用管理 API,进而**在无需二开的情况下扩展和自定义** One API 的功能,详情请参考此处 [API 文档](./docs/API.md)。。 | ||||
| 21. 支持 Cloudflare Turnstile 用户校验。 | ||||
| 22. 支持用户管理,支持**多种用户登录注册方式**: | ||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||
|     + 支持使用飞书进行授权登录。 | ||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||
| 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 | ||||
| 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 | ||||
|  | ||||
| ## 部署 | ||||
| ### 基于 Docker 进行部署 | ||||
| @@ -343,35 +351,41 @@ graph LR | ||||
|      + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 | ||||
|        + 如果报错 `Error 1040: Too many connections`,请适当减小该值。 | ||||
|      + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 | ||||
| 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | ||||
| 4. `LOG_SQL_DSN`:设置之后将为 `logs` 表使用独立的数据库,请使用 MySQL 或 PostgreSQL。 | ||||
| 5. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | ||||
|    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||
| 5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||
| 6. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||
|    + 例子:`MEMORY_CACHE_ENABLED=true` | ||||
| 6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 | ||||
| 7. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 | ||||
|    + 例子:`SYNC_FREQUENCY=60` | ||||
| 7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | ||||
| 8. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | ||||
|    + 例子:`NODE_TYPE=slave` | ||||
| 8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | ||||
| 9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | ||||
|    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` | ||||
| 9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | ||||
|    + 例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||
| 10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||
| 10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | ||||
| 11. 例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||
| 12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||
|     + 例子:`POLLING_INTERVAL=5` | ||||
| 11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||
| 13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||
|     + 例子:`BATCH_UPDATE_ENABLED=true` | ||||
|     + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 | ||||
| 12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||
| 14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||
|     + 例子:`BATCH_UPDATE_INTERVAL=5` | ||||
| 13. 请求频率限制: | ||||
| 15. 请求频率限制: | ||||
|     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 | ||||
|     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 | ||||
| 14. 编码器缓存设置: | ||||
| 16. 编码器缓存设置: | ||||
|     + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 | ||||
|     + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 | ||||
| 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | ||||
| 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | ||||
| 17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 | ||||
| 18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 | ||||
| 17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | ||||
| 18. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | ||||
| 19. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 | ||||
| 20. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 | ||||
| 21. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 | ||||
| 22. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 | ||||
| 23. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 | ||||
| 24. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 | ||||
| 25. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 | ||||
|  | ||||
| ### 命令行参数 | ||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||
| @@ -410,7 +424,7 @@ https://openai.justsong.cn | ||||
|    + 检查你的接口地址和 API Key 有没有填对。 | ||||
|    + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 | ||||
| 6. 报错:`当前分组负载已饱和,请稍后再试` | ||||
|    + 上游通道 429 了。 | ||||
|    + 上游渠道 429 了。 | ||||
| 7. 升级之后我的数据会丢失吗? | ||||
|    + 如果使用 MySQL,不会。 | ||||
|    + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 | ||||
| @@ -418,8 +432,8 @@ https://openai.justsong.cn | ||||
|    + 一般情况下不需要,系统将在初始化的时候自动调整。 | ||||
|    + 如果需要的话,我会在更新日志中说明,并给出脚本。 | ||||
| 9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`? | ||||
|    + 这是检测到 ability 表里有些记录的通道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的通道。 | ||||
|    + 对于每一个通道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该通道支持该模型。 | ||||
|    + 这是检测到 ability 表里有些记录的渠道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的渠道。 | ||||
|    + 对于每一个渠道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该渠道支持该模型。 | ||||
|  | ||||
| ## 相关项目 | ||||
| * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | ||||
|   | ||||
							
								
								
									
										29
									
								
								common/blacklist/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								common/blacklist/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| package blacklist | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| ) | ||||
|  | ||||
| var blackList sync.Map | ||||
|  | ||||
| func init() { | ||||
| 	blackList = sync.Map{} | ||||
| } | ||||
|  | ||||
| func userId2Key(id int) string { | ||||
| 	return fmt.Sprintf("userid_%d", id) | ||||
| } | ||||
|  | ||||
| func BanUser(id int) { | ||||
| 	blackList.Store(userId2Key(id), true) | ||||
| } | ||||
|  | ||||
| func UnbanUser(id int) { | ||||
| 	blackList.Delete(userId2Key(id)) | ||||
| } | ||||
|  | ||||
| func IsUserBanned(id int) bool { | ||||
| 	_, ok := blackList.Load(userId2Key(id)) | ||||
| 	return ok | ||||
| } | ||||
| @@ -1,7 +1,7 @@ | ||||
| package config | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/env" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| @@ -52,6 +52,7 @@ var EmailDomainWhitelist = []string{ | ||||
| } | ||||
|  | ||||
| var DebugEnabled = os.Getenv("DEBUG") == "true" | ||||
| var DebugSQLEnabled = os.Getenv("DEBUG_SQL") == "true" | ||||
| var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" | ||||
|  | ||||
| var LogConsumeEnabled = true | ||||
| @@ -65,21 +66,27 @@ var SMTPToken = "" | ||||
| var GitHubClientId = "" | ||||
| var GitHubClientSecret = "" | ||||
|  | ||||
| var LarkClientId = "" | ||||
| var LarkClientSecret = "" | ||||
|  | ||||
| var WeChatServerAddress = "" | ||||
| var WeChatServerToken = "" | ||||
| var WeChatAccountQRCodeImageURL = "" | ||||
|  | ||||
| var MessagePusherAddress = "" | ||||
| var MessagePusherToken = "" | ||||
|  | ||||
| var TurnstileSiteKey = "" | ||||
| var TurnstileSecretKey = "" | ||||
|  | ||||
| var QuotaForNewUser = 0 | ||||
| var QuotaForInviter = 0 | ||||
| var QuotaForInvitee = 0 | ||||
| var QuotaForNewUser int64 = 0 | ||||
| var QuotaForInviter int64 = 0 | ||||
| var QuotaForInvitee int64 = 0 | ||||
| var ChannelDisableThreshold = 5.0 | ||||
| var AutomaticDisableChannelEnabled = false | ||||
| var AutomaticEnableChannelEnabled = false | ||||
| var QuotaRemindThreshold = 1000 | ||||
| var PreConsumedQuota = 500 | ||||
| var QuotaRemindThreshold int64 = 1000 | ||||
| var PreConsumedQuota int64 = 500 | ||||
| var ApproximateTokenEnabled = false | ||||
| var RetryTimes = 0 | ||||
|  | ||||
| @@ -90,28 +97,29 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" | ||||
| var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | ||||
| var RequestInterval = time.Duration(requestInterval) * time.Second | ||||
|  | ||||
| var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second | ||||
| var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second | ||||
|  | ||||
| var BatchUpdateEnabled = false | ||||
| var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) | ||||
| var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5) | ||||
|  | ||||
| var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second | ||||
| var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second | ||||
|  | ||||
| var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") | ||||
| var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE") | ||||
|  | ||||
| var Theme = helper.GetOrDefaultEnvString("THEME", "default") | ||||
| var Theme = env.String("THEME", "default") | ||||
| var ValidThemes = map[string]bool{ | ||||
| 	"default": true, | ||||
| 	"berry":   true, | ||||
| 	"air":     true, | ||||
| } | ||||
|  | ||||
| // All duration's unit is seconds | ||||
| // Shouldn't larger then RateLimitKeyExpirationDuration | ||||
| var ( | ||||
| 	GlobalApiRateLimitNum            = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) | ||||
| 	GlobalApiRateLimitNum            = env.Int("GLOBAL_API_RATE_LIMIT", 180) | ||||
| 	GlobalApiRateLimitDuration int64 = 3 * 60 | ||||
|  | ||||
| 	GlobalWebRateLimitNum            = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) | ||||
| 	GlobalWebRateLimitNum            = env.Int("GLOBAL_WEB_RATE_LIMIT", 60) | ||||
| 	GlobalWebRateLimitDuration int64 = 3 * 60 | ||||
|  | ||||
| 	UploadRateLimitNum            = 10 | ||||
| @@ -125,3 +133,13 @@ var ( | ||||
| ) | ||||
|  | ||||
| var RateLimitKeyExpirationDuration = 20 * time.Minute | ||||
|  | ||||
| var EnableMetric = env.Bool("ENABLE_METRIC", false) | ||||
| var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10) | ||||
| var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) | ||||
| var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024) | ||||
| var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) | ||||
|  | ||||
| var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") | ||||
|  | ||||
| var GeminiVersion = env.String("GEMINI_VERSION", "v1") | ||||
|   | ||||
							
								
								
									
										9
									
								
								common/config/key.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								common/config/key.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | ||||
| package config | ||||
|  | ||||
| const ( | ||||
| 	KeyPrefix = "cfg_" | ||||
|  | ||||
| 	KeyAPIVersion = KeyPrefix + "api_version" | ||||
| 	KeyLibraryID  = KeyPrefix + "library_id" | ||||
| 	KeyPlugin     = KeyPrefix + "plugin" | ||||
| ) | ||||
| @@ -4,101 +4,3 @@ import "time" | ||||
|  | ||||
| var StartTime = time.Now().Unix() // unit: second | ||||
| var Version = "v0.0.0"            // this hard coding will be replaced automatically when building, no need to manually change | ||||
|  | ||||
| const ( | ||||
| 	RoleGuestUser  = 0 | ||||
| 	RoleCommonUser = 1 | ||||
| 	RoleAdminUser  = 10 | ||||
| 	RoleRootUser   = 100 | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	UserStatusEnabled  = 1 // don't use 0, 0 is the default value! | ||||
| 	UserStatusDisabled = 2 // also don't use 0 | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	TokenStatusEnabled   = 1 // don't use 0, 0 is the default value! | ||||
| 	TokenStatusDisabled  = 2 // also don't use 0 | ||||
| 	TokenStatusExpired   = 3 | ||||
| 	TokenStatusExhausted = 4 | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	RedemptionCodeStatusEnabled  = 1 // don't use 0, 0 is the default value! | ||||
| 	RedemptionCodeStatusDisabled = 2 // also don't use 0 | ||||
| 	RedemptionCodeStatusUsed     = 3 // also don't use 0 | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	ChannelStatusUnknown          = 0 | ||||
| 	ChannelStatusEnabled          = 1 // don't use 0, 0 is the default value! | ||||
| 	ChannelStatusManuallyDisabled = 2 // also don't use 0 | ||||
| 	ChannelStatusAutoDisabled     = 3 | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	ChannelTypeUnknown        = 0 | ||||
| 	ChannelTypeOpenAI         = 1 | ||||
| 	ChannelTypeAPI2D          = 2 | ||||
| 	ChannelTypeAzure          = 3 | ||||
| 	ChannelTypeCloseAI        = 4 | ||||
| 	ChannelTypeOpenAISB       = 5 | ||||
| 	ChannelTypeOpenAIMax      = 6 | ||||
| 	ChannelTypeOhMyGPT        = 7 | ||||
| 	ChannelTypeCustom         = 8 | ||||
| 	ChannelTypeAILS           = 9 | ||||
| 	ChannelTypeAIProxy        = 10 | ||||
| 	ChannelTypePaLM           = 11 | ||||
| 	ChannelTypeAPI2GPT        = 12 | ||||
| 	ChannelTypeAIGC2D         = 13 | ||||
| 	ChannelTypeAnthropic      = 14 | ||||
| 	ChannelTypeBaidu          = 15 | ||||
| 	ChannelTypeZhipu          = 16 | ||||
| 	ChannelTypeAli            = 17 | ||||
| 	ChannelTypeXunfei         = 18 | ||||
| 	ChannelType360            = 19 | ||||
| 	ChannelTypeOpenRouter     = 20 | ||||
| 	ChannelTypeAIProxyLibrary = 21 | ||||
| 	ChannelTypeFastGPT        = 22 | ||||
| 	ChannelTypeTencent        = 23 | ||||
| 	ChannelTypeGemini         = 24 | ||||
| 	ChannelTypeMoonshot       = 25 | ||||
| ) | ||||
|  | ||||
| var ChannelBaseURLs = []string{ | ||||
| 	"",                              // 0 | ||||
| 	"https://api.openai.com",        // 1 | ||||
| 	"https://oa.api2d.net",          // 2 | ||||
| 	"",                              // 3 | ||||
| 	"https://api.closeai-proxy.xyz", // 4 | ||||
| 	"https://api.openai-sb.com",     // 5 | ||||
| 	"https://api.openaimax.com",     // 6 | ||||
| 	"https://api.ohmygpt.com",       // 7 | ||||
| 	"",                              // 8 | ||||
| 	"https://api.caipacity.com",     // 9 | ||||
| 	"https://api.aiproxy.io",        // 10 | ||||
| 	"https://generativelanguage.googleapis.com", // 11 | ||||
| 	"https://api.api2gpt.com",                   // 12 | ||||
| 	"https://api.aigc2d.com",                    // 13 | ||||
| 	"https://api.anthropic.com",                 // 14 | ||||
| 	"https://aip.baidubce.com",                  // 15 | ||||
| 	"https://open.bigmodel.cn",                  // 16 | ||||
| 	"https://dashscope.aliyuncs.com",            // 17 | ||||
| 	"",                                          // 18 | ||||
| 	"https://ai.360.cn",                         // 19 | ||||
| 	"https://openrouter.ai/api",                 // 20 | ||||
| 	"https://api.aiproxy.io",                    // 21 | ||||
| 	"https://fastgpt.run/api/openapi",           // 22 | ||||
| 	"https://hunyuan.cloud.tencent.com",         // 23 | ||||
| 	"https://generativelanguage.googleapis.com", // 24 | ||||
| 	"https://api.moonshot.cn",                   // 25 | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	ConfigKeyPrefix = "cfg_" | ||||
|  | ||||
| 	ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version" | ||||
| 	ConfigKeyLibraryID  = ConfigKeyPrefix + "library_id" | ||||
| 	ConfigKeyPlugin     = ConfigKeyPrefix + "plugin" | ||||
| ) | ||||
|   | ||||
							
								
								
									
										6
									
								
								common/conv/any.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								common/conv/any.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| package conv | ||||
|  | ||||
| func AsString(v any) string { | ||||
| 	str, _ := v.(string) | ||||
| 	return str | ||||
| } | ||||
| @@ -1,9 +1,12 @@ | ||||
| package common | ||||
|  | ||||
| import "github.com/songquanpeng/one-api/common/helper" | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/common/env" | ||||
| ) | ||||
|  | ||||
| var UsingSQLite = false | ||||
| var UsingPostgreSQL = false | ||||
| var UsingMySQL = false | ||||
|  | ||||
| var SQLitePath = "one-api.db" | ||||
| var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) | ||||
| var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000) | ||||
|   | ||||
							
								
								
									
										42
									
								
								common/env/helper.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								common/env/helper.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | ||||
| package env | ||||
|  | ||||
| import ( | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| ) | ||||
|  | ||||
| func Bool(env string, defaultValue bool) bool { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return os.Getenv(env) == "true" | ||||
| } | ||||
|  | ||||
| func Int(env string, defaultValue int) int { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	num, err := strconv.Atoi(os.Getenv(env)) | ||||
| 	if err != nil { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|  | ||||
| func Float64(env string, defaultValue float64) float64 { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	num, err := strconv.ParseFloat(os.Getenv(env), 64) | ||||
| 	if err != nil { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|  | ||||
| func String(env string, defaultValue string) string { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return os.Getenv(env) | ||||
| } | ||||
| @@ -8,12 +8,24 @@ import ( | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||
| const KeyRequestBody = "key_request_body" | ||||
|  | ||||
| func GetRequestBody(c *gin.Context) ([]byte, error) { | ||||
| 	requestBody, _ := c.Get(KeyRequestBody) | ||||
| 	if requestBody != nil { | ||||
| 		return requestBody.([]byte), nil | ||||
| 	} | ||||
| 	requestBody, err := io.ReadAll(c.Request.Body) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	err = c.Request.Body.Close() | ||||
| 	_ = c.Request.Body.Close() | ||||
| 	c.Set(KeyRequestBody, requestBody) | ||||
| 	return requestBody.([]byte), nil | ||||
| } | ||||
|  | ||||
| func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||
| 	requestBody, err := GetRequestBody(c) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
| @@ -2,18 +2,14 @@ package helper | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"html/template" | ||||
| 	"log" | ||||
| 	"math/rand" | ||||
| 	"net" | ||||
| 	"os" | ||||
| 	"os/exec" | ||||
| 	"runtime" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func OpenBrowser(url string) { | ||||
| @@ -81,31 +77,6 @@ func Bytes2Size(num int64) string { | ||||
| 	return numStr + " " + unit | ||||
| } | ||||
|  | ||||
| func Seconds2Time(num int) (time string) { | ||||
| 	if num/31104000 > 0 { | ||||
| 		time += strconv.Itoa(num/31104000) + " 年 " | ||||
| 		num %= 31104000 | ||||
| 	} | ||||
| 	if num/2592000 > 0 { | ||||
| 		time += strconv.Itoa(num/2592000) + " 个月 " | ||||
| 		num %= 2592000 | ||||
| 	} | ||||
| 	if num/86400 > 0 { | ||||
| 		time += strconv.Itoa(num/86400) + " 天 " | ||||
| 		num %= 86400 | ||||
| 	} | ||||
| 	if num/3600 > 0 { | ||||
| 		time += strconv.Itoa(num/3600) + " 小时 " | ||||
| 		num %= 3600 | ||||
| 	} | ||||
| 	if num/60 > 0 { | ||||
| 		time += strconv.Itoa(num/60) + " 分钟 " | ||||
| 		num %= 60 | ||||
| 	} | ||||
| 	time += strconv.Itoa(num) + " 秒" | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func Interface2String(inter interface{}) string { | ||||
| 	switch inter := inter.(type) { | ||||
| 	case string: | ||||
| @@ -130,51 +101,8 @@ func IntMax(a int, b int) int { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetUUID() string { | ||||
| 	code := uuid.New().String() | ||||
| 	code = strings.Replace(code, "-", "", -1) | ||||
| 	return code | ||||
| } | ||||
|  | ||||
| const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" | ||||
|  | ||||
| func init() { | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
| } | ||||
|  | ||||
| func GenerateKey() string { | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
| 	key := make([]byte, 48) | ||||
| 	for i := 0; i < 16; i++ { | ||||
| 		key[i] = keyChars[rand.Intn(len(keyChars))] | ||||
| 	} | ||||
| 	uuid_ := GetUUID() | ||||
| 	for i := 0; i < 32; i++ { | ||||
| 		c := uuid_[i] | ||||
| 		if i%2 == 0 && c >= 'a' && c <= 'z' { | ||||
| 			c = c - 'a' + 'A' | ||||
| 		} | ||||
| 		key[i+16] = c | ||||
| 	} | ||||
| 	return string(key) | ||||
| } | ||||
|  | ||||
| func GetRandomString(length int) string { | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
| 	key := make([]byte, length) | ||||
| 	for i := 0; i < length; i++ { | ||||
| 		key[i] = keyChars[rand.Intn(len(keyChars))] | ||||
| 	} | ||||
| 	return string(key) | ||||
| } | ||||
|  | ||||
| func GetTimestamp() int64 { | ||||
| 	return time.Now().Unix() | ||||
| } | ||||
|  | ||||
| func GetTimeString() string { | ||||
| 	now := time.Now() | ||||
| 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) | ||||
| func GenRequestID() string { | ||||
| 	return GetTimeString() + random.GetRandomNumberString(8) | ||||
| } | ||||
|  | ||||
| func Max(a int, b int) int { | ||||
| @@ -185,25 +113,6 @@ func Max(a int, b int) int { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetOrDefaultEnvInt(env string, defaultValue int) int { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	num, err := strconv.Atoi(os.Getenv(env)) | ||||
| 	if err != nil { | ||||
| 		logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|  | ||||
| func GetOrDefaultEnvString(env string, defaultValue string) string { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return os.Getenv(env) | ||||
| } | ||||
|  | ||||
| func AssignOrDefault(value string, defaultValue string) string { | ||||
| 	if len(value) != 0 { | ||||
| 		return value | ||||
|   | ||||
							
								
								
									
										15
									
								
								common/helper/time.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								common/helper/time.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| package helper | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func GetTimestamp() int64 { | ||||
| 	return time.Now().Unix() | ||||
| } | ||||
|  | ||||
| func GetTimeString() string { | ||||
| 	now := time.Now() | ||||
| 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) | ||||
| } | ||||
| @@ -4,6 +4,8 @@ import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"os" | ||||
| @@ -13,14 +15,12 @@ import ( | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	loggerDEBUG = "DEBUG" | ||||
| 	loggerINFO  = "INFO" | ||||
| 	loggerWarn  = "WARN" | ||||
| 	loggerError = "ERR" | ||||
| ) | ||||
|  | ||||
| const maxLogCount = 1000000 | ||||
|  | ||||
| var logCount int | ||||
| var setupLogLock sync.Mutex | ||||
| var setupLogWorking bool | ||||
|  | ||||
| @@ -55,6 +55,12 @@ func SysError(s string) { | ||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | ||||
| } | ||||
|  | ||||
| func Debug(ctx context.Context, msg string) { | ||||
| 	if config.DebugEnabled { | ||||
| 		logHelper(ctx, loggerDEBUG, msg) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func Info(ctx context.Context, msg string) { | ||||
| 	logHelper(ctx, loggerINFO, msg) | ||||
| } | ||||
| @@ -67,6 +73,10 @@ func Error(ctx context.Context, msg string) { | ||||
| 	logHelper(ctx, loggerError, msg) | ||||
| } | ||||
|  | ||||
| func Debugf(ctx context.Context, format string, a ...any) { | ||||
| 	Debug(ctx, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func Infof(ctx context.Context, format string, a ...any) { | ||||
| 	Info(ctx, fmt.Sprintf(format, a...)) | ||||
| } | ||||
| @@ -85,11 +95,12 @@ func logHelper(ctx context.Context, level string, msg string) { | ||||
| 		writer = gin.DefaultWriter | ||||
| 	} | ||||
| 	id := ctx.Value(RequestIdKey) | ||||
| 	if id == nil { | ||||
| 		id = helper.GenRequestID() | ||||
| 	} | ||||
| 	now := time.Now() | ||||
| 	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) | ||||
| 	logCount++ // we don't need accurate count, so no lock here | ||||
| 	if logCount > maxLogCount && !setupLogWorking { | ||||
| 		logCount = 0 | ||||
| 	if !setupLogWorking { | ||||
| 		setupLogWorking = true | ||||
| 		go func() { | ||||
| 			SetupLogger() | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| package common | ||||
| package message | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| @@ -12,6 +12,9 @@ import ( | ||||
| ) | ||||
| 
 | ||||
| func SendEmail(subject string, receiver string, content string) error { | ||||
| 	if receiver == "" { | ||||
| 		return fmt.Errorf("receiver is empty") | ||||
| 	} | ||||
| 	if config.SMTPFrom == "" { // for compatibility | ||||
| 		config.SMTPFrom = config.SMTPAccount | ||||
| 	} | ||||
							
								
								
									
										22
									
								
								common/message/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								common/message/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| package message | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	ByAll           = "all" | ||||
| 	ByEmail         = "email" | ||||
| 	ByMessagePusher = "message_pusher" | ||||
| ) | ||||
|  | ||||
| func Notify(by string, title string, description string, content string) error { | ||||
| 	if by == ByEmail { | ||||
| 		return SendEmail(title, config.RootUserEmail, content) | ||||
| 	} | ||||
| 	if by == ByMessagePusher { | ||||
| 		return SendMessage(title, description, content) | ||||
| 	} | ||||
| 	return fmt.Errorf("unknown notify method: %s", by) | ||||
| } | ||||
							
								
								
									
										53
									
								
								common/message/message-pusher.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								common/message/message-pusher.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | ||||
| package message | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| type request struct { | ||||
| 	Title       string `json:"title"` | ||||
| 	Description string `json:"description"` | ||||
| 	Content     string `json:"content"` | ||||
| 	URL         string `json:"url"` | ||||
| 	Channel     string `json:"channel"` | ||||
| 	Token       string `json:"token"` | ||||
| } | ||||
|  | ||||
| type response struct { | ||||
| 	Success bool   `json:"success"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| func SendMessage(title string, description string, content string) error { | ||||
| 	if config.MessagePusherAddress == "" { | ||||
| 		return errors.New("message pusher address is not set") | ||||
| 	} | ||||
| 	req := request{ | ||||
| 		Title:       title, | ||||
| 		Description: description, | ||||
| 		Content:     content, | ||||
| 		Token:       config.MessagePusherToken, | ||||
| 	} | ||||
| 	data, err := json.Marshal(req) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	resp, err := http.Post(config.MessagePusherAddress, | ||||
| 		"application/json", bytes.NewBuffer(data)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	var res response | ||||
| 	err = json.NewDecoder(resp.Body).Decode(&res) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if !res.Success { | ||||
| 		return errors.New(res.Message) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										52
									
								
								common/network/ip.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								common/network/ip.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | ||||
| package network | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"net" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func splitSubnets(subnets string) []string { | ||||
| 	res := strings.Split(subnets, ",") | ||||
| 	for i := 0; i < len(res); i++ { | ||||
| 		res[i] = strings.TrimSpace(res[i]) | ||||
| 	} | ||||
| 	return res | ||||
| } | ||||
|  | ||||
| func isValidSubnet(subnet string) error { | ||||
| 	_, _, err := net.ParseCIDR(subnet) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to parse subnet: %w", err) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func isIpInSubnet(ctx context.Context, ip string, subnet string) bool { | ||||
| 	_, ipNet, err := net.ParseCIDR(subnet) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf(ctx, "failed to parse subnet: %s", err.Error()) | ||||
| 		return false | ||||
| 	} | ||||
| 	return ipNet.Contains(net.ParseIP(ip)) | ||||
| } | ||||
|  | ||||
| func IsValidSubnets(subnets string) error { | ||||
| 	for _, subnet := range splitSubnets(subnets) { | ||||
| 		if err := isValidSubnet(subnet); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool { | ||||
| 	for _, subnet := range splitSubnets(subnets) { | ||||
| 		if isIpInSubnet(ctx, ip, subnet) { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
							
								
								
									
										19
									
								
								common/network/ip_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								common/network/ip_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| package network | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"testing" | ||||
|  | ||||
| 	. "github.com/smartystreets/goconvey/convey" | ||||
| ) | ||||
|  | ||||
| func TestIsIpInSubnet(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	ip1 := "192.168.0.5" | ||||
| 	ip2 := "125.216.250.89" | ||||
| 	subnet := "192.168.0.0/24" | ||||
| 	Convey("TestIsIpInSubnet", t, func() { | ||||
| 		So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue) | ||||
| 		So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse) | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										61
									
								
								common/random/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								common/random/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | ||||
| package random | ||||
|  | ||||
| import ( | ||||
| 	"github.com/google/uuid" | ||||
| 	"math/rand" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func GetUUID() string { | ||||
| 	code := uuid.New().String() | ||||
| 	code = strings.Replace(code, "-", "", -1) | ||||
| 	return code | ||||
| } | ||||
|  | ||||
| const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" | ||||
| const keyNumbers = "0123456789" | ||||
|  | ||||
| func init() { | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
| } | ||||
|  | ||||
| func GenerateKey() string { | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
| 	key := make([]byte, 48) | ||||
| 	for i := 0; i < 16; i++ { | ||||
| 		key[i] = keyChars[rand.Intn(len(keyChars))] | ||||
| 	} | ||||
| 	uuid_ := GetUUID() | ||||
| 	for i := 0; i < 32; i++ { | ||||
| 		c := uuid_[i] | ||||
| 		if i%2 == 0 && c >= 'a' && c <= 'z' { | ||||
| 			c = c - 'a' + 'A' | ||||
| 		} | ||||
| 		key[i+16] = c | ||||
| 	} | ||||
| 	return string(key) | ||||
| } | ||||
|  | ||||
| func GetRandomString(length int) string { | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
| 	key := make([]byte, length) | ||||
| 	for i := 0; i < length; i++ { | ||||
| 		key[i] = keyChars[rand.Intn(len(keyChars))] | ||||
| 	} | ||||
| 	return string(key) | ||||
| } | ||||
|  | ||||
| func GetRandomNumberString(length int) string { | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
| 	key := make([]byte, length) | ||||
| 	for i := 0; i < length; i++ { | ||||
| 		key[i] = keyNumbers[rand.Intn(len(keyNumbers))] | ||||
| 	} | ||||
| 	return string(key) | ||||
| } | ||||
|  | ||||
| // RandRange returns a random number between min and max (max is not included) | ||||
| func RandRange(min, max int) int { | ||||
| 	return min + rand.Intn(max-min) | ||||
| } | ||||
| @@ -5,7 +5,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| ) | ||||
|  | ||||
| func LogQuota(quota int) string { | ||||
| func LogQuota(quota int64) string { | ||||
| 	if config.DisplayInCurrencyEnabled { | ||||
| 		return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) | ||||
| 	} else { | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| package controller | ||||
| package auth | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| @@ -7,10 +7,10 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| @@ -133,8 +133,8 @@ func GitHubOAuth(c *gin.Context) { | ||||
| 				user.DisplayName = "GitHub User" | ||||
| 			} | ||||
| 			user.Email = githubUser.Email | ||||
| 			user.Role = common.RoleCommonUser | ||||
| 			user.Status = common.UserStatusEnabled | ||||
| 			user.Role = model.RoleCommonUser | ||||
| 			user.Status = model.UserStatusEnabled | ||||
| 
 | ||||
| 			if err := user.Insert(0); err != nil { | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| @@ -152,14 +152,14 @@ func GitHubOAuth(c *gin.Context) { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if user.Status != common.UserStatusEnabled { | ||||
| 	if user.Status != model.UserStatusEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "用户已被封禁", | ||||
| 			"success": false, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	setupLogin(&user, c) | ||||
| 	controller.SetupLogin(&user, c) | ||||
| } | ||||
| 
 | ||||
| func GitHubBind(c *gin.Context) { | ||||
| @@ -219,7 +219,7 @@ func GitHubBind(c *gin.Context) { | ||||
| 
 | ||||
| func GenerateOAuthCode(c *gin.Context) { | ||||
| 	session := sessions.Default(c) | ||||
| 	state := helper.GetRandomString(12) | ||||
| 	state := random.GetRandomString(12) | ||||
| 	session.Set("oauth_state", state) | ||||
| 	err := session.Save() | ||||
| 	if err != nil { | ||||
							
								
								
									
										200
									
								
								controller/auth/lark.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										200
									
								
								controller/auth/lark.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,200 @@ | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type LarkOAuthResponse struct { | ||||
| 	AccessToken string `json:"access_token"` | ||||
| } | ||||
|  | ||||
| type LarkUser struct { | ||||
| 	Name   string `json:"name"` | ||||
| 	OpenID string `json:"open_id"` | ||||
| } | ||||
|  | ||||
| func getLarkUserInfoByCode(code string) (*LarkUser, error) { | ||||
| 	if code == "" { | ||||
| 		return nil, errors.New("无效的参数") | ||||
| 	} | ||||
| 	values := map[string]string{ | ||||
| 		"client_id":     config.LarkClientId, | ||||
| 		"client_secret": config.LarkClientSecret, | ||||
| 		"code":          code, | ||||
| 		"grant_type":    "authorization_code", | ||||
| 		"redirect_uri":  fmt.Sprintf("%s/oauth/lark", config.ServerAddress), | ||||
| 	} | ||||
| 	jsonData, err := json.Marshal(values) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/json") | ||||
| 	req.Header.Set("Accept", "application/json") | ||||
| 	client := http.Client{ | ||||
| 		Timeout: 5 * time.Second, | ||||
| 	} | ||||
| 	res, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		logger.SysLog(err.Error()) | ||||
| 		return nil, errors.New("无法连接至飞书服务器,请稍后重试!") | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
| 	var oAuthResponse LarkOAuthResponse | ||||
| 	err = json.NewDecoder(res.Body).Decode(&oAuthResponse) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) | ||||
| 	res2, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		logger.SysLog(err.Error()) | ||||
| 		return nil, errors.New("无法连接至飞书服务器,请稍后重试!") | ||||
| 	} | ||||
| 	var larkUser LarkUser | ||||
| 	err = json.NewDecoder(res2.Body).Decode(&larkUser) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &larkUser, nil | ||||
| } | ||||
|  | ||||
| func LarkOAuth(c *gin.Context) { | ||||
| 	session := sessions.Default(c) | ||||
| 	state := c.Query("state") | ||||
| 	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { | ||||
| 		c.JSON(http.StatusForbidden, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "state is empty or not same", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	username := session.Get("username") | ||||
| 	if username != nil { | ||||
| 		LarkBind(c) | ||||
| 		return | ||||
| 	} | ||||
| 	code := c.Query("code") | ||||
| 	larkUser, err := getLarkUserInfoByCode(code) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user := model.User{ | ||||
| 		LarkId: larkUser.OpenID, | ||||
| 	} | ||||
| 	if model.IsLarkIdAlreadyTaken(user.LarkId) { | ||||
| 		err := user.FillUserByLarkId() | ||||
| 		if err != nil { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": err.Error(), | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} else { | ||||
| 		if config.RegisterEnabled { | ||||
| 			user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1) | ||||
| 			if larkUser.Name != "" { | ||||
| 				user.DisplayName = larkUser.Name | ||||
| 			} else { | ||||
| 				user.DisplayName = "Lark User" | ||||
| 			} | ||||
| 			user.Role = model.RoleCommonUser | ||||
| 			user.Status = model.UserStatusEnabled | ||||
|  | ||||
| 			if err := user.Insert(0); err != nil { | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| 					"success": false, | ||||
| 					"message": err.Error(), | ||||
| 				}) | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "管理员关闭了新用户注册", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if user.Status != model.UserStatusEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "用户已被封禁", | ||||
| 			"success": false, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	controller.SetupLogin(&user, c) | ||||
| } | ||||
|  | ||||
| func LarkBind(c *gin.Context) { | ||||
| 	code := c.Query("code") | ||||
| 	larkUser, err := getLarkUserInfoByCode(code) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user := model.User{ | ||||
| 		LarkId: larkUser.OpenID, | ||||
| 	} | ||||
| 	if model.IsLarkIdAlreadyTaken(user.LarkId) { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "该飞书账户已被绑定", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	session := sessions.Default(c) | ||||
| 	id := session.Get("id") | ||||
| 	// id := c.GetInt("id")  // critical bug! | ||||
| 	user.Id = id.(int) | ||||
| 	err = user.FillUserById() | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user.LarkId = larkUser.OpenID | ||||
| 	err = user.Update(false) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "bind", | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
| @@ -1,12 +1,12 @@ | ||||
| package controller | ||||
| package auth | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| @@ -83,8 +83,8 @@ func WeChatAuth(c *gin.Context) { | ||||
| 		if config.RegisterEnabled { | ||||
| 			user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) | ||||
| 			user.DisplayName = "WeChat User" | ||||
| 			user.Role = common.RoleCommonUser | ||||
| 			user.Status = common.UserStatusEnabled | ||||
| 			user.Role = model.RoleCommonUser | ||||
| 			user.Status = model.UserStatusEnabled | ||||
| 
 | ||||
| 			if err := user.Insert(0); err != nil { | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| @@ -102,14 +102,14 @@ func WeChatAuth(c *gin.Context) { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if user.Status != common.UserStatusEnabled { | ||||
| 	if user.Status != model.UserStatusEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "用户已被封禁", | ||||
| 			"success": false, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	setupLogin(&user, c) | ||||
| 	controller.SetupLogin(&user, c) | ||||
| } | ||||
| 
 | ||||
| func WeChatBind(c *gin.Context) { | ||||
| @@ -8,8 +8,8 @@ import ( | ||||
| ) | ||||
|  | ||||
| func GetSubscription(c *gin.Context) { | ||||
| 	var remainQuota int | ||||
| 	var usedQuota int | ||||
| 	var remainQuota int64 | ||||
| 	var usedQuota int64 | ||||
| 	var err error | ||||
| 	var token *model.Token | ||||
| 	var expiredTime int64 | ||||
| @@ -60,7 +60,7 @@ func GetSubscription(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func GetUsage(c *gin.Context) { | ||||
| 	var quota int | ||||
| 	var quota int64 | ||||
| 	var err error | ||||
| 	var token *model.Token | ||||
| 	if config.DisplayTokenStatEnabled { | ||||
|   | ||||
| @@ -4,11 +4,12 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"github.com/songquanpeng/one-api/monitor" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"github.com/songquanpeng/one-api/relay/client" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| @@ -95,7 +96,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He | ||||
| 	for k := range headers { | ||||
| 		req.Header.Add(k, headers.Get(k)) | ||||
| 	} | ||||
| 	res, err := util.HTTPClient.Do(req) | ||||
| 	res, err := client.HTTPClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -203,28 +204,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { | ||||
| } | ||||
|  | ||||
| func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||
| 	baseURL := common.ChannelBaseURLs[channel.Type] | ||||
| 	baseURL := channeltype.ChannelBaseURLs[channel.Type] | ||||
| 	if channel.GetBaseURL() == "" { | ||||
| 		channel.BaseURL = &baseURL | ||||
| 	} | ||||
| 	switch channel.Type { | ||||
| 	case common.ChannelTypeOpenAI: | ||||
| 	case channeltype.OpenAI: | ||||
| 		if channel.GetBaseURL() != "" { | ||||
| 			baseURL = channel.GetBaseURL() | ||||
| 		} | ||||
| 	case common.ChannelTypeAzure: | ||||
| 	case channeltype.Azure: | ||||
| 		return 0, errors.New("尚未实现") | ||||
| 	case common.ChannelTypeCustom: | ||||
| 	case channeltype.Custom: | ||||
| 		baseURL = channel.GetBaseURL() | ||||
| 	case common.ChannelTypeCloseAI: | ||||
| 	case channeltype.CloseAI: | ||||
| 		return updateChannelCloseAIBalance(channel) | ||||
| 	case common.ChannelTypeOpenAISB: | ||||
| 	case channeltype.OpenAISB: | ||||
| 		return updateChannelOpenAISBBalance(channel) | ||||
| 	case common.ChannelTypeAIProxy: | ||||
| 	case channeltype.AIProxy: | ||||
| 		return updateChannelAIProxyBalance(channel) | ||||
| 	case common.ChannelTypeAPI2GPT: | ||||
| 	case channeltype.API2GPT: | ||||
| 		return updateChannelAPI2GPTBalance(channel) | ||||
| 	case common.ChannelTypeAIGC2D: | ||||
| 	case channeltype.AIGC2D: | ||||
| 		return updateChannelAIGC2DBalance(channel) | ||||
| 	default: | ||||
| 		return 0, errors.New("尚未实现") | ||||
| @@ -295,16 +296,16 @@ func UpdateChannelBalance(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func updateAllChannelsBalance() error { | ||||
| 	channels, err := model.GetAllChannels(0, 0, true) | ||||
| 	channels, err := model.GetAllChannels(0, 0, "all") | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	for _, channel := range channels { | ||||
| 		if channel.Status != common.ChannelStatusEnabled { | ||||
| 		if channel.Status != model.ChannelStatusEnabled { | ||||
| 			continue | ||||
| 		} | ||||
| 		// TODO: support Azure | ||||
| 		if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { | ||||
| 		if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom { | ||||
| 			continue | ||||
| 		} | ||||
| 		balance, err := updateChannelBalance(channel) | ||||
| @@ -313,7 +314,7 @@ func updateAllChannelsBalance() error { | ||||
| 		} else { | ||||
| 			// err is nil & balance <= 0 means quota is used up | ||||
| 			if balance <= 0 { | ||||
| 				disableChannel(channel.Id, channel.Name, "余额不足") | ||||
| 				monitor.DisableChannel(channel.Id, channel.Name, "余额不足") | ||||
| 			} | ||||
| 		} | ||||
| 		time.Sleep(config.RequestInterval) | ||||
| @@ -322,15 +323,14 @@ func updateAllChannelsBalance() error { | ||||
| } | ||||
|  | ||||
| func UpdateAllChannelsBalance(c *gin.Context) { | ||||
| 	// TODO: make it async | ||||
| 	err := updateAllChannelsBalance() | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	//err := updateAllChannelsBalance() | ||||
| 	//if err != nil { | ||||
| 	//	c.JSON(http.StatusOK, gin.H{ | ||||
| 	//		"success": false, | ||||
| 	//		"message": err.Error(), | ||||
| 	//	}) | ||||
| 	//	return | ||||
| 	//} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
|   | ||||
| @@ -5,19 +5,24 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/message" | ||||
| 	"github.com/songquanpeng/one-api/middleware" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/helper" | ||||
| 	"github.com/songquanpeng/one-api/monitor" | ||||
| 	relay "github.com/songquanpeng/one-api/relay" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"github.com/songquanpeng/one-api/relay/controller" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"net/url" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| @@ -26,7 +31,7 @@ import ( | ||||
|  | ||||
| func buildTestRequest() *relaymodel.GeneralOpenAIRequest { | ||||
| 	testRequest := &relaymodel.GeneralOpenAIRequest{ | ||||
| 		MaxTokens: 1, | ||||
| 		MaxTokens: 2, | ||||
| 		Stream:    false, | ||||
| 		Model:     "gpt-3.5-turbo", | ||||
| 	} | ||||
| @@ -51,18 +56,25 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error | ||||
| 	c.Request.Header.Set("Content-Type", "application/json") | ||||
| 	c.Set("channel", channel.Type) | ||||
| 	c.Set("base_url", channel.GetBaseURL()) | ||||
| 	meta := util.GetRelayMeta(c) | ||||
| 	apiType := constant.ChannelType2APIType(channel.Type) | ||||
| 	adaptor := helper.GetAdaptor(apiType) | ||||
| 	middleware.SetupContextForSelectedChannel(c, channel, "") | ||||
| 	meta := meta.GetByContext(c) | ||||
| 	apiType := channeltype.ToAPIType(channel.Type) | ||||
| 	adaptor := relay.GetAdaptor(apiType) | ||||
| 	if adaptor == nil { | ||||
| 		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil | ||||
| 	} | ||||
| 	adaptor.Init(meta) | ||||
| 	modelName := adaptor.GetModelList()[0] | ||||
| 	if !strings.Contains(channel.Models, modelName) { | ||||
| 		modelNames := strings.Split(channel.Models, ",") | ||||
| 		if len(modelNames) > 0 { | ||||
| 			modelName = modelNames[0] | ||||
| 		} | ||||
| 	} | ||||
| 	request := buildTestRequest() | ||||
| 	request.Model = modelName | ||||
| 	meta.OriginModelName, meta.ActualModelName = modelName, modelName | ||||
| 	convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request) | ||||
| 	convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 	} | ||||
| @@ -77,7 +89,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error | ||||
| 		return err, nil | ||||
| 	} | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		err := util.RelayErrorHandler(resp) | ||||
| 		err := controller.RelayErrorHandler(resp) | ||||
| 		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error | ||||
| 	} | ||||
| 	usage, respErr := adaptor.DoResponse(c, resp, meta) | ||||
| @@ -139,33 +151,7 @@ func TestChannel(c *gin.Context) { | ||||
| var testAllChannelsLock sync.Mutex | ||||
| var testAllChannelsRunning bool = false | ||||
|  | ||||
| func notifyRootUser(subject string, content string) { | ||||
| 	if config.RootUserEmail == "" { | ||||
| 		config.RootUserEmail = model.GetRootUserEmail() | ||||
| 	} | ||||
| 	err := common.SendEmail(subject, config.RootUserEmail, content) | ||||
| 	if err != nil { | ||||
| 		logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // disable & notify | ||||
| func disableChannel(channelId int, channelName string, reason string) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| // enable & notify | ||||
| func enableChannel(channelId int, channelName string) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| func testAllChannels(notify bool) error { | ||||
| func testChannels(notify bool, scope string) error { | ||||
| 	if config.RootUserEmail == "" { | ||||
| 		config.RootUserEmail = model.GetRootUserEmail() | ||||
| 	} | ||||
| @@ -176,7 +162,7 @@ func testAllChannels(notify bool) error { | ||||
| 	} | ||||
| 	testAllChannelsRunning = true | ||||
| 	testAllChannelsLock.Unlock() | ||||
| 	channels, err := model.GetAllChannels(0, 0, true) | ||||
| 	channels, err := model.GetAllChannels(0, 0, scope) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -186,20 +172,24 @@ func testAllChannels(notify bool) error { | ||||
| 	} | ||||
| 	go func() { | ||||
| 		for _, channel := range channels { | ||||
| 			isChannelEnabled := channel.Status == common.ChannelStatusEnabled | ||||
| 			isChannelEnabled := channel.Status == model.ChannelStatusEnabled | ||||
| 			tik := time.Now() | ||||
| 			err, openaiErr := testChannel(channel) | ||||
| 			tok := time.Now() | ||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | ||||
| 			if isChannelEnabled && milliseconds > disableThreshold { | ||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 				if config.AutomaticDisableChannelEnabled { | ||||
| 					monitor.DisableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 				} else { | ||||
| 					_ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s (%d)测试超时", channel.Name, channel.Id), "", err.Error()) | ||||
| 				} | ||||
| 			} | ||||
| 			if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 			if isChannelEnabled && monitor.ShouldDisableChannel(openaiErr, -1) { | ||||
| 				monitor.DisableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 			} | ||||
| 			if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { | ||||
| 				enableChannel(channel.Id, channel.Name) | ||||
| 			if !isChannelEnabled && monitor.ShouldEnableChannel(err, openaiErr) { | ||||
| 				monitor.EnableChannel(channel.Id, channel.Name) | ||||
| 			} | ||||
| 			channel.UpdateResponseTime(milliseconds) | ||||
| 			time.Sleep(config.RequestInterval) | ||||
| @@ -208,7 +198,7 @@ func testAllChannels(notify bool) error { | ||||
| 		testAllChannelsRunning = false | ||||
| 		testAllChannelsLock.Unlock() | ||||
| 		if notify { | ||||
| 			err := common.SendEmail("通道测试完成", config.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") | ||||
| 			err := message.Notify(message.ByAll, "渠道测试完成", "", "渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常") | ||||
| 			if err != nil { | ||||
| 				logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||
| 			} | ||||
| @@ -217,8 +207,12 @@ func testAllChannels(notify bool) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func TestAllChannels(c *gin.Context) { | ||||
| 	err := testAllChannels(true) | ||||
| func TestChannels(c *gin.Context) { | ||||
| 	scope := c.Query("scope") | ||||
| 	if scope == "" { | ||||
| 		scope = "all" | ||||
| 	} | ||||
| 	err := testChannels(true, scope) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -237,7 +231,7 @@ func AutomaticallyTestChannels(frequency int) { | ||||
| 	for { | ||||
| 		time.Sleep(time.Duration(frequency) * time.Minute) | ||||
| 		logger.SysLog("testing all channels") | ||||
| 		_ = testAllChannels(false) | ||||
| 		_ = testChannels(false, "all") | ||||
| 		logger.SysLog("channel test finished") | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -15,7 +15,7 @@ func GetAllChannels(c *gin.Context) { | ||||
| 	if p < 0 { | ||||
| 		p = 0 | ||||
| 	} | ||||
| 	channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, false) | ||||
| 	channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, "limited") | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
|   | ||||
| @@ -2,13 +2,13 @@ package controller | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| func GetGroups(c *gin.Context) { | ||||
| 	groupNames := make([]string, 0) | ||||
| 	for groupName := range common.GroupRatio { | ||||
| 	for groupName := range billingratio.GroupRatio { | ||||
| 		groupNames = append(groupNames, groupName) | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/message" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| @@ -22,6 +23,7 @@ func GetStatus(c *gin.Context) { | ||||
| 			"email_verification":  config.EmailVerificationEnabled, | ||||
| 			"github_oauth":        config.GitHubOAuthEnabled, | ||||
| 			"github_client_id":    config.GitHubClientId, | ||||
| 			"lark_client_id":      config.LarkClientId, | ||||
| 			"system_name":         config.SystemName, | ||||
| 			"logo":                config.Logo, | ||||
| 			"footer_html":         config.Footer, | ||||
| @@ -110,7 +112,7 @@ func SendEmailVerification(c *gin.Context) { | ||||
| 	content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+ | ||||
| 		"<p>您的验证码为: <strong>%s</strong></p>"+ | ||||
| 		"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, code, common.VerificationValidMinutes) | ||||
| 	err := common.SendEmail(subject, email, content) | ||||
| 	err := message.SendEmail(subject, email, content) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -149,7 +151,7 @@ func SendPasswordResetEmail(c *gin.Context) { | ||||
| 		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ | ||||
| 		"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ | ||||
| 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, link, link, common.VerificationValidMinutes) | ||||
| 	err := common.SendEmail(subject, email, content) | ||||
| 	err := message.SendEmail(subject, email, content) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
|   | ||||
| @@ -3,11 +3,15 @@ package controller | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/ai360" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/moonshot" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/helper" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	relay "github.com/songquanpeng/one-api/relay" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/apitype" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://platform.openai.com/docs/api-reference/models/list | ||||
| @@ -37,8 +41,9 @@ type OpenAIModels struct { | ||||
| 	Parent     *string                 `json:"parent"` | ||||
| } | ||||
|  | ||||
| var openAIModels []OpenAIModels | ||||
| var openAIModelsMap map[string]OpenAIModels | ||||
| var models []OpenAIModels | ||||
| var modelsMap map[string]OpenAIModels | ||||
| var channelId2Models map[int][]string | ||||
|  | ||||
| func init() { | ||||
| 	var permission []OpenAIModelPermission | ||||
| @@ -57,15 +62,15 @@ func init() { | ||||
| 		IsBlocking:         false, | ||||
| 	}) | ||||
| 	// https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||
| 	for i := 0; i < constant.APITypeDummy; i++ { | ||||
| 		if i == constant.APITypeAIProxyLibrary { | ||||
| 	for i := 0; i < apitype.Dummy; i++ { | ||||
| 		if i == apitype.AIProxyLibrary { | ||||
| 			continue | ||||
| 		} | ||||
| 		adaptor := helper.GetAdaptor(i) | ||||
| 		adaptor := relay.GetAdaptor(i) | ||||
| 		channelName := adaptor.GetChannelName() | ||||
| 		modelNames := adaptor.GetModelList() | ||||
| 		for _, modelName := range modelNames { | ||||
| 			openAIModels = append(openAIModels, OpenAIModels{ | ||||
| 			models = append(models, OpenAIModels{ | ||||
| 				Id:         modelName, | ||||
| 				Object:     "model", | ||||
| 				Created:    1626777600, | ||||
| @@ -76,44 +81,95 @@ func init() { | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	for _, modelName := range ai360.ModelList { | ||||
| 		openAIModels = append(openAIModels, OpenAIModels{ | ||||
| 			Id:         modelName, | ||||
| 			Object:     "model", | ||||
| 			Created:    1626777600, | ||||
| 			OwnedBy:    "360", | ||||
| 			Permission: permission, | ||||
| 			Root:       modelName, | ||||
| 			Parent:     nil, | ||||
| 		}) | ||||
| 	for _, channelType := range openai.CompatibleChannels { | ||||
| 		if channelType == channeltype.Azure { | ||||
| 			continue | ||||
| 		} | ||||
| 		channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) | ||||
| 		for _, modelName := range channelModelList { | ||||
| 			models = append(models, OpenAIModels{ | ||||
| 				Id:         modelName, | ||||
| 				Object:     "model", | ||||
| 				Created:    1626777600, | ||||
| 				OwnedBy:    channelName, | ||||
| 				Permission: permission, | ||||
| 				Root:       modelName, | ||||
| 				Parent:     nil, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	for _, modelName := range moonshot.ModelList { | ||||
| 		openAIModels = append(openAIModels, OpenAIModels{ | ||||
| 			Id:         modelName, | ||||
| 			Object:     "model", | ||||
| 			Created:    1626777600, | ||||
| 			OwnedBy:    "moonshot", | ||||
| 			Permission: permission, | ||||
| 			Root:       modelName, | ||||
| 			Parent:     nil, | ||||
| 		}) | ||||
| 	modelsMap = make(map[string]OpenAIModels) | ||||
| 	for _, model := range models { | ||||
| 		modelsMap[model.Id] = model | ||||
| 	} | ||||
| 	openAIModelsMap = make(map[string]OpenAIModels) | ||||
| 	for _, model := range openAIModels { | ||||
| 		openAIModelsMap[model.Id] = model | ||||
| 	channelId2Models = make(map[int][]string) | ||||
| 	for i := 1; i < channeltype.Dummy; i++ { | ||||
| 		adaptor := relay.GetAdaptor(channeltype.ToAPIType(i)) | ||||
| 		meta := &meta.Meta{ | ||||
| 			ChannelType: i, | ||||
| 		} | ||||
| 		adaptor.Init(meta) | ||||
| 		channelId2Models[i] = adaptor.GetModelList() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ListModels(c *gin.Context) { | ||||
| func DashboardListModels(c *gin.Context) { | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data":    channelId2Models, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func ListAllModels(c *gin.Context) { | ||||
| 	c.JSON(200, gin.H{ | ||||
| 		"object": "list", | ||||
| 		"data":   openAIModels, | ||||
| 		"data":   models, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func ListModels(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	var availableModels []string | ||||
| 	if c.GetString("available_models") != "" { | ||||
| 		availableModels = strings.Split(c.GetString("available_models"), ",") | ||||
| 	} else { | ||||
| 		userId := c.GetInt("id") | ||||
| 		userGroup, _ := model.CacheGetUserGroup(userId) | ||||
| 		availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) | ||||
| 	} | ||||
| 	modelSet := make(map[string]bool) | ||||
| 	for _, availableModel := range availableModels { | ||||
| 		modelSet[availableModel] = true | ||||
| 	} | ||||
| 	availableOpenAIModels := make([]OpenAIModels, 0) | ||||
| 	for _, model := range models { | ||||
| 		if _, ok := modelSet[model.Id]; ok { | ||||
| 			modelSet[model.Id] = false | ||||
| 			availableOpenAIModels = append(availableOpenAIModels, model) | ||||
| 		} | ||||
| 	} | ||||
| 	for modelName, ok := range modelSet { | ||||
| 		if ok { | ||||
| 			availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ | ||||
| 				Id:      modelName, | ||||
| 				Object:  "model", | ||||
| 				Created: 1626777600, | ||||
| 				OwnedBy: "custom", | ||||
| 				Root:    modelName, | ||||
| 				Parent:  nil, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	c.JSON(200, gin.H{ | ||||
| 		"object": "list", | ||||
| 		"data":   availableOpenAIModels, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func RetrieveModel(c *gin.Context) { | ||||
| 	modelId := c.Param("model") | ||||
| 	if model, ok := openAIModelsMap[modelId]; ok { | ||||
| 	if model, ok := modelsMap[modelId]; ok { | ||||
| 		c.JSON(200, model) | ||||
| 	} else { | ||||
| 		Error := relaymodel.Error{ | ||||
| @@ -127,3 +183,30 @@ func RetrieveModel(c *gin.Context) { | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetUserAvailableModels(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	id := c.GetInt("id") | ||||
| 	userGroup, err := model.CacheGetUserGroup(id) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	models, err := model.CacheGetGroupModels(ctx, userGroup) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data":    models, | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| @@ -106,7 +107,7 @@ func AddRedemption(c *gin.Context) { | ||||
| 	} | ||||
| 	var keys []string | ||||
| 	for i := 0; i < redemption.Count; i++ { | ||||
| 		key := helper.GetUUID() | ||||
| 		key := random.GetUUID() | ||||
| 		cleanRedemption := model.Redemption{ | ||||
| 			UserId:      c.GetInt("id"), | ||||
| 			Name:        redemption.Name, | ||||
|   | ||||
| @@ -1,62 +1,126 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/middleware" | ||||
| 	dbmodel "github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/monitor" | ||||
| 	"github.com/songquanpeng/one-api/relay/controller" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| ) | ||||
|  | ||||
| // https://platform.openai.com/docs/api-reference/chat | ||||
|  | ||||
| func Relay(c *gin.Context) { | ||||
| 	relayMode := constant.Path2RelayMode(c.Request.URL.Path) | ||||
| func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { | ||||
| 	var err *model.ErrorWithStatusCode | ||||
| 	switch relayMode { | ||||
| 	case constant.RelayModeImagesGenerations: | ||||
| 	case relaymode.ImagesGenerations: | ||||
| 		err = controller.RelayImageHelper(c, relayMode) | ||||
| 	case constant.RelayModeAudioSpeech: | ||||
| 	case relaymode.AudioSpeech: | ||||
| 		fallthrough | ||||
| 	case constant.RelayModeAudioTranslation: | ||||
| 	case relaymode.AudioTranslation: | ||||
| 		fallthrough | ||||
| 	case constant.RelayModeAudioTranscription: | ||||
| 	case relaymode.AudioTranscription: | ||||
| 		err = controller.RelayAudioHelper(c, relayMode) | ||||
| 	default: | ||||
| 		err = controller.RelayTextHelper(c) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		requestId := c.GetString(logger.RequestIdKey) | ||||
| 		retryTimesStr := c.Query("retry") | ||||
| 		retryTimes, _ := strconv.Atoi(retryTimesStr) | ||||
| 		if retryTimesStr == "" { | ||||
| 			retryTimes = config.RetryTimes | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func Relay(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	relayMode := relaymode.GetByPath(c.Request.URL.Path) | ||||
| 	if config.DebugEnabled { | ||||
| 		requestBody, _ := common.GetRequestBody(c) | ||||
| 		logger.Debugf(ctx, "request body: %s", string(requestBody)) | ||||
| 	} | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	bizErr := relayHelper(c, relayMode) | ||||
| 	if bizErr == nil { | ||||
| 		monitor.Emit(channelId, true) | ||||
| 		return | ||||
| 	} | ||||
| 	lastFailedChannelId := channelId | ||||
| 	channelName := c.GetString("channel_name") | ||||
| 	group := c.GetString("group") | ||||
| 	originalModel := c.GetString("original_model") | ||||
| 	go processChannelRelayError(ctx, channelId, channelName, bizErr) | ||||
| 	requestId := c.GetString(logger.RequestIdKey) | ||||
| 	retryTimes := config.RetryTimes | ||||
| 	if !shouldRetry(c, bizErr.StatusCode) { | ||||
| 		logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode) | ||||
| 		retryTimes = 0 | ||||
| 	} | ||||
| 	for i := retryTimes; i > 0; i-- { | ||||
| 		channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes) | ||||
| 		if err != nil { | ||||
| 			logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) | ||||
| 			break | ||||
| 		} | ||||
| 		if retryTimes > 0 { | ||||
| 			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) | ||||
| 		} else { | ||||
| 			if err.StatusCode == http.StatusTooManyRequests { | ||||
| 				err.Error.Message = "当前分组上游负载已饱和,请稍后再试" | ||||
| 			} | ||||
| 			err.Error.Message = helper.MessageWithRequestId(err.Error.Message, requestId) | ||||
| 			c.JSON(err.StatusCode, gin.H{ | ||||
| 				"error": err.Error, | ||||
| 			}) | ||||
| 		logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i) | ||||
| 		if channel.Id == lastFailedChannelId { | ||||
| 			continue | ||||
| 		} | ||||
| 		middleware.SetupContextForSelectedChannel(c, channel, originalModel) | ||||
| 		requestBody, err := common.GetRequestBody(c) | ||||
| 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||
| 		bizErr = relayHelper(c, relayMode) | ||||
| 		if bizErr == nil { | ||||
| 			return | ||||
| 		} | ||||
| 		channelId := c.GetInt("channel_id") | ||||
| 		logger.Error(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | ||||
| 		// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||
| 		if util.ShouldDisableChannel(&err.Error, err.StatusCode) { | ||||
| 			channelId := c.GetInt("channel_id") | ||||
| 			channelName := c.GetString("channel_name") | ||||
| 			disableChannel(channelId, channelName, err.Message) | ||||
| 		lastFailedChannelId = channelId | ||||
| 		channelName := c.GetString("channel_name") | ||||
| 		go processChannelRelayError(ctx, channelId, channelName, bizErr) | ||||
| 	} | ||||
| 	if bizErr != nil { | ||||
| 		if bizErr.StatusCode == http.StatusTooManyRequests { | ||||
| 			bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" | ||||
| 		} | ||||
| 		bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId) | ||||
| 		c.JSON(bizErr.StatusCode, gin.H{ | ||||
| 			"error": bizErr.Error, | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func shouldRetry(c *gin.Context, statusCode int) bool { | ||||
| 	if _, ok := c.Get("specific_channel_id"); ok { | ||||
| 		return false | ||||
| 	} | ||||
| 	if statusCode == http.StatusTooManyRequests { | ||||
| 		return true | ||||
| 	} | ||||
| 	if statusCode/100 == 5 { | ||||
| 		return true | ||||
| 	} | ||||
| 	if statusCode == http.StatusBadRequest { | ||||
| 		return false | ||||
| 	} | ||||
| 	if statusCode/100 == 2 { | ||||
| 		return false | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) { | ||||
| 	logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) | ||||
| 	// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||
| 	if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { | ||||
| 		monitor.DisableChannel(channelId, channelName, err.Message) | ||||
| 	} else { | ||||
| 		monitor.Emit(channelId, false) | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -1,10 +1,12 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/network" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| @@ -16,7 +18,10 @@ func GetAllTokens(c *gin.Context) { | ||||
| 	if p < 0 { | ||||
| 		p = 0 | ||||
| 	} | ||||
| 	tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage) | ||||
|  | ||||
| 	order := c.Query("order") | ||||
| 	tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage, order) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -101,6 +106,19 @@ func GetTokenStatus(c *gin.Context) { | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func validateToken(c *gin.Context, token model.Token) error { | ||||
| 	if len(token.Name) > 30 { | ||||
| 		return fmt.Errorf("令牌名称过长") | ||||
| 	} | ||||
| 	if token.Subnet != nil && *token.Subnet != "" { | ||||
| 		err := network.IsValidSubnets(*token.Subnet) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("无效的网段:%s", err.Error()) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func AddToken(c *gin.Context) { | ||||
| 	token := model.Token{} | ||||
| 	err := c.ShouldBindJSON(&token) | ||||
| @@ -111,22 +129,26 @@ func AddToken(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	if len(token.Name) > 30 { | ||||
| 	err = validateToken(c, token) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "令牌名称过长", | ||||
| 			"message": fmt.Sprintf("参数错误:%s", err.Error()), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	cleanToken := model.Token{ | ||||
| 		UserId:         c.GetInt("id"), | ||||
| 		Name:           token.Name, | ||||
| 		Key:            helper.GenerateKey(), | ||||
| 		Key:            random.GenerateKey(), | ||||
| 		CreatedTime:    helper.GetTimestamp(), | ||||
| 		AccessedTime:   helper.GetTimestamp(), | ||||
| 		ExpiredTime:    token.ExpiredTime, | ||||
| 		RemainQuota:    token.RemainQuota, | ||||
| 		UnlimitedQuota: token.UnlimitedQuota, | ||||
| 		Models:         token.Models, | ||||
| 		Subnet:         token.Subnet, | ||||
| 	} | ||||
| 	err = cleanToken.Insert() | ||||
| 	if err != nil { | ||||
| @@ -139,6 +161,7 @@ func AddToken(c *gin.Context) { | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data":    cleanToken, | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
| @@ -173,10 +196,11 @@ func UpdateToken(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	if len(token.Name) > 30 { | ||||
| 	err = validateToken(c, token) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "令牌名称过长", | ||||
| 			"message": fmt.Sprintf("参数错误:%s", err.Error()), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| @@ -188,15 +212,15 @@ func UpdateToken(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	if token.Status == common.TokenStatusEnabled { | ||||
| 		if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { | ||||
| 	if token.Status == model.TokenStatusEnabled { | ||||
| 		if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 		if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { | ||||
| 		if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", | ||||
| @@ -212,6 +236,8 @@ func UpdateToken(c *gin.Context) { | ||||
| 		cleanToken.ExpiredTime = token.ExpiredTime | ||||
| 		cleanToken.RemainQuota = token.RemainQuota | ||||
| 		cleanToken.UnlimitedQuota = token.UnlimitedQuota | ||||
| 		cleanToken.Models = token.Models | ||||
| 		cleanToken.Subnet = token.Subnet | ||||
| 	} | ||||
| 	err = cleanToken.Update() | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -5,7 +5,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| @@ -58,11 +58,11 @@ func Login(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	setupLogin(&user, c) | ||||
| 	SetupLogin(&user, c) | ||||
| } | ||||
|  | ||||
| // setup session & cookies and then return user info | ||||
| func setupLogin(user *model.User, c *gin.Context) { | ||||
| func SetupLogin(user *model.User, c *gin.Context) { | ||||
| 	session := sessions.Default(c) | ||||
| 	session.Set("id", user.Id) | ||||
| 	session.Set("username", user.Username) | ||||
| @@ -184,7 +184,10 @@ func GetAllUsers(c *gin.Context) { | ||||
| 	if p < 0 { | ||||
| 		p = 0 | ||||
| 	} | ||||
| 	users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage) | ||||
|  | ||||
| 	order := c.DefaultQuery("order", "") | ||||
| 	users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -192,12 +195,12 @@ func GetAllUsers(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data":    users, | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func SearchUsers(c *gin.Context) { | ||||
| @@ -236,7 +239,7 @@ func GetUser(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
| 	myRole := c.GetInt("role") | ||||
| 	if myRole <= user.Role && myRole != common.RoleRootUser { | ||||
| 	if myRole <= user.Role && myRole != model.RoleRootUser { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "无权获取同级或更高等级用户的信息", | ||||
| @@ -284,7 +287,7 @@ func GenerateAccessToken(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user.AccessToken = helper.GetUUID() | ||||
| 	user.AccessToken = random.GetUUID() | ||||
|  | ||||
| 	if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| @@ -321,7 +324,7 @@ func GetAffCode(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
| 	if user.AffCode == "" { | ||||
| 		user.AffCode = helper.GetRandomString(4) | ||||
| 		user.AffCode = random.GetRandomString(4) | ||||
| 		if err := user.Update(false); err != nil { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| @@ -385,14 +388,14 @@ func UpdateUser(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
| 	myRole := c.GetInt("role") | ||||
| 	if myRole <= originUser.Role && myRole != common.RoleRootUser { | ||||
| 	if myRole <= originUser.Role && myRole != model.RoleRootUser { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "无权更新同权限等级或更高权限等级的用户信息", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	if myRole <= updatedUser.Role && myRole != common.RoleRootUser { | ||||
| 	if myRole <= updatedUser.Role && myRole != model.RoleRootUser { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "无权将其他用户权限等级提升到大于等于自己的权限等级", | ||||
| @@ -506,7 +509,7 @@ func DeleteSelf(c *gin.Context) { | ||||
| 	id := c.GetInt("id") | ||||
| 	user, _ := model.GetUserById(id, false) | ||||
|  | ||||
| 	if user.Role == common.RoleRootUser { | ||||
| 	if user.Role == model.RoleRootUser { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "不能删除超级管理员账户", | ||||
| @@ -608,7 +611,7 @@ func ManageUser(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
| 	myRole := c.GetInt("role") | ||||
| 	if myRole <= user.Role && myRole != common.RoleRootUser { | ||||
| 	if myRole <= user.Role && myRole != model.RoleRootUser { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "无权更新同权限等级或更高权限等级的用户信息", | ||||
| @@ -617,8 +620,8 @@ func ManageUser(c *gin.Context) { | ||||
| 	} | ||||
| 	switch req.Action { | ||||
| 	case "disable": | ||||
| 		user.Status = common.UserStatusDisabled | ||||
| 		if user.Role == common.RoleRootUser { | ||||
| 		user.Status = model.UserStatusDisabled | ||||
| 		if user.Role == model.RoleRootUser { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "无法禁用超级管理员用户", | ||||
| @@ -626,9 +629,9 @@ func ManageUser(c *gin.Context) { | ||||
| 			return | ||||
| 		} | ||||
| 	case "enable": | ||||
| 		user.Status = common.UserStatusEnabled | ||||
| 		user.Status = model.UserStatusEnabled | ||||
| 	case "delete": | ||||
| 		if user.Role == common.RoleRootUser { | ||||
| 		if user.Role == model.RoleRootUser { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "无法删除超级管理员用户", | ||||
| @@ -643,37 +646,37 @@ func ManageUser(c *gin.Context) { | ||||
| 			return | ||||
| 		} | ||||
| 	case "promote": | ||||
| 		if myRole != common.RoleRootUser { | ||||
| 		if myRole != model.RoleRootUser { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "普通管理员用户无法提升其他用户为管理员", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 		if user.Role >= common.RoleAdminUser { | ||||
| 		if user.Role >= model.RoleAdminUser { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "该用户已经是管理员", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 		user.Role = common.RoleAdminUser | ||||
| 		user.Role = model.RoleAdminUser | ||||
| 	case "demote": | ||||
| 		if user.Role == common.RoleRootUser { | ||||
| 		if user.Role == model.RoleRootUser { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "无法降级超级管理员用户", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 		if user.Role == common.RoleCommonUser { | ||||
| 		if user.Role == model.RoleCommonUser { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "该用户已经是普通用户", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 		user.Role = common.RoleCommonUser | ||||
| 		user.Role = model.RoleCommonUser | ||||
| 	} | ||||
|  | ||||
| 	if err := user.Update(false); err != nil { | ||||
| @@ -727,7 +730,7 @@ func EmailBind(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	if user.Role == common.RoleRootUser { | ||||
| 	if user.Role == model.RoleRootUser { | ||||
| 		config.RootUserEmail = email | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| @@ -767,3 +770,38 @@ func TopUp(c *gin.Context) { | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| type adminTopUpRequest struct { | ||||
| 	UserId int    `json:"user_id"` | ||||
| 	Quota  int    `json:"quota"` | ||||
| 	Remark string `json:"remark"` | ||||
| } | ||||
|  | ||||
| func AdminTopUp(c *gin.Context) { | ||||
| 	req := adminTopUpRequest{} | ||||
| 	err := c.ShouldBindJSON(&req) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	err = model.IncreaseUserQuota(req.UserId, int64(req.Quota)) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	if req.Remark == "" { | ||||
| 		req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) | ||||
| 	} | ||||
| 	model.RecordTopupLog(req.UserId, req.Remark, req.Quota) | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -2,7 +2,7 @@ version: '3.4' | ||||
|  | ||||
| services: | ||||
|   one-api: | ||||
|     image: justsong/one-api:latest | ||||
|     image: "${REGISTRY:-docker.io}/justsong/one-api:latest" | ||||
|     container_name: one-api | ||||
|     restart: always | ||||
|     command: --log-dir /app/logs | ||||
| @@ -29,12 +29,12 @@ services: | ||||
|       retries: 3 | ||||
|  | ||||
|   redis: | ||||
|     image: redis:latest | ||||
|     image: "${REGISTRY:-docker.io}/redis:latest" | ||||
|     container_name: redis | ||||
|     restart: always | ||||
|  | ||||
|   db: | ||||
|     image: mysql:8.2.0 | ||||
|     image: "${REGISTRY:-docker.io}/mysql:8.2.0" | ||||
|     restart: always | ||||
|     container_name: mysql | ||||
|     volumes: | ||||
|   | ||||
							
								
								
									
										53
									
								
								docs/API.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								docs/API.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | ||||
| # 使用 API 操控 & 扩展 One API | ||||
| > 欢迎提交 PR 在此放上你的拓展项目。 | ||||
|  | ||||
| 例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。 | ||||
|  | ||||
| 又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。 | ||||
|  | ||||
| ## 鉴权 | ||||
| One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下图获取: | ||||
|  | ||||
|  | ||||
|  | ||||
| 之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API: | ||||
|  | ||||
|  | ||||
| ## 请求格式与响应格式 | ||||
| One API 使用 JSON 格式进行请求和响应。 | ||||
|  | ||||
| 对于响应体,一般格式如下: | ||||
| ```json | ||||
| { | ||||
|   "message": "请求信息", | ||||
|   "success": true, | ||||
|   "data": {} | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## API 列表 | ||||
| > 当前 API 列表不全,请自行通过浏览器抓取前端请求 | ||||
|  | ||||
| 如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。 | ||||
|  | ||||
| ### 获取当前登录用户信息 | ||||
| **GET** `/api/user/self` | ||||
|  | ||||
| ### 为给定用户充值额度 | ||||
| **POST** `/api/topup` | ||||
| ```json | ||||
| { | ||||
|   "user_id": 1, | ||||
|   "quota": 100000, | ||||
|   "remark": "充值 100000 额度" | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## 其他 | ||||
| ### 充值链接上的附加参数 | ||||
| One API 会在用户点击充值按钮的时候,将用户的信息和充值信息附加在链接上,例如: | ||||
| `https://example.com?username=root&user_id=1&transaction_id=4b3eed80-55d5-443f-bd44-fb18c648c837` | ||||
|  | ||||
| 你可以通过解析链接上的参数来获取用户信息和充值信息,然后调用 API 来为用户充值。 | ||||
|  | ||||
| 注意,不是所有主题都支持该功能,欢迎 PR 补齐。 | ||||
							
								
								
									
										10
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								go.mod
									
									
									
									
									
								
							| @@ -15,6 +15,7 @@ require ( | ||||
| 	github.com/google/uuid v1.3.0 | ||||
| 	github.com/gorilla/websocket v1.5.0 | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.5 | ||||
| 	github.com/smartystreets/goconvey v1.8.1 | ||||
| 	github.com/stretchr/testify v1.8.3 | ||||
| 	golang.org/x/crypto v0.17.0 | ||||
| 	golang.org/x/image v0.14.0 | ||||
| @@ -37,15 +38,18 @@ require ( | ||||
| 	github.com/go-playground/universal-translator v0.18.1 // indirect | ||||
| 	github.com/go-sql-driver/mysql v1.6.0 // indirect | ||||
| 	github.com/goccy/go-json v0.10.2 // indirect | ||||
| 	github.com/gopherjs/gopherjs v1.17.2 // indirect | ||||
| 	github.com/gorilla/context v1.1.1 // indirect | ||||
| 	github.com/gorilla/securecookie v1.1.1 // indirect | ||||
| 	github.com/gorilla/sessions v1.2.1 // indirect | ||||
| 	github.com/jackc/pgpassfile v1.0.0 // indirect | ||||
| 	github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect | ||||
| 	github.com/jackc/pgx/v5 v5.3.1 // indirect | ||||
| 	github.com/jackc/pgx/v5 v5.5.4 // indirect | ||||
| 	github.com/jackc/puddle/v2 v2.2.1 // indirect | ||||
| 	github.com/jinzhu/inflection v1.0.0 // indirect | ||||
| 	github.com/jinzhu/now v1.1.5 // indirect | ||||
| 	github.com/json-iterator/go v1.1.12 // indirect | ||||
| 	github.com/jtolds/gls v4.20.0+incompatible // indirect | ||||
| 	github.com/klauspost/cpuid/v2 v2.2.4 // indirect | ||||
| 	github.com/leodido/go-urn v1.2.4 // indirect | ||||
| 	github.com/mattn/go-isatty v0.0.19 // indirect | ||||
| @@ -54,12 +58,14 @@ require ( | ||||
| 	github.com/modern-go/reflect2 v1.0.2 // indirect | ||||
| 	github.com/pelletier/go-toml/v2 v2.0.8 // indirect | ||||
| 	github.com/pmezard/go-difflib v1.0.0 // indirect | ||||
| 	github.com/smarty/assertions v1.15.0 // indirect | ||||
| 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | ||||
| 	github.com/ugorji/go/codec v1.2.11 // indirect | ||||
| 	golang.org/x/arch v0.3.0 // indirect | ||||
| 	golang.org/x/net v0.17.0 // indirect | ||||
| 	golang.org/x/sync v0.1.0 // indirect | ||||
| 	golang.org/x/sys v0.15.0 // indirect | ||||
| 	golang.org/x/text v0.14.0 // indirect | ||||
| 	google.golang.org/protobuf v1.30.0 // indirect | ||||
| 	google.golang.org/protobuf v1.33.0 // indirect | ||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||
| ) | ||||
|   | ||||
							
								
								
									
										24
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								go.sum
									
									
									
									
									
								
							| @@ -56,11 +56,13 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL | ||||
| github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= | ||||
| github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= | ||||
| github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= | ||||
| github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= | ||||
| github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | ||||
| github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= | ||||
| github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= | ||||
| github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= | ||||
| github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= | ||||
| github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= | ||||
| github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= | ||||
| github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= | ||||
| github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= | ||||
| github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= | ||||
| @@ -73,8 +75,10 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI | ||||
| github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= | ||||
| github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= | ||||
| github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= | ||||
| github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= | ||||
| github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= | ||||
| github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8= | ||||
| github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= | ||||
| github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= | ||||
| github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= | ||||
| github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | ||||
| github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | ||||
| github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||
| @@ -83,6 +87,8 @@ github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/ | ||||
| github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= | ||||
| github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= | ||||
| github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= | ||||
| github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= | ||||
| github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= | ||||
| github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= | ||||
| github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= | ||||
| github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= | ||||
| @@ -125,6 +131,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN | ||||
| github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= | ||||
| github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= | ||||
| github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= | ||||
| github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= | ||||
| github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= | ||||
| github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= | ||||
| github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= | ||||
| github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= | ||||
| github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= | ||||
| github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= | ||||
| @@ -157,6 +167,8 @@ golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE | ||||
| golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | ||||
| golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= | ||||
| golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= | ||||
| golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= | ||||
| golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| @@ -173,12 +185,12 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= | ||||
| golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= | ||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= | ||||
| google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= | ||||
| google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= | ||||
| google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= | ||||
| google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= | ||||
| google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= | ||||
| google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= | ||||
| gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= | ||||
|   | ||||
							
								
								
									
										35
									
								
								i18n/en.json
									
									
									
									
									
								
							
							
						
						
									
										35
									
								
								i18n/en.json
									
									
									
									
									
								
							| @@ -8,12 +8,12 @@ | ||||
|   "确认删除": "Confirm Delete", | ||||
|   "确认绑定": "Confirm Binding", | ||||
|   "您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.", | ||||
|   "\"通道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"", | ||||
|   "通道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s", | ||||
|   "\"渠道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"", | ||||
|   "渠道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s", | ||||
|   "测试已在运行中": "Test is already running", | ||||
|   "响应时间 %.2fs 超过阈值 %.2fs": "Response time %.2fs exceeds threshold %.2fs", | ||||
|   "通道测试完成": "Channel test completed", | ||||
|   "通道测试完成,如果没有收到禁用通知,说明所有通道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal", | ||||
|   "渠道测试完成": "Channel test completed", | ||||
|   "渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal", | ||||
|   "无法连接至 GitHub 服务器,请稍后重试!": "Unable to connect to GitHub server, please try again later!", | ||||
|   "返回值非法,用户字段为空,请稍后重试!": "The return value is illegal, the user field is empty, please try again later!", | ||||
|   "管理员未开启通过 GitHub 登录以及注册": "The administrator did not turn on login and registration via GitHub", | ||||
| @@ -119,11 +119,11 @@ | ||||
|   " 个月 ": " M ", | ||||
|   " 年 ": " y ", | ||||
|   "未测试": "Not tested", | ||||
|   "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", | ||||
|   "已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.", | ||||
|   "已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", | ||||
|   "通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", | ||||
|   "已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!", | ||||
|   "渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", | ||||
|   "已成功开始测试所有渠道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.", | ||||
|   "已成功开始测试所有已启用渠道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", | ||||
|   "渠道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", | ||||
|   "已更新完毕所有已启用渠道余额!": "The balance of all enabled channels has been updated!", | ||||
|   "搜索渠道的 ID,名称和密钥 ...": "Search for channel ID, name and key ...", | ||||
|   "名称": "Name", | ||||
|   "分组": "Group", | ||||
| @@ -141,9 +141,9 @@ | ||||
|   "启用": "Enable", | ||||
|   "编辑": "Edit", | ||||
|   "添加新的渠道": "Add a new channel", | ||||
|   "测试所有通道": "Test all channels", | ||||
|   "测试所有已启用通道": "Test all enabled channels", | ||||
|   "更新所有已启用通道余额": "Update the balance of all enabled channels", | ||||
|   "测试所有渠道": "Test all channels", | ||||
|   "测试所有已启用渠道": "Test all enabled channels", | ||||
|   "更新所有已启用渠道余额": "Update the balance of all enabled channels", | ||||
|   "刷新": "Refresh", | ||||
|   "处理中...": "Processing...", | ||||
|   "绑定成功!": "Binding succeeded!", | ||||
| @@ -207,11 +207,11 @@ | ||||
|   "监控设置": "Monitoring Settings", | ||||
|   "最长响应时间": "Longest Response Time", | ||||
|   "单位秒": "Unit in seconds", | ||||
|   "当运行通道全部测试时": "When all operating channels are tested", | ||||
|   "超过此时间将自动禁用通道": "Channels will be automatically disabled if this time is exceeded", | ||||
|   "当运行渠道全部测试时": "When all operating channels are tested", | ||||
|   "超过此时间将自动禁用渠道": "Channels will be automatically disabled if this time is exceeded", | ||||
|   "额度提醒阈值": "Quota reminder threshold", | ||||
|   "低于此额度时将发送邮件提醒用户": "Email will be sent to remind users when the quota is below this", | ||||
|   "失败时自动禁用通道": "Automatically disable the channel when it fails", | ||||
|   "失败时自动禁用渠道": "Automatically disable the channel when it fails", | ||||
|   "保存监控设置": "Save Monitoring Settings", | ||||
|   "额度设置": "Quota Settings", | ||||
|   "新用户初始额度": "Initial quota for new users", | ||||
| @@ -405,7 +405,7 @@ | ||||
|   "镜像": "Mirror", | ||||
|   "请输入镜像站地址,格式为:https://domain.com,可不填,不填则使用渠道默认值": "Please enter the mirror site address, the format is: https://domain.com, it can be left blank, if left blank, the default value of the channel will be used", | ||||
|   "模型": "Model", | ||||
|   "请选择该通道所支持的模型": "Please select the model supported by the channel", | ||||
|   "请选择该渠道所支持的模型": "Please select the model supported by the channel", | ||||
|   "填入基础模型": "Fill in the basic model", | ||||
|   "填入所有模型": "Fill in all models", | ||||
|   "清除所有模型": "Clear all models", | ||||
| @@ -456,6 +456,7 @@ | ||||
|   "已绑定的邮箱账户": "Email Account Bound", | ||||
|   "用户信息更新成功!": "User information updated successfully!", | ||||
|   "模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f", | ||||
|   "模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f": "model rate %.2f, group rate %.2f, completion rate %.2f", | ||||
|   "使用明细(总消耗额度:{renderQuota(stat.quota)})": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})", | ||||
|   "用户名称": "User Name", | ||||
|   "令牌名称": "Token Name", | ||||
| @@ -514,7 +515,7 @@ | ||||
|   "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", | ||||
|   "Homepage URL 填": "Fill in the Homepage URL", | ||||
|   "Authorization callback URL 填": "Fill in the Authorization callback URL", | ||||
|   "请为通道命名": "Please name the channel", | ||||
|   "请为渠道命名": "Please name the channel", | ||||
|   "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", | ||||
|   "模型重定向": "Model redirection", | ||||
|   "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel", | ||||
|   | ||||
							
								
								
									
										28
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								main.go
									
									
									
									
									
								
							| @@ -12,7 +12,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/middleware" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/router" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| @@ -30,11 +30,25 @@ func main() { | ||||
| 	if config.DebugEnabled { | ||||
| 		logger.SysLog("running in debug mode") | ||||
| 	} | ||||
| 	var err error | ||||
| 	// Initialize SQL Database | ||||
| 	err := model.InitDB() | ||||
| 	model.DB, err = model.InitDB("SQL_DSN") | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("failed to initialize database: " + err.Error()) | ||||
| 	} | ||||
| 	if os.Getenv("LOG_SQL_DSN") != "" { | ||||
| 		logger.SysLog("using secondary database for table logs") | ||||
| 		model.LOG_DB, err = model.InitDB("LOG_SQL_DSN") | ||||
| 		if err != nil { | ||||
| 			logger.FatalLog("failed to initialize secondary database: " + err.Error()) | ||||
| 		} | ||||
| 	} else { | ||||
| 		model.LOG_DB = model.DB | ||||
| 	} | ||||
| 	err = model.CreateRootAccountIfNeed() | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("database init error: " + err.Error()) | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		err := model.CloseDB() | ||||
| 		if err != nil { | ||||
| @@ -64,13 +78,6 @@ func main() { | ||||
| 		go model.SyncOptions(config.SyncFrequency) | ||||
| 		go model.SyncChannelCache(config.SyncFrequency) | ||||
| 	} | ||||
| 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { | ||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) | ||||
| 		if err != nil { | ||||
| 			logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) | ||||
| 		} | ||||
| 		go controller.AutomaticallyUpdateChannels(frequency) | ||||
| 	} | ||||
| 	if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { | ||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) | ||||
| 		if err != nil { | ||||
| @@ -83,6 +90,9 @@ func main() { | ||||
| 		logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") | ||||
| 		model.InitBatchUpdater() | ||||
| 	} | ||||
| 	if config.EnableMetric { | ||||
| 		logger.SysLog("metric enabled, will disable channel if too much request failed") | ||||
| 	} | ||||
| 	openai.InitTokenEncoders() | ||||
|  | ||||
| 	// Initialize HTTP server | ||||
|   | ||||
| @@ -1,9 +1,11 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/blacklist" | ||||
| 	"github.com/songquanpeng/one-api/common/network" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| @@ -42,11 +44,14 @@ func authHelper(c *gin.Context, minRole int) { | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	if status.(int) == common.UserStatusDisabled { | ||||
| 	if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "用户已被封禁", | ||||
| 		}) | ||||
| 		session := sessions.Default(c) | ||||
| 		session.Clear() | ||||
| 		_ = session.Save() | ||||
| 		c.Abort() | ||||
| 		return | ||||
| 	} | ||||
| @@ -66,24 +71,25 @@ func authHelper(c *gin.Context, minRole int) { | ||||
|  | ||||
| func UserAuth() func(c *gin.Context) { | ||||
| 	return func(c *gin.Context) { | ||||
| 		authHelper(c, common.RoleCommonUser) | ||||
| 		authHelper(c, model.RoleCommonUser) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func AdminAuth() func(c *gin.Context) { | ||||
| 	return func(c *gin.Context) { | ||||
| 		authHelper(c, common.RoleAdminUser) | ||||
| 		authHelper(c, model.RoleAdminUser) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func RootAuth() func(c *gin.Context) { | ||||
| 	return func(c *gin.Context) { | ||||
| 		authHelper(c, common.RoleRootUser) | ||||
| 		authHelper(c, model.RoleRootUser) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TokenAuth() func(c *gin.Context) { | ||||
| 	return func(c *gin.Context) { | ||||
| 		ctx := c.Request.Context() | ||||
| 		key := c.Request.Header.Get("Authorization") | ||||
| 		key = strings.TrimPrefix(key, "Bearer ") | ||||
| 		key = strings.TrimPrefix(key, "sk-") | ||||
| @@ -94,21 +100,40 @@ func TokenAuth() func(c *gin.Context) { | ||||
| 			abortWithMessage(c, http.StatusUnauthorized, err.Error()) | ||||
| 			return | ||||
| 		} | ||||
| 		if token.Subnet != nil && *token.Subnet != "" { | ||||
| 			if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) { | ||||
| 				abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP())) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 		userEnabled, err := model.CacheIsUserEnabled(token.UserId) | ||||
| 		if err != nil { | ||||
| 			abortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 			return | ||||
| 		} | ||||
| 		if !userEnabled { | ||||
| 		if !userEnabled || blacklist.IsUserBanned(token.UserId) { | ||||
| 			abortWithMessage(c, http.StatusForbidden, "用户已被封禁") | ||||
| 			return | ||||
| 		} | ||||
| 		requestModel, err := getRequestModel(c) | ||||
| 		if err != nil && shouldCheckModel(c) { | ||||
| 			abortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||
| 			return | ||||
| 		} | ||||
| 		c.Set("request_model", requestModel) | ||||
| 		if token.Models != nil && *token.Models != "" { | ||||
| 			c.Set("available_models", *token.Models) | ||||
| 			if requestModel != "" && !isModelInList(requestModel, *token.Models) { | ||||
| 				abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 		c.Set("id", token.UserId) | ||||
| 		c.Set("token_id", token.Id) | ||||
| 		c.Set("token_name", token.Name) | ||||
| 		if len(parts) > 1 { | ||||
| 			if model.IsAdmin(token.UserId) { | ||||
| 				c.Set("channelId", parts[1]) | ||||
| 				c.Set("specific_channel_id", parts[1]) | ||||
| 			} else { | ||||
| 				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") | ||||
| 				return | ||||
| @@ -117,3 +142,19 @@ func TokenAuth() func(c *gin.Context) { | ||||
| 		c.Next() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func shouldCheckModel(c *gin.Context) bool { | ||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { | ||||
| 		return true | ||||
| 	} | ||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { | ||||
| 		return true | ||||
| 	} | ||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/images") { | ||||
| 		return true | ||||
| 	} | ||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||
| 		return true | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|   | ||||
| @@ -2,14 +2,13 @@ package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type ModelRequest struct { | ||||
| @@ -21,8 +20,9 @@ func Distribute() func(c *gin.Context) { | ||||
| 		userId := c.GetInt("id") | ||||
| 		userGroup, _ := model.CacheGetUserGroup(userId) | ||||
| 		c.Set("group", userGroup) | ||||
| 		var requestModel string | ||||
| 		var channel *model.Channel | ||||
| 		channelId, ok := c.Get("channelId") | ||||
| 		channelId, ok := c.Get("specific_channel_id") | ||||
| 		if ok { | ||||
| 			id, err := strconv.Atoi(channelId.(string)) | ||||
| 			if err != nil { | ||||
| @@ -34,41 +34,16 @@ func Distribute() func(c *gin.Context) { | ||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") | ||||
| 				return | ||||
| 			} | ||||
| 			if channel.Status != common.ChannelStatusEnabled { | ||||
| 			if channel.Status != model.ChannelStatusEnabled { | ||||
| 				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			// Select a channel for the user | ||||
| 			var modelRequest ModelRequest | ||||
| 			err := common.UnmarshalBodyReusable(c, &modelRequest) | ||||
| 			requestModel = c.GetString("request_model") | ||||
| 			var err error | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false) | ||||
| 			if err != nil { | ||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的请求") | ||||
| 				return | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "text-moderation-stable" | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = c.Param("model") | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "dall-e-2" | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "whisper-1" | ||||
| 				} | ||||
| 			} | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) | ||||
| 			if err != nil { | ||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | ||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel) | ||||
| 				if channel != nil { | ||||
| 					logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | ||||
| 					message = "数据库一致性已被破坏,请联系管理员" | ||||
| @@ -77,29 +52,34 @@ func Distribute() func(c *gin.Context) { | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 		c.Set("channel", channel.Type) | ||||
| 		c.Set("channel_id", channel.Id) | ||||
| 		c.Set("channel_name", channel.Name) | ||||
| 		c.Set("model_mapping", channel.GetModelMapping()) | ||||
| 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||
| 		c.Set("base_url", channel.GetBaseURL()) | ||||
| 		// this is for backward compatibility | ||||
| 		switch channel.Type { | ||||
| 		case common.ChannelTypeAzure: | ||||
| 			c.Set(common.ConfigKeyAPIVersion, channel.Other) | ||||
| 		case common.ChannelTypeXunfei: | ||||
| 			c.Set(common.ConfigKeyAPIVersion, channel.Other) | ||||
| 		case common.ChannelTypeGemini: | ||||
| 			c.Set(common.ConfigKeyAPIVersion, channel.Other) | ||||
| 		case common.ChannelTypeAIProxyLibrary: | ||||
| 			c.Set(common.ConfigKeyLibraryID, channel.Other) | ||||
| 		case common.ChannelTypeAli: | ||||
| 			c.Set(common.ConfigKeyPlugin, channel.Other) | ||||
| 		} | ||||
| 		cfg, _ := channel.LoadConfig() | ||||
| 		for k, v := range cfg { | ||||
| 			c.Set(common.ConfigKeyPrefix+k, v) | ||||
| 		} | ||||
| 		SetupContextForSelectedChannel(c, channel, requestModel) | ||||
| 		c.Next() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { | ||||
| 	c.Set("channel", channel.Type) | ||||
| 	c.Set("channel_id", channel.Id) | ||||
| 	c.Set("channel_name", channel.Name) | ||||
| 	c.Set("model_mapping", channel.GetModelMapping()) | ||||
| 	c.Set("original_model", modelName) // for retry | ||||
| 	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||
| 	c.Set("base_url", channel.GetBaseURL()) | ||||
| 	// this is for backward compatibility | ||||
| 	switch channel.Type { | ||||
| 	case channeltype.Azure: | ||||
| 		c.Set(config.KeyAPIVersion, channel.Other) | ||||
| 	case channeltype.Xunfei: | ||||
| 		c.Set(config.KeyAPIVersion, channel.Other) | ||||
| 	case channeltype.Gemini: | ||||
| 		c.Set(config.KeyAPIVersion, channel.Other) | ||||
| 	case channeltype.AIProxyLibrary: | ||||
| 		c.Set(config.KeyLibraryID, channel.Other) | ||||
| 	case channeltype.Ali: | ||||
| 		c.Set(config.KeyPlugin, channel.Other) | ||||
| 	} | ||||
| 	cfg, _ := channel.LoadConfig() | ||||
| 	for k, v := range cfg { | ||||
| 		c.Set(config.KeyPrefix+k, v) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package middleware | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"net/http" | ||||
| 	"runtime/debug" | ||||
| @@ -12,11 +13,15 @@ func RelayPanicRecover() gin.HandlerFunc { | ||||
| 	return func(c *gin.Context) { | ||||
| 		defer func() { | ||||
| 			if err := recover(); err != nil { | ||||
| 				logger.SysError(fmt.Sprintf("panic detected: %v", err)) | ||||
| 				logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) | ||||
| 				ctx := c.Request.Context() | ||||
| 				logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err)) | ||||
| 				logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) | ||||
| 				logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path)) | ||||
| 				body, _ := common.GetRequestBody(c) | ||||
| 				logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body))) | ||||
| 				c.JSON(http.StatusInternalServerError, gin.H{ | ||||
| 					"error": gin.H{ | ||||
| 						"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), | ||||
| 						"message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err), | ||||
| 						"type":    "one_api_panic", | ||||
| 					}, | ||||
| 				}) | ||||
|   | ||||
| @@ -9,7 +9,7 @@ import ( | ||||
|  | ||||
| func RequestId() func(c *gin.Context) { | ||||
| 	return func(c *gin.Context) { | ||||
| 		id := helper.GetTimeString() + helper.GetRandomString(8) | ||||
| 		id := helper.GenRequestID() | ||||
| 		c.Set(logger.RequestIdKey, id) | ||||
| 		ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) | ||||
| 		c.Request = c.Request.WithContext(ctx) | ||||
|   | ||||
| @@ -1,9 +1,12 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func abortWithMessage(c *gin.Context, statusCode int, message string) { | ||||
| @@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { | ||||
| 	c.Abort() | ||||
| 	logger.Error(c.Request.Context(), message) | ||||
| } | ||||
|  | ||||
| func getRequestModel(c *gin.Context) (string, error) { | ||||
| 	var modelRequest ModelRequest | ||||
| 	err := common.UnmarshalBodyReusable(c, &modelRequest) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) | ||||
| 	} | ||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||
| 		if modelRequest.Model == "" { | ||||
| 			modelRequest.Model = "text-moderation-stable" | ||||
| 		} | ||||
| 	} | ||||
| 	if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||
| 		if modelRequest.Model == "" { | ||||
| 			modelRequest.Model = c.Param("model") | ||||
| 		} | ||||
| 	} | ||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||
| 		if modelRequest.Model == "" { | ||||
| 			modelRequest.Model = "dall-e-2" | ||||
| 		} | ||||
| 	} | ||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { | ||||
| 		if modelRequest.Model == "" { | ||||
| 			modelRequest.Model = "whisper-1" | ||||
| 		} | ||||
| 	} | ||||
| 	return modelRequest.Model, nil | ||||
| } | ||||
|  | ||||
| func isModelInList(modelName string, models string) bool { | ||||
| 	modelList := strings.Split(models, ",") | ||||
| 	for _, model := range modelList { | ||||
| 		if modelName == model { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|   | ||||
| @@ -1,7 +1,10 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"gorm.io/gorm" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| @@ -13,7 +16,7 @@ type Ability struct { | ||||
| 	Priority  *int64 `json:"priority" gorm:"bigint;default:0;index"` | ||||
| } | ||||
|  | ||||
| func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||
| func GetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { | ||||
| 	ability := Ability{} | ||||
| 	groupCol := "`group`" | ||||
| 	trueVal := "1" | ||||
| @@ -23,8 +26,13 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||
| 	} | ||||
|  | ||||
| 	var err error = nil | ||||
| 	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) | ||||
| 	channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) | ||||
| 	var channelQuery *gorm.DB | ||||
| 	if ignoreFirstPriority { | ||||
| 		channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) | ||||
| 	} else { | ||||
| 		maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) | ||||
| 		channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) | ||||
| 	} | ||||
| 	if common.UsingSQLite || common.UsingPostgreSQL { | ||||
| 		err = channelQuery.Order("RANDOM()").First(&ability).Error | ||||
| 	} else { | ||||
| @@ -49,7 +57,7 @@ func (channel *Channel) AddAbilities() error { | ||||
| 				Group:     group, | ||||
| 				Model:     model, | ||||
| 				ChannelId: channel.Id, | ||||
| 				Enabled:   channel.Status == common.ChannelStatusEnabled, | ||||
| 				Enabled:   channel.Status == ChannelStatusEnabled, | ||||
| 				Priority:  channel.Priority, | ||||
| 			} | ||||
| 			abilities = append(abilities, ability) | ||||
| @@ -82,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error { | ||||
| func UpdateAbilityStatus(channelId int, status bool) error { | ||||
| 	return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error | ||||
| } | ||||
|  | ||||
| func GetGroupModels(ctx context.Context, group string) ([]string, error) { | ||||
| 	groupCol := "`group`" | ||||
| 	trueVal := "1" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		groupCol = `"group"` | ||||
| 		trueVal = "true" | ||||
| 	} | ||||
| 	var models []string | ||||
| 	err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	sort.Strings(models) | ||||
| 	return models, err | ||||
| } | ||||
|   | ||||
| @@ -1,12 +1,14 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"math/rand" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| @@ -20,6 +22,7 @@ var ( | ||||
| 	UserId2GroupCacheSeconds  = config.SyncFrequency | ||||
| 	UserId2QuotaCacheSeconds  = config.SyncFrequency | ||||
| 	UserId2StatusCacheSeconds = config.SyncFrequency | ||||
| 	GroupModelsCacheSeconds   = config.SyncFrequency | ||||
| ) | ||||
|  | ||||
| func CacheGetTokenByKey(key string) (*Token, error) { | ||||
| @@ -70,31 +73,42 @@ func CacheGetUserGroup(id int) (group string, err error) { | ||||
| 	return group, err | ||||
| } | ||||
|  | ||||
| func CacheGetUserQuota(id int) (quota int, err error) { | ||||
| func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) { | ||||
| 	quota, err = GetUserQuota(id) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||
| 	if err != nil { | ||||
| 		logger.Error(ctx, "Redis set user quota error: "+err.Error()) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) { | ||||
| 	if !common.RedisEnabled { | ||||
| 		return GetUserQuota(id) | ||||
| 	} | ||||
| 	quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) | ||||
| 	if err != nil { | ||||
| 		quota, err = GetUserQuota(id) | ||||
| 		if err != nil { | ||||
| 			return 0, err | ||||
| 		} | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("Redis set user quota error: " + err.Error()) | ||||
| 		} | ||||
| 		return quota, err | ||||
| 		return fetchAndUpdateUserQuota(ctx, id) | ||||
| 	} | ||||
| 	quota, err = strconv.Atoi(quotaString) | ||||
| 	return quota, err | ||||
| 	quota, err = strconv.ParseInt(quotaString, 10, 64) | ||||
| 	if err != nil { | ||||
| 		return 0, nil | ||||
| 	} | ||||
| 	if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db | ||||
| 		logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id) | ||||
| 		return fetchAndUpdateUserQuota(ctx, id) | ||||
| 	} | ||||
| 	return quota, nil | ||||
| } | ||||
|  | ||||
| func CacheUpdateUserQuota(id int) error { | ||||
| func CacheUpdateUserQuota(ctx context.Context, id int) error { | ||||
| 	if !common.RedisEnabled { | ||||
| 		return nil | ||||
| 	} | ||||
| 	quota, err := GetUserQuota(id) | ||||
| 	quota, err := CacheGetUserQuota(ctx, id) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -102,7 +116,7 @@ func CacheUpdateUserQuota(id int) error { | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func CacheDecreaseUserQuota(id int, quota int) error { | ||||
| func CacheDecreaseUserQuota(id int, quota int64) error { | ||||
| 	if !common.RedisEnabled { | ||||
| 		return nil | ||||
| 	} | ||||
| @@ -134,13 +148,32 @@ func CacheIsUserEnabled(userId int) (bool, error) { | ||||
| 	return userEnabled, err | ||||
| } | ||||
|  | ||||
| func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { | ||||
| 	if !common.RedisEnabled { | ||||
| 		return GetGroupModels(ctx, group) | ||||
| 	} | ||||
| 	modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group)) | ||||
| 	if err == nil { | ||||
| 		return strings.Split(modelsStr, ","), nil | ||||
| 	} | ||||
| 	models, err := GetGroupModels(ctx, group) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second) | ||||
| 	if err != nil { | ||||
| 		logger.SysError("Redis set group models error: " + err.Error()) | ||||
| 	} | ||||
| 	return models, nil | ||||
| } | ||||
|  | ||||
| var group2model2channels map[string]map[string][]*Channel | ||||
| var channelSyncLock sync.RWMutex | ||||
|  | ||||
| func InitChannelCache() { | ||||
| 	newChannelId2channel := make(map[int]*Channel) | ||||
| 	var channels []*Channel | ||||
| 	DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) | ||||
| 	DB.Where("status = ?", ChannelStatusEnabled).Find(&channels) | ||||
| 	for _, channel := range channels { | ||||
| 		newChannelId2channel[channel.Id] = channel | ||||
| 	} | ||||
| @@ -191,9 +224,9 @@ func SyncChannelCache(frequency int) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||
| func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { | ||||
| 	if !config.MemoryCacheEnabled { | ||||
| 		return GetRandomSatisfiedChannel(group, model) | ||||
| 		return GetRandomSatisfiedChannel(group, model, ignoreFirstPriority) | ||||
| 	} | ||||
| 	channelSyncLock.RLock() | ||||
| 	defer channelSyncLock.RUnlock() | ||||
| @@ -213,5 +246,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error | ||||
| 		} | ||||
| 	} | ||||
| 	idx := rand.Intn(endIdx) | ||||
| 	if ignoreFirstPriority { | ||||
| 		if endIdx < len(channels) { // which means there are more than one priority | ||||
| 			idx = random.RandRange(endIdx, len(channels)) | ||||
| 		} | ||||
| 	} | ||||
| 	return channels[idx], nil | ||||
| } | ||||
|   | ||||
| @@ -3,17 +3,23 @@ package model | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	ChannelStatusUnknown          = 0 | ||||
| 	ChannelStatusEnabled          = 1 // don't use 0, 0 is the default value! | ||||
| 	ChannelStatusManuallyDisabled = 2 // also don't use 0 | ||||
| 	ChannelStatusAutoDisabled     = 3 | ||||
| ) | ||||
|  | ||||
| type Channel struct { | ||||
| 	Id                 int     `json:"id"` | ||||
| 	Type               int     `json:"type" gorm:"default:0"` | ||||
| 	Key                string  `json:"key" gorm:"not null;index"` | ||||
| 	Key                string  `json:"key" gorm:"type:text"` | ||||
| 	Status             int     `json:"status" gorm:"default:1"` | ||||
| 	Name               string  `json:"name" gorm:"index"` | ||||
| 	Weight             *uint   `json:"weight" gorm:"default:0"` | ||||
| @@ -32,23 +38,22 @@ type Channel struct { | ||||
| 	Config             string  `json:"config"` | ||||
| } | ||||
|  | ||||
| func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | ||||
| func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { | ||||
| 	var channels []*Channel | ||||
| 	var err error | ||||
| 	if selectAll { | ||||
| 	switch scope { | ||||
| 	case "all": | ||||
| 		err = DB.Order("id desc").Find(&channels).Error | ||||
| 	} else { | ||||
| 	case "disabled": | ||||
| 		err = DB.Order("id desc").Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Find(&channels).Error | ||||
| 	default: | ||||
| 		err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error | ||||
| 	} | ||||
| 	return channels, err | ||||
| } | ||||
|  | ||||
| func SearchChannels(keyword string) (channels []*Channel, err error) { | ||||
| 	keyCol := "`key`" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		keyCol = `"key"` | ||||
| 	} | ||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", helper.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error | ||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ?", helper.String2Int(keyword), keyword+"%").Find(&channels).Error | ||||
| 	return channels, err | ||||
| } | ||||
|  | ||||
| @@ -169,7 +174,7 @@ func (channel *Channel) LoadConfig() (map[string]string, error) { | ||||
| } | ||||
|  | ||||
| func UpdateChannelStatusById(id int, status int) { | ||||
| 	err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) | ||||
| 	err := UpdateAbilityStatus(id, status == ChannelStatusEnabled) | ||||
| 	if err != nil { | ||||
| 		logger.SysError("failed to update ability status: " + err.Error()) | ||||
| 	} | ||||
| @@ -179,7 +184,7 @@ func UpdateChannelStatusById(id int, status int) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func UpdateChannelUsedQuota(id int, quota int) { | ||||
| func UpdateChannelUsedQuota(id int, quota int64) { | ||||
| 	if config.BatchUpdateEnabled { | ||||
| 		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) | ||||
| 		return | ||||
| @@ -187,7 +192,7 @@ func UpdateChannelUsedQuota(id int, quota int) { | ||||
| 	updateChannelUsedQuota(id, quota) | ||||
| } | ||||
|  | ||||
| func updateChannelUsedQuota(id int, quota int) { | ||||
| func updateChannelUsedQuota(id int, quota int64) { | ||||
| 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | ||||
| 	if err != nil { | ||||
| 		logger.SysError("failed to update channel used quota: " + err.Error()) | ||||
| @@ -200,6 +205,6 @@ func DeleteChannelByStatus(status int64) (int64, error) { | ||||
| } | ||||
|  | ||||
| func DeleteDisabledChannel() (int64, error) { | ||||
| 	result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) | ||||
| 	result := DB.Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Delete(&Channel{}) | ||||
| 	return result.RowsAffected, result.Error | ||||
| } | ||||
|   | ||||
							
								
								
									
										46
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										46
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -7,7 +7,6 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| @@ -45,13 +44,28 @@ func RecordLog(userId int, logType int, content string) { | ||||
| 		Type:      logType, | ||||
| 		Content:   content, | ||||
| 	} | ||||
| 	err := DB.Create(log).Error | ||||
| 	err := LOG_DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| 		logger.SysError("failed to record log: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | ||||
| func RecordTopupLog(userId int, content string, quota int) { | ||||
| 	log := &Log{ | ||||
| 		UserId:    userId, | ||||
| 		Username:  GetUsernameById(userId), | ||||
| 		CreatedAt: helper.GetTimestamp(), | ||||
| 		Type:      LogTypeTopup, | ||||
| 		Content:   content, | ||||
| 		Quota:     quota, | ||||
| 	} | ||||
| 	err := LOG_DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| 		logger.SysError("failed to record log: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { | ||||
| 	logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||
| 	if !config.LogConsumeEnabled { | ||||
| 		return | ||||
| @@ -66,10 +80,10 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke | ||||
| 		CompletionTokens: completionTokens, | ||||
| 		TokenName:        tokenName, | ||||
| 		ModelName:        modelName, | ||||
| 		Quota:            quota, | ||||
| 		Quota:            int(quota), | ||||
| 		ChannelId:        channelId, | ||||
| 	} | ||||
| 	err := DB.Create(log).Error | ||||
| 	err := LOG_DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| 		logger.Error(ctx, "failed to record log: "+err.Error()) | ||||
| 	} | ||||
| @@ -78,9 +92,9 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke | ||||
| func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { | ||||
| 	var tx *gorm.DB | ||||
| 	if logType == LogTypeUnknown { | ||||
| 		tx = DB | ||||
| 		tx = LOG_DB | ||||
| 	} else { | ||||
| 		tx = DB.Where("type = ?", logType) | ||||
| 		tx = LOG_DB.Where("type = ?", logType) | ||||
| 	} | ||||
| 	if modelName != "" { | ||||
| 		tx = tx.Where("model_name = ?", modelName) | ||||
| @@ -107,9 +121,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName | ||||
| func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { | ||||
| 	var tx *gorm.DB | ||||
| 	if logType == LogTypeUnknown { | ||||
| 		tx = DB.Where("user_id = ?", userId) | ||||
| 		tx = LOG_DB.Where("user_id = ?", userId) | ||||
| 	} else { | ||||
| 		tx = DB.Where("user_id = ? and type = ?", userId, logType) | ||||
| 		tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType) | ||||
| 	} | ||||
| 	if modelName != "" { | ||||
| 		tx = tx.Where("model_name = ?", modelName) | ||||
| @@ -128,17 +142,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int | ||||
| } | ||||
|  | ||||
| func SearchAllLogs(keyword string) (logs []*Log, err error) { | ||||
| 	err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error | ||||
| 	err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error | ||||
| 	return logs, err | ||||
| } | ||||
|  | ||||
| func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | ||||
| 	err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error | ||||
| 	err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error | ||||
| 	return logs, err | ||||
| } | ||||
|  | ||||
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { | ||||
| 	tx := DB.Table("logs").Select("ifnull(sum(quota),0)") | ||||
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { | ||||
| 	tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)") | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
| 	} | ||||
| @@ -162,7 +176,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa | ||||
| } | ||||
|  | ||||
| func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | ||||
| 	tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") | ||||
| 	tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
| 	} | ||||
| @@ -183,7 +197,7 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa | ||||
| } | ||||
|  | ||||
| func DeleteOldLog(targetTimestamp int64) (int64, error) { | ||||
| 	result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) | ||||
| 	result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) | ||||
| 	return result.RowsAffected, result.Error | ||||
| } | ||||
|  | ||||
| @@ -207,7 +221,7 @@ func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatis | ||||
| 		groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day" | ||||
| 	} | ||||
|  | ||||
| 	err = DB.Raw(` | ||||
| 	err = LOG_DB.Raw(` | ||||
| 		SELECT `+groupSelect+`, | ||||
| 		model_name, count(1) as request_count, | ||||
| 		sum(quota) as quota, | ||||
|   | ||||
| @@ -4,8 +4,10 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/env" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"gorm.io/driver/mysql" | ||||
| 	"gorm.io/driver/postgres" | ||||
| 	"gorm.io/driver/sqlite" | ||||
| @@ -16,12 +18,13 @@ import ( | ||||
| ) | ||||
|  | ||||
| var DB *gorm.DB | ||||
| var LOG_DB *gorm.DB | ||||
|  | ||||
| func createRootAccountIfNeed() error { | ||||
| func CreateRootAccountIfNeed() error { | ||||
| 	var user User | ||||
| 	//if user.Status != util.UserStatusEnabled { | ||||
| 	if err := DB.First(&user).Error; err != nil { | ||||
| 		logger.SysLog("no user exists, create a root user for you: username is root, password is 123456") | ||||
| 		logger.SysLog("no user exists, creating a root user for you: username is root, password is 123456") | ||||
| 		hashedPassword, err := common.Password2Hash("123456") | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| @@ -29,20 +32,36 @@ func createRootAccountIfNeed() error { | ||||
| 		rootUser := User{ | ||||
| 			Username:    "root", | ||||
| 			Password:    hashedPassword, | ||||
| 			Role:        common.RoleRootUser, | ||||
| 			Status:      common.UserStatusEnabled, | ||||
| 			Role:        RoleRootUser, | ||||
| 			Status:      UserStatusEnabled, | ||||
| 			DisplayName: "Root User", | ||||
| 			AccessToken: helper.GetUUID(), | ||||
| 			Quota:       100000000, | ||||
| 			AccessToken: random.GetUUID(), | ||||
| 			Quota:       500000000000000, | ||||
| 		} | ||||
| 		DB.Create(&rootUser) | ||||
| 		if config.InitialRootToken != "" { | ||||
| 			logger.SysLog("creating initial root token as requested") | ||||
| 			token := Token{ | ||||
| 				Id:             1, | ||||
| 				UserId:         rootUser.Id, | ||||
| 				Key:            config.InitialRootToken, | ||||
| 				Status:         TokenStatusEnabled, | ||||
| 				Name:           "Initial Root Token", | ||||
| 				CreatedTime:    helper.GetTimestamp(), | ||||
| 				AccessedTime:   helper.GetTimestamp(), | ||||
| 				ExpiredTime:    -1, | ||||
| 				RemainQuota:    500000000000000, | ||||
| 				UnlimitedQuota: true, | ||||
| 			} | ||||
| 			DB.Create(&token) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func chooseDB() (*gorm.DB, error) { | ||||
| 	if os.Getenv("SQL_DSN") != "" { | ||||
| 		dsn := os.Getenv("SQL_DSN") | ||||
| func chooseDB(envName string) (*gorm.DB, error) { | ||||
| 	if os.Getenv(envName) != "" { | ||||
| 		dsn := os.Getenv(envName) | ||||
| 		if strings.HasPrefix(dsn, "postgres://") { | ||||
| 			// Use PostgreSQL | ||||
| 			logger.SysLog("using PostgreSQL as database") | ||||
| @@ -56,6 +75,7 @@ func chooseDB() (*gorm.DB, error) { | ||||
| 		} | ||||
| 		// Use MySQL | ||||
| 		logger.SysLog("using MySQL as database") | ||||
| 		common.UsingMySQL = true | ||||
| 		return gorm.Open(mysql.Open(dsn), &gorm.Config{ | ||||
| 			PrepareStmt: true, // precompile SQL | ||||
| 		}) | ||||
| @@ -69,67 +89,78 @@ func chooseDB() (*gorm.DB, error) { | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func InitDB() (err error) { | ||||
| 	db, err := chooseDB() | ||||
| func InitDB(envName string) (db *gorm.DB, err error) { | ||||
| 	db, err = chooseDB(envName) | ||||
| 	if err == nil { | ||||
| 		if config.DebugEnabled { | ||||
| 		if config.DebugSQLEnabled { | ||||
| 			db = db.Debug() | ||||
| 		} | ||||
| 		DB = db | ||||
| 		sqlDB, err := DB.DB() | ||||
| 		sqlDB, err := db.DB() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) | ||||
| 		sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) | ||||
| 		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) | ||||
| 		sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) | ||||
| 		sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) | ||||
| 		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) | ||||
|  | ||||
| 		if !config.IsMasterNode { | ||||
| 			return nil | ||||
| 			return db, err | ||||
| 		} | ||||
| 		if common.UsingMySQL { | ||||
| 			_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded | ||||
| 		} | ||||
| 		logger.SysLog("database migration started") | ||||
| 		err = db.AutoMigrate(&Channel{}) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&Token{}) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&User{}) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&Option{}) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&Redemption{}) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&Ability{}) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&Log{}) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		logger.SysLog("database migrated") | ||||
| 		err = createRootAccountIfNeed() | ||||
| 		return err | ||||
| 		return db, err | ||||
| 	} else { | ||||
| 		logger.FatalLog(err) | ||||
| 	} | ||||
| 	return err | ||||
| 	return db, err | ||||
| } | ||||
|  | ||||
| func CloseDB() error { | ||||
| 	sqlDB, err := DB.DB() | ||||
| func closeDB(db *gorm.DB) error { | ||||
| 	sqlDB, err := db.DB() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	err = sqlDB.Close() | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func CloseDB() error { | ||||
| 	if LOG_DB != DB { | ||||
| 		err := closeDB(LOG_DB) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return closeDB(DB) | ||||
| } | ||||
|   | ||||
| @@ -1,9 +1,9 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| @@ -57,16 +57,18 @@ func InitOptionMap() { | ||||
| 	config.OptionMap["WeChatServerAddress"] = "" | ||||
| 	config.OptionMap["WeChatServerToken"] = "" | ||||
| 	config.OptionMap["WeChatAccountQRCodeImageURL"] = "" | ||||
| 	config.OptionMap["MessagePusherAddress"] = "" | ||||
| 	config.OptionMap["MessagePusherToken"] = "" | ||||
| 	config.OptionMap["TurnstileSiteKey"] = "" | ||||
| 	config.OptionMap["TurnstileSecretKey"] = "" | ||||
| 	config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser) | ||||
| 	config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter) | ||||
| 	config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee) | ||||
| 	config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold) | ||||
| 	config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota) | ||||
| 	config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() | ||||
| 	config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() | ||||
| 	config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() | ||||
| 	config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10) | ||||
| 	config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10) | ||||
| 	config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10) | ||||
| 	config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10) | ||||
| 	config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10) | ||||
| 	config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString() | ||||
| 	config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString() | ||||
| 	config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString() | ||||
| 	config.OptionMap["TopUpLink"] = config.TopUpLink | ||||
| 	config.OptionMap["ChatLink"] = config.ChatLink | ||||
| 	config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) | ||||
| @@ -79,6 +81,9 @@ func InitOptionMap() { | ||||
| func loadOptionsFromDatabase() { | ||||
| 	options, _ := AllOption() | ||||
| 	for _, option := range options { | ||||
| 		if option.Key == "ModelRatio" { | ||||
| 			option.Value = billingratio.AddNewMissingRatio(option.Value) | ||||
| 		} | ||||
| 		err := updateOptionMap(option.Key, option.Value) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("failed to update option map: " + err.Error()) | ||||
| @@ -167,6 +172,10 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 		config.GitHubClientId = value | ||||
| 	case "GitHubClientSecret": | ||||
| 		config.GitHubClientSecret = value | ||||
| 	case "LarkClientId": | ||||
| 		config.LarkClientId = value | ||||
| 	case "LarkClientSecret": | ||||
| 		config.LarkClientSecret = value | ||||
| 	case "Footer": | ||||
| 		config.Footer = value | ||||
| 	case "SystemName": | ||||
| @@ -179,28 +188,32 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 		config.WeChatServerToken = value | ||||
| 	case "WeChatAccountQRCodeImageURL": | ||||
| 		config.WeChatAccountQRCodeImageURL = value | ||||
| 	case "MessagePusherAddress": | ||||
| 		config.MessagePusherAddress = value | ||||
| 	case "MessagePusherToken": | ||||
| 		config.MessagePusherToken = value | ||||
| 	case "TurnstileSiteKey": | ||||
| 		config.TurnstileSiteKey = value | ||||
| 	case "TurnstileSecretKey": | ||||
| 		config.TurnstileSecretKey = value | ||||
| 	case "QuotaForNewUser": | ||||
| 		config.QuotaForNewUser, _ = strconv.Atoi(value) | ||||
| 		config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64) | ||||
| 	case "QuotaForInviter": | ||||
| 		config.QuotaForInviter, _ = strconv.Atoi(value) | ||||
| 		config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64) | ||||
| 	case "QuotaForInvitee": | ||||
| 		config.QuotaForInvitee, _ = strconv.Atoi(value) | ||||
| 		config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64) | ||||
| 	case "QuotaRemindThreshold": | ||||
| 		config.QuotaRemindThreshold, _ = strconv.Atoi(value) | ||||
| 		config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64) | ||||
| 	case "PreConsumedQuota": | ||||
| 		config.PreConsumedQuota, _ = strconv.Atoi(value) | ||||
| 		config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64) | ||||
| 	case "RetryTimes": | ||||
| 		config.RetryTimes, _ = strconv.Atoi(value) | ||||
| 	case "ModelRatio": | ||||
| 		err = common.UpdateModelRatioByJSONString(value) | ||||
| 		err = billingratio.UpdateModelRatioByJSONString(value) | ||||
| 	case "GroupRatio": | ||||
| 		err = common.UpdateGroupRatioByJSONString(value) | ||||
| 		err = billingratio.UpdateGroupRatioByJSONString(value) | ||||
| 	case "CompletionRatio": | ||||
| 		err = common.UpdateCompletionRatioByJSONString(value) | ||||
| 		err = billingratio.UpdateCompletionRatioByJSONString(value) | ||||
| 	case "TopUpLink": | ||||
| 		config.TopUpLink = value | ||||
| 	case "ChatLink": | ||||
|   | ||||
| @@ -8,13 +8,19 @@ import ( | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	RedemptionCodeStatusEnabled  = 1 // don't use 0, 0 is the default value! | ||||
| 	RedemptionCodeStatusDisabled = 2 // also don't use 0 | ||||
| 	RedemptionCodeStatusUsed     = 3 // also don't use 0 | ||||
| ) | ||||
|  | ||||
| type Redemption struct { | ||||
| 	Id           int    `json:"id"` | ||||
| 	UserId       int    `json:"user_id"` | ||||
| 	Key          string `json:"key" gorm:"type:char(32);uniqueIndex"` | ||||
| 	Status       int    `json:"status" gorm:"default:1"` | ||||
| 	Name         string `json:"name" gorm:"index"` | ||||
| 	Quota        int    `json:"quota" gorm:"default:100"` | ||||
| 	Quota        int64  `json:"quota" gorm:"bigint;default:100"` | ||||
| 	CreatedTime  int64  `json:"created_time" gorm:"bigint"` | ||||
| 	RedeemedTime int64  `json:"redeemed_time" gorm:"bigint"` | ||||
| 	Count        int    `json:"count" gorm:"-:all"` // only for api request | ||||
| @@ -42,7 +48,7 @@ func GetRedemptionById(id int) (*Redemption, error) { | ||||
| 	return &redemption, err | ||||
| } | ||||
|  | ||||
| func Redeem(key string, userId int) (quota int, err error) { | ||||
| func Redeem(key string, userId int) (quota int64, err error) { | ||||
| 	if key == "" { | ||||
| 		return 0, errors.New("未提供兑换码") | ||||
| 	} | ||||
| @@ -61,7 +67,7 @@ func Redeem(key string, userId int) (quota int, err error) { | ||||
| 		if err != nil { | ||||
| 			return errors.New("无效的兑换码") | ||||
| 		} | ||||
| 		if redemption.Status != common.RedemptionCodeStatusEnabled { | ||||
| 		if redemption.Status != RedemptionCodeStatusEnabled { | ||||
| 			return errors.New("该兑换码已被使用") | ||||
| 		} | ||||
| 		err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error | ||||
| @@ -69,7 +75,7 @@ func Redeem(key string, userId int) (quota int, err error) { | ||||
| 			return err | ||||
| 		} | ||||
| 		redemption.RedeemedTime = helper.GetTimestamp() | ||||
| 		redemption.Status = common.RedemptionCodeStatusUsed | ||||
| 		redemption.Status = RedemptionCodeStatusUsed | ||||
| 		err = tx.Save(redemption).Error | ||||
| 		return err | ||||
| 	}) | ||||
|   | ||||
| @@ -7,27 +7,48 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/message" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	TokenStatusEnabled   = 1 // don't use 0, 0 is the default value! | ||||
| 	TokenStatusDisabled  = 2 // also don't use 0 | ||||
| 	TokenStatusExpired   = 3 | ||||
| 	TokenStatusExhausted = 4 | ||||
| ) | ||||
|  | ||||
| type Token struct { | ||||
| 	Id             int    `json:"id"` | ||||
| 	UserId         int    `json:"user_id"` | ||||
| 	Key            string `json:"key" gorm:"type:char(48);uniqueIndex"` | ||||
| 	Status         int    `json:"status" gorm:"default:1"` | ||||
| 	Name           string `json:"name" gorm:"index" ` | ||||
| 	CreatedTime    int64  `json:"created_time" gorm:"bigint"` | ||||
| 	AccessedTime   int64  `json:"accessed_time" gorm:"bigint"` | ||||
| 	ExpiredTime    int64  `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired | ||||
| 	RemainQuota    int    `json:"remain_quota" gorm:"default:0"` | ||||
| 	UnlimitedQuota bool   `json:"unlimited_quota" gorm:"default:false"` | ||||
| 	UsedQuota      int    `json:"used_quota" gorm:"default:0"` // used quota | ||||
| 	Id             int     `json:"id"` | ||||
| 	UserId         int     `json:"user_id"` | ||||
| 	Key            string  `json:"key" gorm:"type:char(48);uniqueIndex"` | ||||
| 	Status         int     `json:"status" gorm:"default:1"` | ||||
| 	Name           string  `json:"name" gorm:"index" ` | ||||
| 	CreatedTime    int64   `json:"created_time" gorm:"bigint"` | ||||
| 	AccessedTime   int64   `json:"accessed_time" gorm:"bigint"` | ||||
| 	ExpiredTime    int64   `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired | ||||
| 	RemainQuota    int64   `json:"remain_quota" gorm:"bigint;default:0"` | ||||
| 	UnlimitedQuota bool    `json:"unlimited_quota" gorm:"default:false"` | ||||
| 	UsedQuota      int64   `json:"used_quota" gorm:"bigint;default:0"` // used quota | ||||
| 	Models         *string `json:"models" gorm:"default:''"`           // allowed models | ||||
| 	Subnet         *string `json:"subnet" gorm:"default:''"`           // allowed subnet | ||||
| } | ||||
|  | ||||
| func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { | ||||
| func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) { | ||||
| 	var tokens []*Token | ||||
| 	var err error | ||||
| 	err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error | ||||
| 	query := DB.Where("user_id = ?", userId) | ||||
|  | ||||
| 	switch order { | ||||
| 	case "remain_quota": | ||||
| 		query = query.Order("unlimited_quota desc, remain_quota desc") | ||||
| 	case "used_quota": | ||||
| 		query = query.Order("used_quota desc") | ||||
| 	default: | ||||
| 		query = query.Order("id desc") | ||||
| 	} | ||||
|  | ||||
| 	err = query.Limit(num).Offset(startIdx).Find(&tokens).Error | ||||
| 	return tokens, err | ||||
| } | ||||
|  | ||||
| @@ -48,17 +69,17 @@ func ValidateUserToken(key string) (token *Token, err error) { | ||||
| 		} | ||||
| 		return nil, errors.New("令牌验证失败") | ||||
| 	} | ||||
| 	if token.Status == common.TokenStatusExhausted { | ||||
| 		return nil, errors.New("该令牌额度已用尽") | ||||
| 	} else if token.Status == common.TokenStatusExpired { | ||||
| 	if token.Status == TokenStatusExhausted { | ||||
| 		return nil, fmt.Errorf("令牌 %s(#%d)额度已用尽", token.Name, token.Id) | ||||
| 	} else if token.Status == TokenStatusExpired { | ||||
| 		return nil, errors.New("该令牌已过期") | ||||
| 	} | ||||
| 	if token.Status != common.TokenStatusEnabled { | ||||
| 	if token.Status != TokenStatusEnabled { | ||||
| 		return nil, errors.New("该令牌状态不可用") | ||||
| 	} | ||||
| 	if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() { | ||||
| 		if !common.RedisEnabled { | ||||
| 			token.Status = common.TokenStatusExpired | ||||
| 			token.Status = TokenStatusExpired | ||||
| 			err := token.SelectUpdate() | ||||
| 			if err != nil { | ||||
| 				logger.SysError("failed to update token status" + err.Error()) | ||||
| @@ -69,7 +90,7 @@ func ValidateUserToken(key string) (token *Token, err error) { | ||||
| 	if !token.UnlimitedQuota && token.RemainQuota <= 0 { | ||||
| 		if !common.RedisEnabled { | ||||
| 			// in this case, we can make sure the token is exhausted | ||||
| 			token.Status = common.TokenStatusExhausted | ||||
| 			token.Status = TokenStatusExhausted | ||||
| 			err := token.SelectUpdate() | ||||
| 			if err != nil { | ||||
| 				logger.SysError("failed to update token status" + err.Error()) | ||||
| @@ -109,7 +130,7 @@ func (token *Token) Insert() error { | ||||
| // Update Make sure your token's fields is completed, because this will update non-zero values | ||||
| func (token *Token) Update() error { | ||||
| 	var err error | ||||
| 	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error | ||||
| 	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| @@ -137,7 +158,7 @@ func DeleteTokenById(id int, userId int) (err error) { | ||||
| 	return token.Delete() | ||||
| } | ||||
|  | ||||
| func IncreaseTokenQuota(id int, quota int) (err error) { | ||||
| func IncreaseTokenQuota(id int, quota int64) (err error) { | ||||
| 	if quota < 0 { | ||||
| 		return errors.New("quota 不能为负数!") | ||||
| 	} | ||||
| @@ -148,7 +169,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) { | ||||
| 	return increaseTokenQuota(id, quota) | ||||
| } | ||||
|  | ||||
| func increaseTokenQuota(id int, quota int) (err error) { | ||||
| func increaseTokenQuota(id int, quota int64) (err error) { | ||||
| 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | ||||
| 		map[string]interface{}{ | ||||
| 			"remain_quota":  gorm.Expr("remain_quota + ?", quota), | ||||
| @@ -159,7 +180,7 @@ func increaseTokenQuota(id int, quota int) (err error) { | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func DecreaseTokenQuota(id int, quota int) (err error) { | ||||
| func DecreaseTokenQuota(id int, quota int64) (err error) { | ||||
| 	if quota < 0 { | ||||
| 		return errors.New("quota 不能为负数!") | ||||
| 	} | ||||
| @@ -170,7 +191,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) { | ||||
| 	return decreaseTokenQuota(id, quota) | ||||
| } | ||||
|  | ||||
| func decreaseTokenQuota(id int, quota int) (err error) { | ||||
| func decreaseTokenQuota(id int, quota int64) (err error) { | ||||
| 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | ||||
| 		map[string]interface{}{ | ||||
| 			"remain_quota":  gorm.Expr("remain_quota - ?", quota), | ||||
| @@ -181,7 +202,7 @@ func decreaseTokenQuota(id int, quota int) (err error) { | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func PreConsumeTokenQuota(tokenId int, quota int) (err error) { | ||||
| func PreConsumeTokenQuota(tokenId int, quota int64) (err error) { | ||||
| 	if quota < 0 { | ||||
| 		return errors.New("quota 不能为负数!") | ||||
| 	} | ||||
| @@ -213,7 +234,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { | ||||
| 			} | ||||
| 			if email != "" { | ||||
| 				topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) | ||||
| 				err = common.SendEmail(prompt, email, | ||||
| 				err = message.SendEmail(prompt, email, | ||||
| 					fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) | ||||
| 				if err != nil { | ||||
| 					logger.SysError("failed to send email" + err.Error()) | ||||
| @@ -231,7 +252,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func PostConsumeTokenQuota(tokenId int, quota int) (err error) { | ||||
| func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { | ||||
| 	token, err := GetTokenById(tokenId) | ||||
| 	if quota > 0 { | ||||
| 		err = DecreaseUserQuota(token.UserId, quota) | ||||
|   | ||||
| @@ -4,13 +4,27 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/blacklist" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"gorm.io/gorm" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	RoleGuestUser  = 0 | ||||
| 	RoleCommonUser = 1 | ||||
| 	RoleAdminUser  = 10 | ||||
| 	RoleRootUser   = 100 | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	UserStatusEnabled  = 1 // don't use 0, 0 is the default value! | ||||
| 	UserStatusDisabled = 2 // also don't use 0 | ||||
| 	UserStatusDeleted  = 3 | ||||
| ) | ||||
|  | ||||
| // User if you add sensitive fields, don't forget to clean them in setupLogin function. | ||||
| // Otherwise, the sensitive information will be saved on local storage in plain text! | ||||
| type User struct { | ||||
| @@ -23,11 +37,12 @@ type User struct { | ||||
| 	Email            string `json:"email" gorm:"index" validate:"max=50"` | ||||
| 	GitHubId         string `json:"github_id" gorm:"column:github_id;index"` | ||||
| 	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"` | ||||
| 	LarkId           string `json:"lark_id" gorm:"column:lark_id;index"` | ||||
| 	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database! | ||||
| 	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management | ||||
| 	Quota            int    `json:"quota" gorm:"type:int;default:0"` | ||||
| 	UsedQuota        int    `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota | ||||
| 	RequestCount     int    `json:"request_count" gorm:"type:int;default:0;"`               // request number | ||||
| 	Quota            int64  `json:"quota" gorm:"bigint;default:0"` | ||||
| 	UsedQuota        int64  `json:"used_quota" gorm:"bigint;default:0;column:used_quota"` // used quota | ||||
| 	RequestCount     int    `json:"request_count" gorm:"type:int;default:0;"`             // request number | ||||
| 	Group            string `json:"group" gorm:"type:varchar(32);default:'default'"` | ||||
| 	AffCode          string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` | ||||
| 	InviterId        int    `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` | ||||
| @@ -39,8 +54,21 @@ func GetMaxUserId() int { | ||||
| 	return user.Id | ||||
| } | ||||
|  | ||||
| func GetAllUsers(startIdx int, num int) (users []*User, err error) { | ||||
| 	err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error | ||||
| func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) { | ||||
| 	query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted) | ||||
|  | ||||
| 	switch order { | ||||
| 	case "quota": | ||||
| 		query = query.Order("quota desc") | ||||
| 	case "used_quota": | ||||
| 		query = query.Order("used_quota desc") | ||||
| 	case "request_count": | ||||
| 		query = query.Order("request_count desc") | ||||
| 	default: | ||||
| 		query = query.Order("id desc") | ||||
| 	} | ||||
|  | ||||
| 	err = query.Find(&users).Error | ||||
| 	return users, err | ||||
| } | ||||
|  | ||||
| @@ -93,8 +121,8 @@ func (user *User) Insert(inviterId int) error { | ||||
| 		} | ||||
| 	} | ||||
| 	user.Quota = config.QuotaForNewUser | ||||
| 	user.AccessToken = helper.GetUUID() | ||||
| 	user.AffCode = helper.GetRandomString(4) | ||||
| 	user.AccessToken = random.GetUUID() | ||||
| 	user.AffCode = random.GetRandomString(4) | ||||
| 	result := DB.Create(user) | ||||
| 	if result.Error != nil { | ||||
| 		return result.Error | ||||
| @@ -123,6 +151,11 @@ func (user *User) Update(updatePassword bool) error { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	if user.Status == UserStatusDisabled { | ||||
| 		blacklist.BanUser(user.Id) | ||||
| 	} else if user.Status == UserStatusEnabled { | ||||
| 		blacklist.UnbanUser(user.Id) | ||||
| 	} | ||||
| 	err = DB.Model(user).Updates(user).Error | ||||
| 	return err | ||||
| } | ||||
| @@ -131,7 +164,10 @@ func (user *User) Delete() error { | ||||
| 	if user.Id == 0 { | ||||
| 		return errors.New("id 为空!") | ||||
| 	} | ||||
| 	err := DB.Delete(user).Error | ||||
| 	blacklist.BanUser(user.Id) | ||||
| 	user.Username = fmt.Sprintf("deleted_%s", random.GetUUID()) | ||||
| 	user.Status = UserStatusDeleted | ||||
| 	err := DB.Model(user).Updates(user).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| @@ -154,7 +190,7 @@ func (user *User) ValidateAndFill() (err error) { | ||||
| 		} | ||||
| 	} | ||||
| 	okay := common.ValidatePasswordAndHash(password, user.Password) | ||||
| 	if !okay || user.Status != common.UserStatusEnabled { | ||||
| 	if !okay || user.Status != UserStatusEnabled { | ||||
| 		return errors.New("用户名或密码错误,或用户已被封禁") | ||||
| 	} | ||||
| 	return nil | ||||
| @@ -184,6 +220,14 @@ func (user *User) FillUserByGitHubId() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (user *User) FillUserByLarkId() error { | ||||
| 	if user.LarkId == "" { | ||||
| 		return errors.New("lark id 为空!") | ||||
| 	} | ||||
| 	DB.Where(User{LarkId: user.LarkId}).First(user) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (user *User) FillUserByWeChatId() error { | ||||
| 	if user.WeChatId == "" { | ||||
| 		return errors.New("WeChat id 为空!") | ||||
| @@ -212,6 +256,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool { | ||||
| 	return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|  | ||||
| func IsLarkIdAlreadyTaken(githubId string) bool { | ||||
| 	return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|  | ||||
| func IsUsernameAlreadyTaken(username string) bool { | ||||
| 	return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
| @@ -238,7 +286,7 @@ func IsAdmin(userId int) bool { | ||||
| 		logger.SysError("no such user " + err.Error()) | ||||
| 		return false | ||||
| 	} | ||||
| 	return user.Role >= common.RoleAdminUser | ||||
| 	return user.Role >= RoleAdminUser | ||||
| } | ||||
|  | ||||
| func IsUserEnabled(userId int) (bool, error) { | ||||
| @@ -250,7 +298,7 @@ func IsUserEnabled(userId int) (bool, error) { | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
| 	return user.Status == common.UserStatusEnabled, nil | ||||
| 	return user.Status == UserStatusEnabled, nil | ||||
| } | ||||
|  | ||||
| func ValidateAccessToken(token string) (user *User) { | ||||
| @@ -265,12 +313,12 @@ func ValidateAccessToken(token string) (user *User) { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func GetUserQuota(id int) (quota int, err error) { | ||||
| func GetUserQuota(id int) (quota int64, err error) { | ||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error | ||||
| 	return quota, err | ||||
| } | ||||
|  | ||||
| func GetUserUsedQuota(id int) (quota int, err error) { | ||||
| func GetUserUsedQuota(id int) (quota int64, err error) { | ||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error | ||||
| 	return quota, err | ||||
| } | ||||
| @@ -290,7 +338,7 @@ func GetUserGroup(id int) (group string, err error) { | ||||
| 	return group, err | ||||
| } | ||||
|  | ||||
| func IncreaseUserQuota(id int, quota int) (err error) { | ||||
| func IncreaseUserQuota(id int, quota int64) (err error) { | ||||
| 	if quota < 0 { | ||||
| 		return errors.New("quota 不能为负数!") | ||||
| 	} | ||||
| @@ -301,12 +349,12 @@ func IncreaseUserQuota(id int, quota int) (err error) { | ||||
| 	return increaseUserQuota(id, quota) | ||||
| } | ||||
|  | ||||
| func increaseUserQuota(id int, quota int) (err error) { | ||||
| func increaseUserQuota(id int, quota int64) (err error) { | ||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func DecreaseUserQuota(id int, quota int) (err error) { | ||||
| func DecreaseUserQuota(id int, quota int64) (err error) { | ||||
| 	if quota < 0 { | ||||
| 		return errors.New("quota 不能为负数!") | ||||
| 	} | ||||
| @@ -317,17 +365,17 @@ func DecreaseUserQuota(id int, quota int) (err error) { | ||||
| 	return decreaseUserQuota(id, quota) | ||||
| } | ||||
|  | ||||
| func decreaseUserQuota(id int, quota int) (err error) { | ||||
| func decreaseUserQuota(id int, quota int64) (err error) { | ||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func GetRootUserEmail() (email string) { | ||||
| 	DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) | ||||
| 	DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email) | ||||
| 	return email | ||||
| } | ||||
|  | ||||
| func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | ||||
| func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) { | ||||
| 	if config.BatchUpdateEnabled { | ||||
| 		addNewRecord(BatchUpdateTypeUsedQuota, id, quota) | ||||
| 		addNewRecord(BatchUpdateTypeRequestCount, id, 1) | ||||
| @@ -336,7 +384,7 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | ||||
| 	updateUserUsedQuotaAndRequestCount(id, quota, 1) | ||||
| } | ||||
|  | ||||
| func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | ||||
| func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) { | ||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | ||||
| 		map[string]interface{}{ | ||||
| 			"used_quota":    gorm.Expr("used_quota + ?", quota), | ||||
| @@ -348,7 +396,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func updateUserUsedQuota(id int, quota int) { | ||||
| func updateUserUsedQuota(id int, quota int64) { | ||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | ||||
| 		map[string]interface{}{ | ||||
| 			"used_quota": gorm.Expr("used_quota + ?", quota), | ||||
|   | ||||
| @@ -16,12 +16,12 @@ const ( | ||||
| 	BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock | ||||
| ) | ||||
|  | ||||
| var batchUpdateStores []map[int]int | ||||
| var batchUpdateStores []map[int]int64 | ||||
| var batchUpdateLocks []sync.Mutex | ||||
|  | ||||
| func init() { | ||||
| 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||
| 		batchUpdateStores = append(batchUpdateStores, make(map[int]int)) | ||||
| 		batchUpdateStores = append(batchUpdateStores, make(map[int]int64)) | ||||
| 		batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) | ||||
| 	} | ||||
| } | ||||
| @@ -35,7 +35,7 @@ func InitBatchUpdater() { | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| func addNewRecord(type_ int, id int, value int) { | ||||
| func addNewRecord(type_ int, id int, value int64) { | ||||
| 	batchUpdateLocks[type_].Lock() | ||||
| 	defer batchUpdateLocks[type_].Unlock() | ||||
| 	if _, ok := batchUpdateStores[type_][id]; !ok { | ||||
| @@ -50,7 +50,7 @@ func batchUpdate() { | ||||
| 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||
| 		batchUpdateLocks[i].Lock() | ||||
| 		store := batchUpdateStores[i] | ||||
| 		batchUpdateStores[i] = make(map[int]int) | ||||
| 		batchUpdateStores[i] = make(map[int]int64) | ||||
| 		batchUpdateLocks[i].Unlock() | ||||
| 		// TODO: maybe we can combine updates with same key? | ||||
| 		for key, value := range store { | ||||
| @@ -68,7 +68,7 @@ func batchUpdate() { | ||||
| 			case BatchUpdateTypeUsedQuota: | ||||
| 				updateUserUsedQuota(key, value) | ||||
| 			case BatchUpdateTypeRequestCount: | ||||
| 				updateUserRequestCount(key, value) | ||||
| 				updateUserRequestCount(key, int(value)) | ||||
| 			case BatchUpdateTypeChannelUsedQuota: | ||||
| 				updateChannelUsedQuota(key, value) | ||||
| 			} | ||||
|   | ||||
							
								
								
									
										54
									
								
								monitor/channel.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								monitor/channel.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,54 @@ | ||||
| package monitor | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/message" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| ) | ||||
|  | ||||
| func notifyRootUser(subject string, content string) { | ||||
| 	if config.MessagePusherAddress != "" { | ||||
| 		err := message.SendMessage(subject, content, content) | ||||
| 		if err != nil { | ||||
| 			logger.SysError(fmt.Sprintf("failed to send message: %s", err.Error())) | ||||
| 		} else { | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	if config.RootUserEmail == "" { | ||||
| 		config.RootUserEmail = model.GetRootUserEmail() | ||||
| 	} | ||||
| 	err := message.SendEmail(subject, config.RootUserEmail, content) | ||||
| 	if err != nil { | ||||
| 		logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // DisableChannel disable & notify | ||||
| func DisableChannel(channelId int, channelName string, reason string) { | ||||
| 	model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled) | ||||
| 	logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) | ||||
| 	subject := fmt.Sprintf("渠道「%s」(#%d)已被禁用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("渠道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| func MetricDisableChannel(channelId int, successRate float64) { | ||||
| 	model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled) | ||||
| 	logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) | ||||
| 	subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId) | ||||
| 	content := fmt.Sprintf("该渠道(#%d)在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", | ||||
| 		channelId, config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| // EnableChannel enable & notify | ||||
| func EnableChannel(channelId int, channelName string) { | ||||
| 	model.UpdateChannelStatusById(channelId, model.ChannelStatusEnabled) | ||||
| 	logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) | ||||
| 	subject := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
							
								
								
									
										62
									
								
								monitor/manage.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								monitor/manage.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,62 @@ | ||||
| package monitor | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func ShouldDisableChannel(err *model.Error, statusCode int) bool { | ||||
| 	if !config.AutomaticDisableChannelEnabled { | ||||
| 		return false | ||||
| 	} | ||||
| 	if err == nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	if statusCode == http.StatusUnauthorized { | ||||
| 		return true | ||||
| 	} | ||||
| 	switch err.Type { | ||||
| 	case "insufficient_quota": | ||||
| 		return true | ||||
| 	// https://docs.anthropic.com/claude/reference/errors | ||||
| 	case "authentication_error": | ||||
| 		return true | ||||
| 	case "permission_error": | ||||
| 		return true | ||||
| 	case "forbidden": | ||||
| 		return true | ||||
| 	} | ||||
| 	if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | ||||
| 		return true | ||||
| 	} | ||||
| 	if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic | ||||
| 		return true | ||||
| 	} else if strings.HasPrefix(err.Message, "This organization has been disabled.") { | ||||
| 		return true | ||||
| 	} | ||||
| 	//if strings.Contains(err.Message, "quota") { | ||||
| 	//	return true | ||||
| 	//} | ||||
| 	if strings.Contains(err.Message, "credit") { | ||||
| 		return true | ||||
| 	} | ||||
| 	if strings.Contains(err.Message, "balance") { | ||||
| 		return true | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| func ShouldEnableChannel(err error, openAIErr *model.Error) bool { | ||||
| 	if !config.AutomaticEnableChannelEnabled { | ||||
| 		return false | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	if openAIErr != nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
							
								
								
									
										79
									
								
								monitor/metric.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								monitor/metric.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | ||||
| package monitor | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| ) | ||||
|  | ||||
| var store = make(map[int][]bool) | ||||
| var metricSuccessChan = make(chan int, config.MetricSuccessChanSize) | ||||
| var metricFailChan = make(chan int, config.MetricFailChanSize) | ||||
|  | ||||
| func consumeSuccess(channelId int) { | ||||
| 	if len(store[channelId]) > config.MetricQueueSize { | ||||
| 		store[channelId] = store[channelId][1:] | ||||
| 	} | ||||
| 	store[channelId] = append(store[channelId], true) | ||||
| } | ||||
|  | ||||
| func consumeFail(channelId int) (bool, float64) { | ||||
| 	if len(store[channelId]) > config.MetricQueueSize { | ||||
| 		store[channelId] = store[channelId][1:] | ||||
| 	} | ||||
| 	store[channelId] = append(store[channelId], false) | ||||
| 	successCount := 0 | ||||
| 	for _, success := range store[channelId] { | ||||
| 		if success { | ||||
| 			successCount++ | ||||
| 		} | ||||
| 	} | ||||
| 	successRate := float64(successCount) / float64(len(store[channelId])) | ||||
| 	if len(store[channelId]) < config.MetricQueueSize { | ||||
| 		return false, successRate | ||||
| 	} | ||||
| 	if successRate < config.MetricSuccessRateThreshold { | ||||
| 		store[channelId] = make([]bool, 0) | ||||
| 		return true, successRate | ||||
| 	} | ||||
| 	return false, successRate | ||||
| } | ||||
|  | ||||
| func metricSuccessConsumer() { | ||||
| 	for { | ||||
| 		select { | ||||
| 		case channelId := <-metricSuccessChan: | ||||
| 			consumeSuccess(channelId) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func metricFailConsumer() { | ||||
| 	for { | ||||
| 		select { | ||||
| 		case channelId := <-metricFailChan: | ||||
| 			disable, successRate := consumeFail(channelId) | ||||
| 			if disable { | ||||
| 				go MetricDisableChannel(channelId, successRate) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func init() { | ||||
| 	if config.EnableMetric { | ||||
| 		go metricSuccessConsumer() | ||||
| 		go metricFailConsumer() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func Emit(channelId int, success bool) { | ||||
| 	if !config.EnableMetric { | ||||
| 		return | ||||
| 	} | ||||
| 	go func() { | ||||
| 		if success { | ||||
| 			metricSuccessChan <- channelId | ||||
| 		} else { | ||||
| 			metricFailChan <- channelId | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
| @@ -1,9 +1,10 @@ | ||||
| [//]: # (请按照以下格式关联 issue) | ||||
| [//]: # (请在提交 PR 前确认所提交的功能可用,附上截图即可,这将有助于项目维护者 review & merge 该 PR,谢谢) | ||||
| [//]: # (请在提交 PR 前确认所提交的功能可用,需要附上截图,谢谢) | ||||
| [//]: # (项目维护者一般仅在周末处理 PR,因此如若未能及时回复希望能理解) | ||||
| [//]: # (开发者交流群:910657413) | ||||
| [//]: # (请在提交 PR 之前删除上面的注释) | ||||
|  | ||||
| close #issue_number | ||||
|  | ||||
| 我已确认该 PR 已自测通过,相关截图如下: | ||||
| 我已确认该 PR 已自测通过,相关截图如下: | ||||
| (此处放上测试通过的截图,如果不涉及前端改动或从 UI 上无法看出,请放终端启动成功的截图) | ||||
|   | ||||
							
								
								
									
										45
									
								
								relay/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								relay/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | ||||
| package relay | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/aiproxy" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/ali" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/anthropic" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/baidu" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/gemini" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/ollama" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/palm" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/tencent" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/xunfei" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/zhipu" | ||||
| 	"github.com/songquanpeng/one-api/relay/apitype" | ||||
| ) | ||||
|  | ||||
| func GetAdaptor(apiType int) adaptor.Adaptor { | ||||
| 	switch apiType { | ||||
| 	case apitype.AIProxyLibrary: | ||||
| 		return &aiproxy.Adaptor{} | ||||
| 	case apitype.Ali: | ||||
| 		return &ali.Adaptor{} | ||||
| 	case apitype.Anthropic: | ||||
| 		return &anthropic.Adaptor{} | ||||
| 	case apitype.Baidu: | ||||
| 		return &baidu.Adaptor{} | ||||
| 	case apitype.Gemini: | ||||
| 		return &gemini.Adaptor{} | ||||
| 	case apitype.OpenAI: | ||||
| 		return &openai.Adaptor{} | ||||
| 	case apitype.PaLM: | ||||
| 		return &palm.Adaptor{} | ||||
| 	case apitype.Tencent: | ||||
| 		return &tencent.Adaptor{} | ||||
| 	case apitype.Xunfei: | ||||
| 		return &xunfei.Adaptor{} | ||||
| 	case apitype.Zhipu: | ||||
| 		return &zhipu.Adaptor{} | ||||
| 	case apitype.Ollama: | ||||
| 		return &ollama.Adaptor{} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| @@ -4,10 +4,10 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
| @@ -15,16 +15,16 @@ import ( | ||||
| type Adaptor struct { | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) Init(meta *util.RelayMeta) { | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | ||||
| 	channel.SetupCommonRequestHeader(c, req, meta) | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	adaptor.SetupCommonRequestHeader(c, req, meta) | ||||
| 	req.Header.Set("Authorization", "Bearer "+meta.APIKey) | ||||
| 	return nil | ||||
| } | ||||
| @@ -34,15 +34,22 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	aiProxyLibraryRequest := ConvertRequest(*request) | ||||
| 	aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID) | ||||
| 	aiProxyLibraryRequest.LibraryId = c.GetString(config.KeyLibraryID) | ||||
| 	return aiProxyLibraryRequest, nil | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return channel.DoRequestHelper(a, c, meta, requestBody) | ||||
| func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	return request, nil | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return adaptor.DoRequestHelper(a, c, meta, requestBody) | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, resp) | ||||
| 	} else { | ||||
| @@ -1,6 +1,6 @@ | ||||
| package aiproxy | ||||
| 
 | ||||
| import "github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| import "github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 
 | ||||
| var ModelList = []string{""} | ||||
| 
 | ||||
| @@ -8,7 +8,8 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| @@ -53,7 +54,7 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon | ||||
| 		FinishReason: "stop", | ||||
| 	} | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      helper.GetUUID(), | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", random.GetUUID()), | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| @@ -66,7 +67,7 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion | ||||
| 	choice.Delta.Content = aiProxyDocuments2Markdown(documents) | ||||
| 	choice.FinishReason = &constant.StopFinishReason | ||||
| 	return &openai.ChatCompletionsStreamResponse{ | ||||
| 		Id:      helper.GetUUID(), | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", random.GetUUID()), | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Model:   "", | ||||
| @@ -78,7 +79,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = response.Content | ||||
| 	return &openai.ChatCompletionsStreamResponse{ | ||||
| 		Id:      helper.GetUUID(), | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", random.GetUUID()), | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Model:   response.Model, | ||||
							
								
								
									
										105
									
								
								relay/adaptor/ali/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								relay/adaptor/ali/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | ||||
| package ali | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| // https://help.aliyun.com/zh/dashscope/developer-reference/api-details | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
|  | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	fullRequestURL := "" | ||||
| 	switch meta.Mode { | ||||
| 	case relaymode.Embeddings: | ||||
| 		fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL) | ||||
| 	case relaymode.ImagesGenerations: | ||||
| 		fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL) | ||||
| 	default: | ||||
| 		fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL) | ||||
| 	} | ||||
|  | ||||
| 	return fullRequestURL, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	adaptor.SetupCommonRequestHeader(c, req, meta) | ||||
| 	if meta.IsStream { | ||||
| 		req.Header.Set("Accept", "text/event-stream") | ||||
| 		req.Header.Set("X-DashScope-SSE", "enable") | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", "Bearer "+meta.APIKey) | ||||
|  | ||||
| 	if meta.Mode == relaymode.ImagesGenerations { | ||||
| 		req.Header.Set("X-DashScope-Async", "enable") | ||||
| 	} | ||||
| 	if c.GetString(config.KeyPlugin) != "" { | ||||
| 		req.Header.Set("X-DashScope-Plugin", c.GetString(config.KeyPlugin)) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	switch relayMode { | ||||
| 	case relaymode.Embeddings: | ||||
| 		aliEmbeddingRequest := ConvertEmbeddingRequest(*request) | ||||
| 		return aliEmbeddingRequest, nil | ||||
| 	default: | ||||
| 		aliRequest := ConvertRequest(*request) | ||||
| 		return aliRequest, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
|  | ||||
| 	aliRequest := ConvertImageRequest(*request) | ||||
| 	return aliRequest, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return adaptor.DoRequestHelper(a, c, meta, requestBody) | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, resp) | ||||
| 	} else { | ||||
| 		switch meta.Mode { | ||||
| 		case relaymode.Embeddings: | ||||
| 			err, usage = EmbeddingHandler(c, resp) | ||||
| 		case relaymode.ImagesGenerations: | ||||
| 			err, usage = ImageHandler(c, resp) | ||||
| 		default: | ||||
| 			err, usage = Handler(c, resp) | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetModelList() []string { | ||||
| 	return ModelList | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return "ali" | ||||
| } | ||||
| @@ -3,4 +3,5 @@ package ali | ||||
| var ModelList = []string{ | ||||
| 	"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", | ||||
| 	"text-embedding-v1", | ||||
| 	"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", | ||||
| } | ||||
							
								
								
									
										192
									
								
								relay/adaptor/ali/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										192
									
								
								relay/adaptor/ali/image.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,192 @@ | ||||
| package ali | ||||
|  | ||||
| import ( | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	apiKey := c.Request.Header.Get("Authorization") | ||||
| 	apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 	responseFormat := c.GetString("response_format") | ||||
|  | ||||
| 	var aliTaskResponse TaskResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &aliTaskResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	if aliTaskResponse.Message != "" { | ||||
| 		logger.SysError("aliAsyncTask err: " + string(responseBody)) | ||||
| 		return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	aliResponse, _, err := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	if aliResponse.Output.TaskStatus != "SUCCEEDED" { | ||||
| 		return &model.ErrorWithStatusCode{ | ||||
| 			Error: model.Error{ | ||||
| 				Message: aliResponse.Output.Message, | ||||
| 				Type:    "ali_error", | ||||
| 				Param:   "", | ||||
| 				Code:    aliResponse.Output.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
|  | ||||
| 	fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) { | ||||
| 	url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID) | ||||
|  | ||||
| 	var aliResponse TaskResponse | ||||
|  | ||||
| 	req, err := http.NewRequest("GET", url, nil) | ||||
| 	if err != nil { | ||||
| 		return &aliResponse, err, nil | ||||
| 	} | ||||
|  | ||||
| 	req.Header.Set("Authorization", "Bearer "+key) | ||||
|  | ||||
| 	client := &http.Client{} | ||||
| 	resp, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		logger.SysError("aliAsyncTask client.Do err: " + err.Error()) | ||||
| 		return &aliResponse, err, nil | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
|  | ||||
| 	var response TaskResponse | ||||
| 	err = json.Unmarshal(responseBody, &response) | ||||
| 	if err != nil { | ||||
| 		logger.SysError("aliAsyncTask NewDecoder err: " + err.Error()) | ||||
| 		return &aliResponse, err, nil | ||||
| 	} | ||||
|  | ||||
| 	return &response, nil, responseBody | ||||
| } | ||||
|  | ||||
| func asyncTaskWait(taskID string, key string) (*TaskResponse, []byte, error) { | ||||
| 	waitSeconds := 2 | ||||
| 	step := 0 | ||||
| 	maxStep := 20 | ||||
|  | ||||
| 	var taskResponse TaskResponse | ||||
| 	var responseBody []byte | ||||
|  | ||||
| 	for { | ||||
| 		step++ | ||||
| 		rsp, err, body := asyncTask(taskID, key) | ||||
| 		responseBody = body | ||||
| 		if err != nil { | ||||
| 			return &taskResponse, responseBody, err | ||||
| 		} | ||||
|  | ||||
| 		if rsp.Output.TaskStatus == "" { | ||||
| 			return &taskResponse, responseBody, nil | ||||
| 		} | ||||
|  | ||||
| 		switch rsp.Output.TaskStatus { | ||||
| 		case "FAILED": | ||||
| 			fallthrough | ||||
| 		case "CANCELED": | ||||
| 			fallthrough | ||||
| 		case "SUCCEEDED": | ||||
| 			fallthrough | ||||
| 		case "UNKNOWN": | ||||
| 			return rsp, responseBody, nil | ||||
| 		} | ||||
| 		if step >= maxStep { | ||||
| 			break | ||||
| 		} | ||||
| 		time.Sleep(time.Duration(waitSeconds) * time.Second) | ||||
| 	} | ||||
|  | ||||
| 	return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout") | ||||
| } | ||||
|  | ||||
| func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse { | ||||
| 	imageResponse := openai.ImageResponse{ | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 	} | ||||
|  | ||||
| 	for _, data := range response.Output.Results { | ||||
| 		var b64Json string | ||||
| 		if responseFormat == "b64_json" { | ||||
| 			// 读取 data.Url 的图片数据并转存到 b64Json | ||||
| 			imageData, err := getImageData(data.Url) | ||||
| 			if err != nil { | ||||
| 				// 处理获取图片数据失败的情况 | ||||
| 				logger.SysError("getImageData Error getting image data: " + err.Error()) | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// 将图片数据转为 Base64 编码的字符串 | ||||
| 			b64Json = Base64Encode(imageData) | ||||
| 		} else { | ||||
| 			// 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image | ||||
| 			b64Json = data.B64Image | ||||
| 		} | ||||
|  | ||||
| 		imageResponse.Data = append(imageResponse.Data, openai.ImageData{ | ||||
| 			Url:           data.Url, | ||||
| 			B64Json:       b64Json, | ||||
| 			RevisedPrompt: "", | ||||
| 		}) | ||||
| 	} | ||||
| 	return &imageResponse | ||||
| } | ||||
|  | ||||
| func getImageData(url string) ([]byte, error) { | ||||
| 	response, err := http.Get(url) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer response.Body.Close() | ||||
|  | ||||
| 	imageData, err := io.ReadAll(response.Body) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return imageData, nil | ||||
| } | ||||
|  | ||||
| func Base64Encode(data []byte) string { | ||||
| 	b64Json := base64.StdEncoding.EncodeToString(data) | ||||
| 	return b64Json | ||||
| } | ||||
| @@ -7,7 +7,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| @@ -33,6 +33,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 		enableSearch = true | ||||
| 		aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) | ||||
| 	} | ||||
| 	if request.TopP >= 1 { | ||||
| 		request.TopP = 0.9999 | ||||
| 	} | ||||
| 	return &ChatRequest{ | ||||
| 		Model: aliModel, | ||||
| 		Input: Input{ | ||||
| @@ -42,6 +45,12 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 			EnableSearch:      enableSearch, | ||||
| 			IncrementalOutput: request.Stream, | ||||
| 			Seed:              uint64(request.Seed), | ||||
| 			MaxTokens:         request.MaxTokens, | ||||
| 			Temperature:       request.Temperature, | ||||
| 			TopP:              request.TopP, | ||||
| 			TopK:              request.TopK, | ||||
| 			ResultFormat:      "message", | ||||
| 			Tools:             request.Tools, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
| @@ -57,6 +66,17 @@ func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingReque | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func ConvertImageRequest(request model.ImageRequest) *ImageRequest { | ||||
| 	var imageRequest ImageRequest | ||||
| 	imageRequest.Input.Prompt = request.Prompt | ||||
| 	imageRequest.Model = request.Model | ||||
| 	imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1) | ||||
| 	imageRequest.Parameters.N = request.N | ||||
| 	imageRequest.ResponseFormat = request.ResponseFormat | ||||
| 
 | ||||
| 	return &imageRequest | ||||
| } | ||||
| 
 | ||||
| func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	var aliResponse EmbeddingResponse | ||||
| 	err := json.NewDecoder(resp.Body).Decode(&aliResponse) | ||||
| @@ -111,19 +131,11 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR | ||||
| } | ||||
| 
 | ||||
| func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: model.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Output.Text, | ||||
| 		}, | ||||
| 		FinishReason: response.Output.FinishReason, | ||||
| 	} | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      response.RequestId, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 		Choices: response.Output.Choices, | ||||
| 		Usage: model.Usage{ | ||||
| 			PromptTokens:     response.Usage.InputTokens, | ||||
| 			CompletionTokens: response.Usage.OutputTokens, | ||||
| @@ -134,10 +146,14 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| } | ||||
| 
 | ||||
| func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	if len(aliResponse.Output.Choices) == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
| 	aliChoice := aliResponse.Output.Choices[0] | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = aliResponse.Output.Text | ||||
| 	if aliResponse.Output.FinishReason != "null" { | ||||
| 		finishReason := aliResponse.Output.FinishReason | ||||
| 	choice.Delta = aliChoice.Message | ||||
| 	if aliChoice.FinishReason != "null" { | ||||
| 		finishReason := aliChoice.FinishReason | ||||
| 		choice.FinishReason = &finishReason | ||||
| 	} | ||||
| 	response := openai.ChatCompletionsStreamResponse{ | ||||
| @@ -198,6 +214,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | ||||
| 			} | ||||
| 			response := streamResponseAli2OpenAI(&aliResponse) | ||||
| 			if response == nil { | ||||
| 				return true | ||||
| 			} | ||||
| 			//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) | ||||
| 			//lastResponseText = aliResponse.Output.Text | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| @@ -220,6 +239,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| } | ||||
| 
 | ||||
| func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	var aliResponse ChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| @@ -229,6 +249,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	logger.Debugf(ctx, "response body: %s\n", responseBody) | ||||
| 	err = json.Unmarshal(responseBody, &aliResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
							
								
								
									
										154
									
								
								relay/adaptor/ali/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								relay/adaptor/ali/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,154 @@ | ||||
| package ali | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| type Message struct { | ||||
| 	Content string `json:"content"` | ||||
| 	Role    string `json:"role"` | ||||
| } | ||||
|  | ||||
| type Input struct { | ||||
| 	//Prompt   string       `json:"prompt"` | ||||
| 	Messages []Message `json:"messages"` | ||||
| } | ||||
|  | ||||
| type Parameters struct { | ||||
| 	TopP              float64      `json:"top_p,omitempty"` | ||||
| 	TopK              int          `json:"top_k,omitempty"` | ||||
| 	Seed              uint64       `json:"seed,omitempty"` | ||||
| 	EnableSearch      bool         `json:"enable_search,omitempty"` | ||||
| 	IncrementalOutput bool         `json:"incremental_output,omitempty"` | ||||
| 	MaxTokens         int          `json:"max_tokens,omitempty"` | ||||
| 	Temperature       float64      `json:"temperature,omitempty"` | ||||
| 	ResultFormat      string       `json:"result_format,omitempty"` | ||||
| 	Tools             []model.Tool `json:"tools,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	Model      string     `json:"model"` | ||||
| 	Input      Input      `json:"input"` | ||||
| 	Parameters Parameters `json:"parameters,omitempty"` | ||||
| } | ||||
|  | ||||
| type ImageRequest struct { | ||||
| 	Model string `json:"model"` | ||||
| 	Input struct { | ||||
| 		Prompt         string `json:"prompt"` | ||||
| 		NegativePrompt string `json:"negative_prompt,omitempty"` | ||||
| 	} `json:"input"` | ||||
| 	Parameters struct { | ||||
| 		Size  string `json:"size,omitempty"` | ||||
| 		N     int    `json:"n,omitempty"` | ||||
| 		Steps string `json:"steps,omitempty"` | ||||
| 		Scale string `json:"scale,omitempty"` | ||||
| 	} `json:"parameters,omitempty"` | ||||
| 	ResponseFormat string `json:"response_format,omitempty"` | ||||
| } | ||||
|  | ||||
| type TaskResponse struct { | ||||
| 	StatusCode int    `json:"status_code,omitempty"` | ||||
| 	RequestId  string `json:"request_id,omitempty"` | ||||
| 	Code       string `json:"code,omitempty"` | ||||
| 	Message    string `json:"message,omitempty"` | ||||
| 	Output     struct { | ||||
| 		TaskId     string `json:"task_id,omitempty"` | ||||
| 		TaskStatus string `json:"task_status,omitempty"` | ||||
| 		Code       string `json:"code,omitempty"` | ||||
| 		Message    string `json:"message,omitempty"` | ||||
| 		Results    []struct { | ||||
| 			B64Image string `json:"b64_image,omitempty"` | ||||
| 			Url      string `json:"url,omitempty"` | ||||
| 			Code     string `json:"code,omitempty"` | ||||
| 			Message  string `json:"message,omitempty"` | ||||
| 		} `json:"results,omitempty"` | ||||
| 		TaskMetrics struct { | ||||
| 			Total     int `json:"TOTAL,omitempty"` | ||||
| 			Succeeded int `json:"SUCCEEDED,omitempty"` | ||||
| 			Failed    int `json:"FAILED,omitempty"` | ||||
| 		} `json:"task_metrics,omitempty"` | ||||
| 	} `json:"output,omitempty"` | ||||
| 	Usage Usage `json:"usage"` | ||||
| } | ||||
|  | ||||
| type Header struct { | ||||
| 	Action       string `json:"action,omitempty"` | ||||
| 	Streaming    string `json:"streaming,omitempty"` | ||||
| 	TaskID       string `json:"task_id,omitempty"` | ||||
| 	Event        string `json:"event,omitempty"` | ||||
| 	ErrorCode    string `json:"error_code,omitempty"` | ||||
| 	ErrorMessage string `json:"error_message,omitempty"` | ||||
| 	Attributes   any    `json:"attributes,omitempty"` | ||||
| } | ||||
|  | ||||
| type Payload struct { | ||||
| 	Model      string `json:"model,omitempty"` | ||||
| 	Task       string `json:"task,omitempty"` | ||||
| 	TaskGroup  string `json:"task_group,omitempty"` | ||||
| 	Function   string `json:"function,omitempty"` | ||||
| 	Parameters struct { | ||||
| 		SampleRate int     `json:"sample_rate,omitempty"` | ||||
| 		Rate       float64 `json:"rate,omitempty"` | ||||
| 		Format     string  `json:"format,omitempty"` | ||||
| 	} `json:"parameters,omitempty"` | ||||
| 	Input struct { | ||||
| 		Text string `json:"text,omitempty"` | ||||
| 	} `json:"input,omitempty"` | ||||
| 	Usage struct { | ||||
| 		Characters int `json:"characters,omitempty"` | ||||
| 	} `json:"usage,omitempty"` | ||||
| } | ||||
|  | ||||
| type WSSMessage struct { | ||||
| 	Header  Header  `json:"header,omitempty"` | ||||
| 	Payload Payload `json:"payload,omitempty"` | ||||
| } | ||||
|  | ||||
| type EmbeddingRequest struct { | ||||
| 	Model string `json:"model"` | ||||
| 	Input struct { | ||||
| 		Texts []string `json:"texts"` | ||||
| 	} `json:"input"` | ||||
| 	Parameters *struct { | ||||
| 		TextType string `json:"text_type,omitempty"` | ||||
| 	} `json:"parameters,omitempty"` | ||||
| } | ||||
|  | ||||
| type Embedding struct { | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| 	TextIndex int       `json:"text_index"` | ||||
| } | ||||
|  | ||||
| type EmbeddingResponse struct { | ||||
| 	Output struct { | ||||
| 		Embeddings []Embedding `json:"embeddings"` | ||||
| 	} `json:"output"` | ||||
| 	Usage Usage `json:"usage"` | ||||
| 	Error | ||||
| } | ||||
|  | ||||
| type Error struct { | ||||
| 	Code      string `json:"code"` | ||||
| 	Message   string `json:"message"` | ||||
| 	RequestId string `json:"request_id"` | ||||
| } | ||||
|  | ||||
| type Usage struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| 	TotalTokens  int `json:"total_tokens"` | ||||
| } | ||||
|  | ||||
| type Output struct { | ||||
| 	//Text         string                      `json:"text"` | ||||
| 	//FinishReason string                      `json:"finish_reason"` | ||||
| 	Choices []openai.TextResponseChoice `json:"choices"` | ||||
| } | ||||
|  | ||||
| type ChatResponse struct { | ||||
| 	Output Output `json:"output"` | ||||
| 	Usage  Usage  `json:"usage"` | ||||
| 	Error | ||||
| } | ||||
| @@ -4,10 +4,9 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
| @@ -15,22 +14,23 @@ import ( | ||||
| type Adaptor struct { | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) Init(meta *util.RelayMeta) { | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| 	return fmt.Sprintf("%s/v1/complete", meta.BaseURL), nil | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | ||||
| 	channel.SetupCommonRequestHeader(c, req, meta) | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	adaptor.SetupCommonRequestHeader(c, req, meta) | ||||
| 	req.Header.Set("x-api-key", meta.APIKey) | ||||
| 	anthropicVersion := c.Request.Header.Get("anthropic-version") | ||||
| 	if anthropicVersion == "" { | ||||
| 		anthropicVersion = "2023-06-01" | ||||
| 	} | ||||
| 	req.Header.Set("anthropic-version", anthropicVersion) | ||||
| 	req.Header.Set("anthropic-beta", "messages-2023-12-15") | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| @@ -41,15 +41,20 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 	return ConvertRequest(*request), nil | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return channel.DoRequestHelper(a, c, meta, requestBody) | ||||
| func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	return request, nil | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return adaptor.DoRequestHelper(a, c, meta, requestBody) | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		var responseText string | ||||
| 		err, responseText = StreamHandler(c, resp) | ||||
| 		usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) | ||||
| 		err, usage = StreamHandler(c, resp) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||
| 	} | ||||
| @@ -61,5 +66,5 @@ func (a *Adaptor) GetModelList() []string { | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return "authropic" | ||||
| 	return "anthropic" | ||||
| } | ||||
							
								
								
									
										8
									
								
								relay/adaptor/anthropic/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								relay/adaptor/anthropic/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | ||||
| package anthropic | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"claude-instant-1.2", "claude-2.0", "claude-2.1", | ||||
| 	"claude-3-haiku-20240307", | ||||
| 	"claude-3-sonnet-20240229", | ||||
| 	"claude-3-opus-20240229", | ||||
| } | ||||
							
								
								
									
										273
									
								
								relay/adaptor/anthropic/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										273
									
								
								relay/adaptor/anthropic/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,273 @@ | ||||
| package anthropic | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/image" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func stopReasonClaude2OpenAI(reason *string) string { | ||||
| 	if reason == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	switch *reason { | ||||
| 	case "end_turn": | ||||
| 		return "stop" | ||||
| 	case "stop_sequence": | ||||
| 		return "stop" | ||||
| 	case "max_tokens": | ||||
| 		return "length" | ||||
| 	default: | ||||
| 		return *reason | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 	claudeRequest := Request{ | ||||
| 		Model:       textRequest.Model, | ||||
| 		MaxTokens:   textRequest.MaxTokens, | ||||
| 		Temperature: textRequest.Temperature, | ||||
| 		TopP:        textRequest.TopP, | ||||
| 		TopK:        textRequest.TopK, | ||||
| 		Stream:      textRequest.Stream, | ||||
| 	} | ||||
| 	if claudeRequest.MaxTokens == 0 { | ||||
| 		claudeRequest.MaxTokens = 4096 | ||||
| 	} | ||||
| 	// legacy model name mapping | ||||
| 	if claudeRequest.Model == "claude-instant-1" { | ||||
| 		claudeRequest.Model = "claude-instant-1.1" | ||||
| 	} else if claudeRequest.Model == "claude-2" { | ||||
| 		claudeRequest.Model = "claude-2.1" | ||||
| 	} | ||||
| 	for _, message := range textRequest.Messages { | ||||
| 		if message.Role == "system" && claudeRequest.System == "" { | ||||
| 			claudeRequest.System = message.StringContent() | ||||
| 			continue | ||||
| 		} | ||||
| 		claudeMessage := Message{ | ||||
| 			Role: message.Role, | ||||
| 		} | ||||
| 		var content Content | ||||
| 		if message.IsStringContent() { | ||||
| 			content.Type = "text" | ||||
| 			content.Text = message.StringContent() | ||||
| 			claudeMessage.Content = append(claudeMessage.Content, content) | ||||
| 			claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) | ||||
| 			continue | ||||
| 		} | ||||
| 		var contents []Content | ||||
| 		openaiContent := message.ParseContent() | ||||
| 		for _, part := range openaiContent { | ||||
| 			var content Content | ||||
| 			if part.Type == model.ContentTypeText { | ||||
| 				content.Type = "text" | ||||
| 				content.Text = part.Text | ||||
| 			} else if part.Type == model.ContentTypeImageURL { | ||||
| 				content.Type = "image" | ||||
| 				content.Source = &ImageSource{ | ||||
| 					Type: "base64", | ||||
| 				} | ||||
| 				mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) | ||||
| 				content.Source.MediaType = mimeType | ||||
| 				content.Source.Data = data | ||||
| 			} | ||||
| 			contents = append(contents, content) | ||||
| 		} | ||||
| 		claudeMessage.Content = contents | ||||
| 		claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) | ||||
| 	} | ||||
| 	return &claudeRequest | ||||
| } | ||||
|  | ||||
| // https://docs.anthropic.com/claude/reference/messages-streaming | ||||
| func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { | ||||
| 	var response *Response | ||||
| 	var responseText string | ||||
| 	var stopReason string | ||||
| 	switch claudeResponse.Type { | ||||
| 	case "message_start": | ||||
| 		return nil, claudeResponse.Message | ||||
| 	case "content_block_start": | ||||
| 		if claudeResponse.ContentBlock != nil { | ||||
| 			responseText = claudeResponse.ContentBlock.Text | ||||
| 		} | ||||
| 	case "content_block_delta": | ||||
| 		if claudeResponse.Delta != nil { | ||||
| 			responseText = claudeResponse.Delta.Text | ||||
| 		} | ||||
| 	case "message_delta": | ||||
| 		if claudeResponse.Usage != nil { | ||||
| 			response = &Response{ | ||||
| 				Usage: *claudeResponse.Usage, | ||||
| 			} | ||||
| 		} | ||||
| 		if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil { | ||||
| 			stopReason = *claudeResponse.Delta.StopReason | ||||
| 		} | ||||
| 	} | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = responseText | ||||
| 	choice.Delta.Role = "assistant" | ||||
| 	finishReason := stopReasonClaude2OpenAI(&stopReason) | ||||
| 	if finishReason != "null" { | ||||
| 		choice.FinishReason = &finishReason | ||||
| 	} | ||||
| 	var openaiResponse openai.ChatCompletionsStreamResponse | ||||
| 	openaiResponse.Object = "chat.completion.chunk" | ||||
| 	openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||
| 	return &openaiResponse, response | ||||
| } | ||||
|  | ||||
| func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { | ||||
| 	var responseText string | ||||
| 	if len(claudeResponse.Content) > 0 { | ||||
| 		responseText = claudeResponse.Content[0].Text | ||||
| 	} | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: model.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: responseText, | ||||
| 			Name:    nil, | ||||
| 		}, | ||||
| 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | ||||
| 	} | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", claudeResponse.Id), | ||||
| 		Model:   claudeResponse.Model, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	createdTime := helper.GetTimestamp() | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 6 { | ||||
| 				continue | ||||
| 			} | ||||
| 			if !strings.HasPrefix(data, "data: ") { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = strings.TrimPrefix(data, "data: ") | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	var usage model.Usage | ||||
| 	var modelName string | ||||
| 	var id string | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			var claudeResponse StreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &claudeResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response, meta := streamResponseClaude2OpenAI(&claudeResponse) | ||||
| 			if meta != nil { | ||||
| 				usage.PromptTokens += meta.Usage.InputTokens | ||||
| 				usage.CompletionTokens += meta.Usage.OutputTokens | ||||
| 				modelName = meta.Model | ||||
| 				id = fmt.Sprintf("chatcmpl-%s", meta.Id) | ||||
| 				return true | ||||
| 			} | ||||
| 			if response == nil { | ||||
| 				return true | ||||
| 			} | ||||
| 			response.Id = id | ||||
| 			response.Model = modelName | ||||
| 			response.Created = createdTime | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	_ = resp.Body.Close() | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var claudeResponse Response | ||||
| 	err = json.Unmarshal(responseBody, &claudeResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if claudeResponse.Error.Type != "" { | ||||
| 		return &model.ErrorWithStatusCode{ | ||||
| 			Error: model.Error{ | ||||
| 				Message: claudeResponse.Error.Message, | ||||
| 				Type:    claudeResponse.Error.Type, | ||||
| 				Param:   "", | ||||
| 				Code:    claudeResponse.Error.Type, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseClaude2OpenAI(&claudeResponse) | ||||
| 	fullTextResponse.Model = modelName | ||||
| 	usage := model.Usage{ | ||||
| 		PromptTokens:     claudeResponse.Usage.InputTokens, | ||||
| 		CompletionTokens: claudeResponse.Usage.OutputTokens, | ||||
| 		TotalTokens:      claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, | ||||
| 	} | ||||
| 	fullTextResponse.Usage = usage | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &usage | ||||
| } | ||||
							
								
								
									
										75
									
								
								relay/adaptor/anthropic/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								relay/adaptor/anthropic/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,75 @@ | ||||
| package anthropic | ||||
|  | ||||
| // https://docs.anthropic.com/claude/reference/messages_post | ||||
|  | ||||
| type Metadata struct { | ||||
| 	UserId string `json:"user_id"` | ||||
| } | ||||
|  | ||||
| type ImageSource struct { | ||||
| 	Type      string `json:"type"` | ||||
| 	MediaType string `json:"media_type"` | ||||
| 	Data      string `json:"data"` | ||||
| } | ||||
|  | ||||
| type Content struct { | ||||
| 	Type   string       `json:"type"` | ||||
| 	Text   string       `json:"text,omitempty"` | ||||
| 	Source *ImageSource `json:"source,omitempty"` | ||||
| } | ||||
|  | ||||
| type Message struct { | ||||
| 	Role    string    `json:"role"` | ||||
| 	Content []Content `json:"content"` | ||||
| } | ||||
|  | ||||
| type Request struct { | ||||
| 	Model         string    `json:"model"` | ||||
| 	Messages      []Message `json:"messages"` | ||||
| 	System        string    `json:"system,omitempty"` | ||||
| 	MaxTokens     int       `json:"max_tokens,omitempty"` | ||||
| 	StopSequences []string  `json:"stop_sequences,omitempty"` | ||||
| 	Stream        bool      `json:"stream,omitempty"` | ||||
| 	Temperature   float64   `json:"temperature,omitempty"` | ||||
| 	TopP          float64   `json:"top_p,omitempty"` | ||||
| 	TopK          int       `json:"top_k,omitempty"` | ||||
| 	//Metadata    `json:"metadata,omitempty"` | ||||
| } | ||||
|  | ||||
| type Usage struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| } | ||||
|  | ||||
| type Error struct { | ||||
| 	Type    string `json:"type"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type Response struct { | ||||
| 	Id           string    `json:"id"` | ||||
| 	Type         string    `json:"type"` | ||||
| 	Role         string    `json:"role"` | ||||
| 	Content      []Content `json:"content"` | ||||
| 	Model        string    `json:"model"` | ||||
| 	StopReason   *string   `json:"stop_reason"` | ||||
| 	StopSequence *string   `json:"stop_sequence"` | ||||
| 	Usage        Usage     `json:"usage"` | ||||
| 	Error        Error     `json:"error"` | ||||
| } | ||||
|  | ||||
| type Delta struct { | ||||
| 	Type         string  `json:"type"` | ||||
| 	Text         string  `json:"text"` | ||||
| 	StopReason   *string `json:"stop_reason"` | ||||
| 	StopSequence *string `json:"stop_sequence"` | ||||
| } | ||||
|  | ||||
| type StreamResponse struct { | ||||
| 	Type         string    `json:"type"` | ||||
| 	Message      *Response `json:"message"` | ||||
| 	Index        int       `json:"index"` | ||||
| 	ContentBlock *Content  `json:"content_block"` | ||||
| 	Delta        *Delta    `json:"delta"` | ||||
| 	Usage        *Usage    `json:"usage"` | ||||
| } | ||||
							
								
								
									
										15
									
								
								relay/adaptor/azure/helper.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								relay/adaptor/azure/helper.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| package azure | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| ) | ||||
|  | ||||
| func GetAPIVersion(c *gin.Context) string { | ||||
| 	query := c.Request.URL.Query() | ||||
| 	apiVersion := query.Get("api-version") | ||||
| 	if apiVersion == "" { | ||||
| 		apiVersion = c.GetString(config.KeyAPIVersion) | ||||
| 	} | ||||
| 	return apiVersion | ||||
| } | ||||
							
								
								
									
										7
									
								
								relay/adaptor/baichuan/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								relay/adaptor/baichuan/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| package baichuan | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"Baichuan2-Turbo", | ||||
| 	"Baichuan2-Turbo-192k", | ||||
| 	"Baichuan-Text-Embedding", | ||||
| } | ||||
							
								
								
									
										143
									
								
								relay/adaptor/baidu/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								relay/adaptor/baidu/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,143 @@ | ||||
| package baidu | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
|  | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t | ||||
| 	suffix := "chat/" | ||||
| 	if strings.HasPrefix(meta.ActualModelName, "Embedding") { | ||||
| 		suffix = "embeddings/" | ||||
| 	} | ||||
| 	if strings.HasPrefix(meta.ActualModelName, "bge-large") { | ||||
| 		suffix = "embeddings/" | ||||
| 	} | ||||
| 	if strings.HasPrefix(meta.ActualModelName, "tao-8k") { | ||||
| 		suffix = "embeddings/" | ||||
| 	} | ||||
| 	switch meta.ActualModelName { | ||||
| 	case "ERNIE-4.0": | ||||
| 		suffix += "completions_pro" | ||||
| 	case "ERNIE-Bot-4": | ||||
| 		suffix += "completions_pro" | ||||
| 	case "ERNIE-Bot": | ||||
| 		suffix += "completions" | ||||
| 	case "ERNIE-Bot-turbo": | ||||
| 		suffix += "eb-instant" | ||||
| 	case "ERNIE-Speed": | ||||
| 		suffix += "ernie_speed" | ||||
| 	case "ERNIE-4.0-8K": | ||||
| 		suffix += "completions_pro" | ||||
| 	case "ERNIE-3.5-8K": | ||||
| 		suffix += "completions" | ||||
| 	case "ERNIE-3.5-8K-0205": | ||||
| 		suffix += "ernie-3.5-8k-0205" | ||||
| 	case "ERNIE-3.5-8K-1222": | ||||
| 		suffix += "ernie-3.5-8k-1222" | ||||
| 	case "ERNIE-Bot-8K": | ||||
| 		suffix += "ernie_bot_8k" | ||||
| 	case "ERNIE-3.5-4K-0205": | ||||
| 		suffix += "ernie-3.5-4k-0205" | ||||
| 	case "ERNIE-Speed-8K": | ||||
| 		suffix += "ernie_speed" | ||||
| 	case "ERNIE-Speed-128K": | ||||
| 		suffix += "ernie-speed-128k" | ||||
| 	case "ERNIE-Lite-8K-0922": | ||||
| 		suffix += "eb-instant" | ||||
| 	case "ERNIE-Lite-8K-0308": | ||||
| 		suffix += "ernie-lite-8k" | ||||
| 	case "ERNIE-Tiny-8K": | ||||
| 		suffix += "ernie-tiny-8k" | ||||
| 	case "BLOOMZ-7B": | ||||
| 		suffix += "bloomz_7b1" | ||||
| 	case "Embedding-V1": | ||||
| 		suffix += "embedding-v1" | ||||
| 	case "bge-large-zh": | ||||
| 		suffix += "bge_large_zh" | ||||
| 	case "bge-large-en": | ||||
| 		suffix += "bge_large_en" | ||||
| 	case "tao-8k": | ||||
| 		suffix += "tao_8k" | ||||
| 	default: | ||||
| 		suffix += strings.ToLower(meta.ActualModelName) | ||||
| 	} | ||||
| 	fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix) | ||||
| 	var accessToken string | ||||
| 	var err error | ||||
| 	if accessToken, err = GetAccessToken(meta.APIKey); err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	fullRequestURL += "?access_token=" + accessToken | ||||
| 	return fullRequestURL, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	adaptor.SetupCommonRequestHeader(c, req, meta) | ||||
| 	req.Header.Set("Authorization", "Bearer "+meta.APIKey) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	switch relayMode { | ||||
| 	case relaymode.Embeddings: | ||||
| 		baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) | ||||
| 		return baiduEmbeddingRequest, nil | ||||
| 	default: | ||||
| 		baiduRequest := ConvertRequest(*request) | ||||
| 		return baiduRequest, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	return request, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return adaptor.DoRequestHelper(a, c, meta, requestBody) | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, resp) | ||||
| 	} else { | ||||
| 		switch meta.Mode { | ||||
| 		case relaymode.Embeddings: | ||||
| 			err, usage = EmbeddingHandler(c, resp) | ||||
| 		default: | ||||
| 			err, usage = Handler(c, resp) | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetModelList() []string { | ||||
| 	return ModelList | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return "baidu" | ||||
| } | ||||
							
								
								
									
										20
									
								
								relay/adaptor/baidu/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								relay/adaptor/baidu/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| package baidu | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"ERNIE-4.0-8K", | ||||
| 	"ERNIE-3.5-8K", | ||||
| 	"ERNIE-3.5-8K-0205", | ||||
| 	"ERNIE-3.5-8K-1222", | ||||
| 	"ERNIE-Bot-8K", | ||||
| 	"ERNIE-3.5-4K-0205", | ||||
| 	"ERNIE-Speed-8K", | ||||
| 	"ERNIE-Speed-128K", | ||||
| 	"ERNIE-Lite-8K-0922", | ||||
| 	"ERNIE-Lite-8K-0308", | ||||
| 	"ERNIE-Tiny-8K", | ||||
| 	"BLOOMZ-7B", | ||||
| 	"Embedding-V1", | ||||
| 	"bge-large-zh", | ||||
| 	"bge-large-en", | ||||
| 	"tao-8k", | ||||
| } | ||||
| @@ -8,10 +8,10 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/client" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| @@ -32,9 +32,16 @@ type Message struct { | ||||
| } | ||||
| 
 | ||||
| type ChatRequest struct { | ||||
| 	Messages []Message `json:"messages"` | ||||
| 	Stream   bool      `json:"stream"` | ||||
| 	UserId   string    `json:"user_id,omitempty"` | ||||
| 	Messages        []Message `json:"messages"` | ||||
| 	Temperature     float64   `json:"temperature,omitempty"` | ||||
| 	TopP            float64   `json:"top_p,omitempty"` | ||||
| 	PenaltyScore    float64   `json:"penalty_score,omitempty"` | ||||
| 	Stream          bool      `json:"stream,omitempty"` | ||||
| 	System          string    `json:"system,omitempty"` | ||||
| 	DisableSearch   bool      `json:"disable_search,omitempty"` | ||||
| 	EnableCitation  bool      `json:"enable_citation,omitempty"` | ||||
| 	MaxOutputTokens int       `json:"max_output_tokens,omitempty"` | ||||
| 	UserId          string    `json:"user_id,omitempty"` | ||||
| } | ||||
| 
 | ||||
| type Error struct { | ||||
| @@ -45,28 +52,28 @@ type Error struct { | ||||
| var baiduTokenStore sync.Map | ||||
| 
 | ||||
| func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 	messages := make([]Message, 0, len(request.Messages)) | ||||
| 	baiduRequest := ChatRequest{ | ||||
| 		Messages:        make([]Message, 0, len(request.Messages)), | ||||
| 		Temperature:     request.Temperature, | ||||
| 		TopP:            request.TopP, | ||||
| 		PenaltyScore:    request.FrequencyPenalty, | ||||
| 		Stream:          request.Stream, | ||||
| 		DisableSearch:   false, | ||||
| 		EnableCitation:  false, | ||||
| 		MaxOutputTokens: request.MaxTokens, | ||||
| 		UserId:          request.User, | ||||
| 	} | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 			baiduRequest.System = message.StringContent() | ||||
| 		} else { | ||||
| 			messages = append(messages, Message{ | ||||
| 			baiduRequest.Messages = append(baiduRequest.Messages, Message{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	return &ChatRequest{ | ||||
| 		Messages: messages, | ||||
| 		Stream:   request.Stream, | ||||
| 	} | ||||
| 	return &baiduRequest | ||||
| } | ||||
| 
 | ||||
| func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| @@ -298,7 +305,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) { | ||||
| 	} | ||||
| 	req.Header.Add("Content-Type", "application/json") | ||||
| 	req.Header.Add("Accept", "application/json") | ||||
| 	res, err := util.ImpatientHTTPClient.Do(req) | ||||
| 	res, err := client.ImpatientHTTPClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -1,15 +1,16 @@ | ||||
| package channel | ||||
| package adaptor | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"github.com/songquanpeng/one-api/relay/client" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
| 
 | ||||
| func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) { | ||||
| func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) { | ||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
| 	if meta.IsStream && c.Request.Header.Get("Accept") == "" { | ||||
| @@ -17,7 +18,7 @@ func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.Rela | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { | ||||
| func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	fullRequestURL, err := a.GetRequestURL(meta) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("get request url failed: %w", err) | ||||
| @@ -38,7 +39,7 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBod | ||||
| } | ||||
| 
 | ||||
| func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) { | ||||
| 	resp, err := util.HTTPClient.Do(req) | ||||
| 	resp, err := client.HTTPClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -4,11 +4,12 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	channelhelper "github.com/songquanpeng/one-api/relay/channel" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	channelhelper "github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
| @@ -16,12 +17,12 @@ import ( | ||||
| type Adaptor struct { | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) Init(meta *util.RelayMeta) { | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| 	version := helper.AssignOrDefault(meta.APIVersion, "v1") | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	version := helper.AssignOrDefault(meta.APIVersion, config.GeminiVersion) | ||||
| 	action := "generateContent" | ||||
| 	if meta.IsStream { | ||||
| 		action = "streamGenerateContent" | ||||
| @@ -29,7 +30,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| 	return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	channelhelper.SetupCommonRequestHeader(c, req, meta) | ||||
| 	req.Header.Set("x-goog-api-key", meta.APIKey) | ||||
| 	return nil | ||||
| @@ -42,11 +43,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 	return ConvertRequest(*request), nil | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { | ||||
| func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	return request, nil | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return channelhelper.DoRequestHelper(a, c, meta, requestBody) | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		var responseText string | ||||
| 		err, responseText = StreamHandler(c, resp) | ||||
							
								
								
									
										8
									
								
								relay/adaptor/gemini/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								relay/adaptor/gemini/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | ||||
| package gemini | ||||
|  | ||||
| // https://ai.google.dev/models/gemini | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro", | ||||
| 	"gemini-pro-vision", "gemini-1.0-pro-vision-001", | ||||
| } | ||||
| @@ -9,7 +9,8 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/image" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| @@ -155,7 +156,7 @@ type ChatPromptFeedback struct { | ||||
| 
 | ||||
| func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", random.GetUUID()), | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), | ||||
| @@ -233,7 +234,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 			var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 			choice.Delta.Content = dummy.Content | ||||
| 			response := openai.ChatCompletionsStreamResponse{ | ||||
| 				Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||
| 				Id:      fmt.Sprintf("chatcmpl-%s", random.GetUUID()), | ||||
| 				Object:  "chat.completion.chunk", | ||||
| 				Created: helper.GetTimestamp(), | ||||
| 				Model:   "gemini-pro", | ||||
							
								
								
									
										10
									
								
								relay/adaptor/groq/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								relay/adaptor/groq/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| package groq | ||||
|  | ||||
| // https://console.groq.com/docs/models | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"gemma-7b-it", | ||||
| 	"llama2-7b-2048", | ||||
| 	"llama2-70b-4096", | ||||
| 	"mixtral-8x7b-32768", | ||||
| } | ||||
							
								
								
									
										21
									
								
								relay/adaptor/interface.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								relay/adaptor/interface.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| package adaptor | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| type Adaptor interface { | ||||
| 	Init(meta *meta.Meta) | ||||
| 	GetRequestURL(meta *meta.Meta) (string, error) | ||||
| 	SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error | ||||
| 	ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) | ||||
| 	ConvertImageRequest(request *model.ImageRequest) (any, error) | ||||
| 	DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) | ||||
| 	DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) | ||||
| 	GetModelList() []string | ||||
| 	GetChannelName() string | ||||
| } | ||||
							
								
								
									
										9
									
								
								relay/adaptor/lingyiwanwu/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								relay/adaptor/lingyiwanwu/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | ||||
| package lingyiwanwu | ||||
|  | ||||
| // https://platform.lingyiwanwu.com/docs | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"yi-34b-chat-0205", | ||||
| 	"yi-34b-chat-200k", | ||||
| 	"yi-vl-plus", | ||||
| } | ||||
							
								
								
									
										7
									
								
								relay/adaptor/minimax/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								relay/adaptor/minimax/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| package minimax | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"abab5.5s-chat", | ||||
| 	"abab5.5-chat", | ||||
| 	"abab6-chat", | ||||
| } | ||||
							
								
								
									
										14
									
								
								relay/adaptor/minimax/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								relay/adaptor/minimax/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| package minimax | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| ) | ||||
|  | ||||
| func GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	if meta.Mode == relaymode.ChatCompletions { | ||||
| 		return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil | ||||
| 	} | ||||
| 	return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode) | ||||
| } | ||||
							
								
								
									
										10
									
								
								relay/adaptor/mistral/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								relay/adaptor/mistral/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| package mistral | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"open-mistral-7b", | ||||
| 	"open-mixtral-8x7b", | ||||
| 	"mistral-small-latest", | ||||
| 	"mistral-medium-latest", | ||||
| 	"mistral-large-latest", | ||||
| 	"mistral-embed", | ||||
| } | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user