mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-31 22:03:41 +08:00 
			
		
		
		
	Compare commits
	
		
			94 Commits
		
	
	
		
			v0.6.0-alp
			...
			v0.6.4-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 96d7a99312 | ||
|  | 24be9de098 | ||
|  | 5b349efff9 | ||
|  | f76c46d648 | ||
|  | cdfdeea3b4 | ||
|  | 56ddbb842a | ||
|  | 99f81a267c | ||
|  | c243cd5535 | ||
|  | e96b173abe | ||
|  | 4ae311e964 | ||
|  | b14cb748d8 | ||
|  | ade19ba4a2 | ||
|  | 4d86d021c4 | ||
|  | 7a44adb5a7 | ||
|  | 9821bc7281 | ||
|  | 08831881f1 | ||
|  | 0eb2272bb7 | ||
|  | 704ec1a827 | ||
|  | 1d7470d6ad | ||
|  | 1185303346 | ||
|  | c212fcf8d7 | ||
|  | c285e000cc | ||
|  | d25ed4c009 | ||
|  | 7400885fbb | ||
|  | 11af81eb39 | ||
|  | 205aba694f | ||
|  | 8dac3afebc | ||
|  | a07791bf93 | ||
|  | 4bb662c0e4 | ||
|  | 4998d58319 | ||
|  | 190203cf8f | ||
|  | 6325c8e0b4 | ||
|  | b204f6d82b | ||
|  | 752639560f | ||
|  | 996f4d99dd | ||
|  | ebfee3b46c | ||
|  | 3e2e805d61 | ||
|  | 3edf7247c4 | ||
|  | 0926b6206b | ||
|  | 7cd57f3125 | ||
|  | 66efabd5ae | ||
|  | 8ede66a896 | ||
|  | b169173860 | ||
|  | f33555ae78 | ||
|  | c28ec10795 | ||
|  | e3767cbb07 | ||
|  | be9eb59fbb | ||
|  | 89e111ac69 | ||
|  | 2dcef85285 | ||
|  | 79d0cd378a | ||
|  | e99150bdb9 | ||
|  | a72e5fcc9e | ||
|  | 0710f8cd66 | ||
|  | 49cad7d4a5 | ||
|  | a90161cf00 | ||
|  | a45fc7d736 | ||
|  | 45940dcb12 | ||
|  | 969042b001 | ||
|  | 7e7369dbc4 | ||
|  | e54e647170 | ||
|  | 358920c858 | ||
|  | 1ea598c773 | ||
|  | 796be42487 | ||
|  | 5b50eb94e5 | ||
|  | 71c61365eb | ||
|  | b09f979b80 | ||
|  | 12440874b0 | ||
|  | 6ebc99460e | ||
|  | 27ad8bfb98 | ||
|  | 8388aa537f | ||
|  | 2346bf70af | ||
|  | f05b403ca5 | ||
|  | b33616df44 | ||
|  | cf16f44970 | ||
|  | bf2e26a48f | ||
|  | 4fb22ad4ce | ||
|  | 95cfb8e8c9 | ||
|  | c6ace985c2 | ||
|  | 10a926b8f3 | ||
|  | 2df877a352 | ||
|  | 9d8967f7d3 | ||
|  | b35f3523d3 | ||
|  | 82e916b5ff | ||
|  | de18d6fe16 | ||
|  | 1d0b7fb5ae | ||
|  | f9490bb72e | ||
|  | 76467285e8 | ||
|  | df1fd9aa81 | ||
|  | 614c2e0442 | ||
|  | eac6a0b9aa | ||
|  | b747cdbc6f | ||
|  | 6b27d6659a | ||
|  | dc5b781191 | ||
|  | c880b4a9a3 | 
							
								
								
									
										7
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							| @@ -20,6 +20,13 @@ jobs: | |||||||
|       - name: Check out the repo |       - name: Check out the repo | ||||||
|         uses: actions/checkout@v3 |         uses: actions/checkout@v3 | ||||||
|  |  | ||||||
|  |       - name: Check repository URL | ||||||
|  |         run: | | ||||||
|  |           REPO_URL=$(git config --get remote.origin.url) | ||||||
|  |           if [[ $REPO_URL == *"pro" ]]; then | ||||||
|  |             exit 1 | ||||||
|  |           fi       | ||||||
|  |  | ||||||
|       - name: Save version info |       - name: Save version info | ||||||
|         run: | |         run: | | ||||||
|           git describe --tags > VERSION  |           git describe --tags > VERSION  | ||||||
|   | |||||||
							
								
								
									
										7
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -20,6 +20,13 @@ jobs: | |||||||
|       - name: Check out the repo |       - name: Check out the repo | ||||||
|         uses: actions/checkout@v3 |         uses: actions/checkout@v3 | ||||||
|  |  | ||||||
|  |       - name: Check repository URL | ||||||
|  |         run: | | ||||||
|  |           REPO_URL=$(git config --get remote.origin.url) | ||||||
|  |           if [[ $REPO_URL == *"pro" ]]; then | ||||||
|  |             exit 1 | ||||||
|  |           fi         | ||||||
|  |  | ||||||
|       - name: Save version info |       - name: Save version info | ||||||
|         run: | |         run: | | ||||||
|           git describe --tags > VERSION  |           git describe --tags > VERSION  | ||||||
|   | |||||||
							
								
								
									
										7
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -21,6 +21,13 @@ jobs: | |||||||
|       - name: Check out the repo |       - name: Check out the repo | ||||||
|         uses: actions/checkout@v3 |         uses: actions/checkout@v3 | ||||||
|  |  | ||||||
|  |       - name: Check repository URL | ||||||
|  |         run: | | ||||||
|  |           REPO_URL=$(git config --get remote.origin.url) | ||||||
|  |           if [[ $REPO_URL == *"pro" ]]; then | ||||||
|  |             exit 1 | ||||||
|  |           fi | ||||||
|  |  | ||||||
|       - name: Save version info |       - name: Save version info | ||||||
|         run: | |         run: | | ||||||
|           git describe --tags > VERSION  |           git describe --tags > VERSION  | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -20,6 +20,12 @@ jobs: | |||||||
|         uses: actions/checkout@v3 |         uses: actions/checkout@v3 | ||||||
|         with: |         with: | ||||||
|           fetch-depth: 0 |           fetch-depth: 0 | ||||||
|  |       - name: Check repository URL | ||||||
|  |         run: | | ||||||
|  |           REPO_URL=$(git config --get remote.origin.url) | ||||||
|  |           if [[ $REPO_URL == *"pro" ]]; then | ||||||
|  |             exit 1 | ||||||
|  |           fi | ||||||
|       - uses: actions/setup-node@v3 |       - uses: actions/setup-node@v3 | ||||||
|         with: |         with: | ||||||
|           node-version: 16 |           node-version: 16 | ||||||
| @@ -38,7 +44,7 @@ jobs: | |||||||
|       - name: Build Backend (amd64) |       - name: Build Backend (amd64) | ||||||
|         run: | |         run: | | ||||||
|           go mod download |           go mod download | ||||||
|           go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api |           go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api | ||||||
|  |  | ||||||
|       - name: Build Backend (arm64) |       - name: Build Backend (arm64) | ||||||
|         run: | |         run: | | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -20,6 +20,12 @@ jobs: | |||||||
|         uses: actions/checkout@v3 |         uses: actions/checkout@v3 | ||||||
|         with: |         with: | ||||||
|           fetch-depth: 0 |           fetch-depth: 0 | ||||||
|  |       - name: Check repository URL | ||||||
|  |         run: | | ||||||
|  |           REPO_URL=$(git config --get remote.origin.url) | ||||||
|  |           if [[ $REPO_URL == *"pro" ]]; then | ||||||
|  |             exit 1 | ||||||
|  |           fi | ||||||
|       - uses: actions/setup-node@v3 |       - uses: actions/setup-node@v3 | ||||||
|         with: |         with: | ||||||
|           node-version: 16 |           node-version: 16 | ||||||
| @@ -38,7 +44,7 @@ jobs: | |||||||
|       - name: Build Backend |       - name: Build Backend | ||||||
|         run: | |         run: | | ||||||
|           go mod download |           go mod download | ||||||
|           go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos |           go build -ldflags "-X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api-macos | ||||||
|       - name: Release |       - name: Release | ||||||
|         uses: softprops/action-gh-release@v1 |         uses: softprops/action-gh-release@v1 | ||||||
|         if: startsWith(github.ref, 'refs/tags/') |         if: startsWith(github.ref, 'refs/tags/') | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -23,6 +23,12 @@ jobs: | |||||||
|         uses: actions/checkout@v3 |         uses: actions/checkout@v3 | ||||||
|         with: |         with: | ||||||
|           fetch-depth: 0 |           fetch-depth: 0 | ||||||
|  |       - name: Check repository URL | ||||||
|  |         run: | | ||||||
|  |           REPO_URL=$(git config --get remote.origin.url) | ||||||
|  |           if [[ $REPO_URL == *"pro" ]]; then | ||||||
|  |             exit 1 | ||||||
|  |           fi | ||||||
|       - uses: actions/setup-node@v3 |       - uses: actions/setup-node@v3 | ||||||
|         with: |         with: | ||||||
|           node-version: 16 |           node-version: 16 | ||||||
| @@ -41,7 +47,7 @@ jobs: | |||||||
|       - name: Build Backend |       - name: Build Backend | ||||||
|         run: | |         run: | | ||||||
|           go mod download |           go mod download | ||||||
|           go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe |           go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api.exe | ||||||
|       - name: Release |       - name: Release | ||||||
|         uses: softprops/action-gh-release@v1 |         uses: softprops/action-gh-release@v1 | ||||||
|         if: startsWith(github.ref, 'refs/tags/') |         if: startsWith(github.ref, 'refs/tags/') | ||||||
|   | |||||||
| @@ -12,6 +12,10 @@ WORKDIR /web/berry | |||||||
| RUN npm install | RUN npm install | ||||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||||
|  |  | ||||||
|  | WORKDIR /web/air | ||||||
|  | RUN npm install | ||||||
|  | RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||||
|  |  | ||||||
| FROM golang AS builder2 | FROM golang AS builder2 | ||||||
|  |  | ||||||
| ENV GO111MODULE=on \ | ENV GO111MODULE=on \ | ||||||
|   | |||||||
							
								
								
									
										14
									
								
								README.en.md
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								README.en.md
									
									
									
									
									
								
							| @@ -241,17 +241,19 @@ If the channel ID is not provided, load balancing will be used to distribute the | |||||||
|     + Example: `SESSION_SECRET=random_string` |     + Example: `SESSION_SECRET=random_string` | ||||||
| 3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0. | 3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0. | ||||||
|     + Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` |     + Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||||
| 4. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. | 4. `LOG_SQL_DSN`: When set, a separate database will be used for the `logs` table; please use MySQL or PostgreSQL. | ||||||
|  |     + Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs` | ||||||
|  | 5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. | ||||||
|     + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` |     + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||||
| 5. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. | 6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. | ||||||
|     + Example: `SYNC_FREQUENCY=60` |     + Example: `SYNC_FREQUENCY=60` | ||||||
| 6. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. | 7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. | ||||||
|     + Example: `NODE_TYPE=slave` |     + Example: `NODE_TYPE=slave` | ||||||
| 7. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. | 8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. | ||||||
|     + Example: `CHANNEL_UPDATE_FREQUENCY=1440` |     + Example: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||||
| 8. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. | 9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. | ||||||
|     + Example: `CHANNEL_TEST_FREQUENCY=1440` |     + Example: `CHANNEL_TEST_FREQUENCY=1440` | ||||||
| 9. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. | 10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. | ||||||
|     + Example: `POLLING_INTERVAL=5` |     + Example: `POLLING_INTERVAL=5` | ||||||
|  |  | ||||||
| ### Command Line Parameters | ### Command Line Parameters | ||||||
|   | |||||||
							
								
								
									
										13
									
								
								README.ja.md
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								README.ja.md
									
									
									
									
									
								
							| @@ -242,17 +242,18 @@ graph LR | |||||||
|     + 例: `SESSION_SECRET=random_string` |     + 例: `SESSION_SECRET=random_string` | ||||||
| 3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。 | 3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。 | ||||||
|     + 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` |     + 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||||
| 4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。 | 4. `LOG_SQL_DSN`: を設定すると、`logs`テーブルには独立したデータベースが使用されます。MySQLまたはPostgreSQLを使用してください。 | ||||||
|  | 5. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。 | ||||||
|     + 例: `FRONTEND_BASE_URL=https://openai.justsong.cn` |     + 例: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||||
| 5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 | 6. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 | ||||||
|     + 例: `SYNC_FREQUENCY=60` |     + 例: `SYNC_FREQUENCY=60` | ||||||
| 6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。 | 7. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。 | ||||||
|     + 例: `NODE_TYPE=slave` |     + 例: `NODE_TYPE=slave` | ||||||
| 7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 | 8. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 | ||||||
|     + 例: `CHANNEL_UPDATE_FREQUENCY=1440` |     + 例: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||||
| 8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 | 9. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 | ||||||
|     + 例: `CHANNEL_TEST_FREQUENCY=1440` |     + 例: `CHANNEL_TEST_FREQUENCY=1440` | ||||||
| 9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 | 10. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 | ||||||
|     + 例: `POLLING_INTERVAL=5` |     + 例: `POLLING_INTERVAL=5` | ||||||
|  |  | ||||||
| ### コマンドラインパラメータ | ### コマンドラインパラメータ | ||||||
|   | |||||||
							
								
								
									
										51
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										51
									
								
								README.md
									
									
									
									
									
								
							| @@ -67,6 +67,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
|    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) |    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||||
|    + [x] [Anthropic Claude 系列模型](https://anthropic.com) |    + [x] [Anthropic Claude 系列模型](https://anthropic.com) | ||||||
|    + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) |    + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) | ||||||
|  |    + [x] [Mistral 系列模型](https://mistral.ai/) | ||||||
|    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) |    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||||
|    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) |    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) | ||||||
|    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) |    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||||
| @@ -74,15 +75,19 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
|    + [x] [360 智脑](https://ai.360.cn) |    + [x] [360 智脑](https://ai.360.cn) | ||||||
|    + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) |    + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) | ||||||
|    + [x] [Moonshot AI](https://platform.moonshot.cn/) |    + [x] [Moonshot AI](https://platform.moonshot.cn/) | ||||||
|  |    + [x] [百川大模型](https://platform.baichuan-ai.com) | ||||||
|    + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) |    + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) | ||||||
|    + [ ] [MINIMAX](https://api.minimax.chat/) (WIP) |    + [x] [MINIMAX](https://api.minimax.chat/) | ||||||
|  |    + [x] [Groq](https://wow.groq.com/) | ||||||
|  |    + [x] [Ollama](https://github.com/ollama/ollama) | ||||||
|  |    + [x] [零一万物](https://platform.lingyiwanwu.com/) | ||||||
| 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | ||||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||||
| 5. 支持**多机部署**,[详见此处](#多机部署)。 | 5. 支持**多机部署**,[详见此处](#多机部署)。 | ||||||
| 6. 支持**令牌管理**,设置令牌的过期时间和额度。 | 6. 支持**令牌管理**,设置令牌的过期时间和额度。 | ||||||
| 7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 | 7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 | ||||||
| 8. 支持**通道管理**,批量创建通道。 | 8. 支持**渠道管理**,批量创建渠道。 | ||||||
| 9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 | 9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 | ||||||
| 10. 支持渠道**设置模型列表**。 | 10. 支持渠道**设置模型列表**。 | ||||||
| 11. 支持**查看额度明细**。 | 11. 支持**查看额度明细**。 | ||||||
| @@ -103,6 +108,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 |     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 |     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||||
| 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 | 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 | ||||||
|  | 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 | ||||||
|  |  | ||||||
| ## 部署 | ## 部署 | ||||||
| ### 基于 Docker 进行部署 | ### 基于 Docker 进行部署 | ||||||
| @@ -343,35 +349,40 @@ graph LR | |||||||
|      + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 |      + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 | ||||||
|        + 如果报错 `Error 1040: Too many connections`,请适当减小该值。 |        + 如果报错 `Error 1040: Too many connections`,请适当减小该值。 | ||||||
|      + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 |      + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 | ||||||
| 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | 4. `LOG_SQL_DSN`:设置之后将为 `logs` 表使用独立的数据库,请使用 MySQL 或 PostgreSQL。 | ||||||
|  | 5. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | ||||||
|    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` |    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||||
| 5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 | 6. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||||
|    + 例子:`MEMORY_CACHE_ENABLED=true` |    + 例子:`MEMORY_CACHE_ENABLED=true` | ||||||
| 6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 | 7. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 | ||||||
|    + 例子:`SYNC_FREQUENCY=60` |    + 例子:`SYNC_FREQUENCY=60` | ||||||
| 7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | 8. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | ||||||
|    + 例子:`NODE_TYPE=slave` |    + 例子:`NODE_TYPE=slave` | ||||||
| 8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | 9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | ||||||
|    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` |    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` | ||||||
| 9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | 10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | ||||||
|    + 例子:`CHANNEL_TEST_FREQUENCY=1440` |    + 例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||||
| 10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | 11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||||
|     + 例子:`POLLING_INTERVAL=5` |     + 例子:`POLLING_INTERVAL=5` | ||||||
| 11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | 12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||||
|     + 例子:`BATCH_UPDATE_ENABLED=true` |     + 例子:`BATCH_UPDATE_ENABLED=true` | ||||||
|     + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 |     + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 | ||||||
| 12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | 13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||||
|     + 例子:`BATCH_UPDATE_INTERVAL=5` |     + 例子:`BATCH_UPDATE_INTERVAL=5` | ||||||
| 13. 请求频率限制: | 14. 请求频率限制: | ||||||
|     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 |     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 | ||||||
|     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 |     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 | ||||||
| 14. 编码器缓存设置: | 15. 编码器缓存设置: | ||||||
|     + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 |     + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 | ||||||
|     + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 |     + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 | ||||||
| 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | 16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | ||||||
| 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | 17. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | ||||||
| 17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 | 18. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 | ||||||
| 18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 | 19. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 | ||||||
|  | 20. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 | ||||||
|  | 21. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 | ||||||
|  | 22. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 | ||||||
|  | 23. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 | ||||||
|  |  | ||||||
| ### 命令行参数 | ### 命令行参数 | ||||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||||
| @@ -410,7 +421,7 @@ https://openai.justsong.cn | |||||||
|    + 检查你的接口地址和 API Key 有没有填对。 |    + 检查你的接口地址和 API Key 有没有填对。 | ||||||
|    + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 |    + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 | ||||||
| 6. 报错:`当前分组负载已饱和,请稍后再试` | 6. 报错:`当前分组负载已饱和,请稍后再试` | ||||||
|    + 上游通道 429 了。 |    + 上游渠道 429 了。 | ||||||
| 7. 升级之后我的数据会丢失吗? | 7. 升级之后我的数据会丢失吗? | ||||||
|    + 如果使用 MySQL,不会。 |    + 如果使用 MySQL,不会。 | ||||||
|    + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 |    + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 | ||||||
| @@ -418,8 +429,8 @@ https://openai.justsong.cn | |||||||
|    + 一般情况下不需要,系统将在初始化的时候自动调整。 |    + 一般情况下不需要,系统将在初始化的时候自动调整。 | ||||||
|    + 如果需要的话,我会在更新日志中说明,并给出脚本。 |    + 如果需要的话,我会在更新日志中说明,并给出脚本。 | ||||||
| 9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`? | 9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`? | ||||||
|    + 这是检测到 ability 表里有些记录的通道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的通道。 |    + 这是检测到 ability 表里有些记录的渠道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的渠道。 | ||||||
|    + 对于每一个通道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该通道支持该模型。 |    + 对于每一个渠道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该渠道支持该模型。 | ||||||
|  |  | ||||||
| ## 相关项目 | ## 相关项目 | ||||||
| * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | ||||||
|   | |||||||
							
								
								
									
										29
									
								
								common/blacklist/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								common/blacklist/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | |||||||
|  | package blacklist | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"sync" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var blackList sync.Map | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	blackList = sync.Map{} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func userId2Key(id int) string { | ||||||
|  | 	return fmt.Sprintf("userid_%d", id) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BanUser(id int) { | ||||||
|  | 	blackList.Store(userId2Key(id), true) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func UnbanUser(id int) { | ||||||
|  | 	blackList.Delete(userId2Key(id)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func IsUserBanned(id int) bool { | ||||||
|  | 	_, ok := blackList.Load(userId2Key(id)) | ||||||
|  | 	return ok | ||||||
|  | } | ||||||
| @@ -1,7 +1,7 @@ | |||||||
| package config | package config | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" | 	"github.com/songquanpeng/one-api/common/env" | ||||||
| 	"os" | 	"os" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"sync" | 	"sync" | ||||||
| @@ -52,6 +52,7 @@ var EmailDomainWhitelist = []string{ | |||||||
| } | } | ||||||
|  |  | ||||||
| var DebugEnabled = os.Getenv("DEBUG") == "true" | var DebugEnabled = os.Getenv("DEBUG") == "true" | ||||||
|  | var DebugSQLEnabled = os.Getenv("DEBUG_SQL") == "true" | ||||||
| var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" | var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" | ||||||
|  |  | ||||||
| var LogConsumeEnabled = true | var LogConsumeEnabled = true | ||||||
| @@ -69,17 +70,20 @@ var WeChatServerAddress = "" | |||||||
| var WeChatServerToken = "" | var WeChatServerToken = "" | ||||||
| var WeChatAccountQRCodeImageURL = "" | var WeChatAccountQRCodeImageURL = "" | ||||||
|  |  | ||||||
|  | var MessagePusherAddress = "" | ||||||
|  | var MessagePusherToken = "" | ||||||
|  |  | ||||||
| var TurnstileSiteKey = "" | var TurnstileSiteKey = "" | ||||||
| var TurnstileSecretKey = "" | var TurnstileSecretKey = "" | ||||||
|  |  | ||||||
| var QuotaForNewUser = 0 | var QuotaForNewUser int64 = 0 | ||||||
| var QuotaForInviter = 0 | var QuotaForInviter int64 = 0 | ||||||
| var QuotaForInvitee = 0 | var QuotaForInvitee int64 = 0 | ||||||
| var ChannelDisableThreshold = 5.0 | var ChannelDisableThreshold = 5.0 | ||||||
| var AutomaticDisableChannelEnabled = false | var AutomaticDisableChannelEnabled = false | ||||||
| var AutomaticEnableChannelEnabled = false | var AutomaticEnableChannelEnabled = false | ||||||
| var QuotaRemindThreshold = 1000 | var QuotaRemindThreshold int64 = 1000 | ||||||
| var PreConsumedQuota = 500 | var PreConsumedQuota int64 = 500 | ||||||
| var ApproximateTokenEnabled = false | var ApproximateTokenEnabled = false | ||||||
| var RetryTimes = 0 | var RetryTimes = 0 | ||||||
|  |  | ||||||
| @@ -90,28 +94,29 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" | |||||||
| var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | ||||||
| var RequestInterval = time.Duration(requestInterval) * time.Second | var RequestInterval = time.Duration(requestInterval) * time.Second | ||||||
|  |  | ||||||
| var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second | var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second | ||||||
|  |  | ||||||
| var BatchUpdateEnabled = false | var BatchUpdateEnabled = false | ||||||
| var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) | var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5) | ||||||
|  |  | ||||||
| var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second | var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second | ||||||
|  |  | ||||||
| var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") | var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE") | ||||||
|  |  | ||||||
| var Theme = helper.GetOrDefaultEnvString("THEME", "default") | var Theme = env.String("THEME", "default") | ||||||
| var ValidThemes = map[string]bool{ | var ValidThemes = map[string]bool{ | ||||||
| 	"default": true, | 	"default": true, | ||||||
| 	"berry":   true, | 	"berry":   true, | ||||||
|  | 	"air":     true, | ||||||
| } | } | ||||||
|  |  | ||||||
| // All duration's unit is seconds | // All duration's unit is seconds | ||||||
| // Shouldn't larger then RateLimitKeyExpirationDuration | // Shouldn't larger then RateLimitKeyExpirationDuration | ||||||
| var ( | var ( | ||||||
| 	GlobalApiRateLimitNum            = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) | 	GlobalApiRateLimitNum            = env.Int("GLOBAL_API_RATE_LIMIT", 180) | ||||||
| 	GlobalApiRateLimitDuration int64 = 3 * 60 | 	GlobalApiRateLimitDuration int64 = 3 * 60 | ||||||
|  |  | ||||||
| 	GlobalWebRateLimitNum            = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) | 	GlobalWebRateLimitNum            = env.Int("GLOBAL_WEB_RATE_LIMIT", 60) | ||||||
| 	GlobalWebRateLimitDuration int64 = 3 * 60 | 	GlobalWebRateLimitDuration int64 = 3 * 60 | ||||||
|  |  | ||||||
| 	UploadRateLimitNum            = 10 | 	UploadRateLimitNum            = 10 | ||||||
| @@ -125,3 +130,11 @@ var ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| var RateLimitKeyExpirationDuration = 20 * time.Minute | var RateLimitKeyExpirationDuration = 20 * time.Minute | ||||||
|  |  | ||||||
|  | var EnableMetric = env.Bool("ENABLE_METRIC", false) | ||||||
|  | var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10) | ||||||
|  | var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) | ||||||
|  | var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024) | ||||||
|  | var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) | ||||||
|  |  | ||||||
|  | var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") | ||||||
|   | |||||||
| @@ -15,6 +15,7 @@ const ( | |||||||
| const ( | const ( | ||||||
| 	UserStatusEnabled  = 1 // don't use 0, 0 is the default value! | 	UserStatusEnabled  = 1 // don't use 0, 0 is the default value! | ||||||
| 	UserStatusDisabled = 2 // also don't use 0 | 	UserStatusDisabled = 2 // also don't use 0 | ||||||
|  | 	UserStatusDeleted  = 3 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| @@ -38,32 +39,40 @@ const ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	ChannelTypeUnknown        = 0 | 	ChannelTypeUnknown = iota | ||||||
| 	ChannelTypeOpenAI         = 1 | 	ChannelTypeOpenAI | ||||||
| 	ChannelTypeAPI2D          = 2 | 	ChannelTypeAPI2D | ||||||
| 	ChannelTypeAzure          = 3 | 	ChannelTypeAzure | ||||||
| 	ChannelTypeCloseAI        = 4 | 	ChannelTypeCloseAI | ||||||
| 	ChannelTypeOpenAISB       = 5 | 	ChannelTypeOpenAISB | ||||||
| 	ChannelTypeOpenAIMax      = 6 | 	ChannelTypeOpenAIMax | ||||||
| 	ChannelTypeOhMyGPT        = 7 | 	ChannelTypeOhMyGPT | ||||||
| 	ChannelTypeCustom         = 8 | 	ChannelTypeCustom | ||||||
| 	ChannelTypeAILS           = 9 | 	ChannelTypeAILS | ||||||
| 	ChannelTypeAIProxy        = 10 | 	ChannelTypeAIProxy | ||||||
| 	ChannelTypePaLM           = 11 | 	ChannelTypePaLM | ||||||
| 	ChannelTypeAPI2GPT        = 12 | 	ChannelTypeAPI2GPT | ||||||
| 	ChannelTypeAIGC2D         = 13 | 	ChannelTypeAIGC2D | ||||||
| 	ChannelTypeAnthropic      = 14 | 	ChannelTypeAnthropic | ||||||
| 	ChannelTypeBaidu          = 15 | 	ChannelTypeBaidu | ||||||
| 	ChannelTypeZhipu          = 16 | 	ChannelTypeZhipu | ||||||
| 	ChannelTypeAli            = 17 | 	ChannelTypeAli | ||||||
| 	ChannelTypeXunfei         = 18 | 	ChannelTypeXunfei | ||||||
| 	ChannelType360            = 19 | 	ChannelType360 | ||||||
| 	ChannelTypeOpenRouter     = 20 | 	ChannelTypeOpenRouter | ||||||
| 	ChannelTypeAIProxyLibrary = 21 | 	ChannelTypeAIProxyLibrary | ||||||
| 	ChannelTypeFastGPT        = 22 | 	ChannelTypeFastGPT | ||||||
| 	ChannelTypeTencent        = 23 | 	ChannelTypeTencent | ||||||
| 	ChannelTypeGemini         = 24 | 	ChannelTypeGemini | ||||||
| 	ChannelTypeMoonshot       = 25 | 	ChannelTypeMoonshot | ||||||
|  | 	ChannelTypeBaichuan | ||||||
|  | 	ChannelTypeMinimax | ||||||
|  | 	ChannelTypeMistral | ||||||
|  | 	ChannelTypeGroq | ||||||
|  | 	ChannelTypeOllama | ||||||
|  | 	ChannelTypeLingYiWanWu | ||||||
|  |  | ||||||
|  | 	ChannelTypeDummy | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ChannelBaseURLs = []string{ | var ChannelBaseURLs = []string{ | ||||||
| @@ -93,6 +102,12 @@ var ChannelBaseURLs = []string{ | |||||||
| 	"https://hunyuan.cloud.tencent.com",         // 23 | 	"https://hunyuan.cloud.tencent.com",         // 23 | ||||||
| 	"https://generativelanguage.googleapis.com", // 24 | 	"https://generativelanguage.googleapis.com", // 24 | ||||||
| 	"https://api.moonshot.cn",                   // 25 | 	"https://api.moonshot.cn",                   // 25 | ||||||
|  | 	"https://api.baichuan-ai.com",               // 26 | ||||||
|  | 	"https://api.minimax.chat",                  // 27 | ||||||
|  | 	"https://api.mistral.ai",                    // 28 | ||||||
|  | 	"https://api.groq.com/openai",               // 29 | ||||||
|  | 	"http://localhost:11434",                    // 30 | ||||||
|  | 	"https://api.lingyiwanwu.com",               // 31 | ||||||
| } | } | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
|   | |||||||
| @@ -1,9 +1,12 @@ | |||||||
| package common | package common | ||||||
|  |  | ||||||
| import "github.com/songquanpeng/one-api/common/helper" | import ( | ||||||
|  | 	"github.com/songquanpeng/one-api/common/env" | ||||||
|  | ) | ||||||
|  |  | ||||||
| var UsingSQLite = false | var UsingSQLite = false | ||||||
| var UsingPostgreSQL = false | var UsingPostgreSQL = false | ||||||
|  | var UsingMySQL = false | ||||||
|  |  | ||||||
| var SQLitePath = "one-api.db" | var SQLitePath = "one-api.db" | ||||||
| var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) | var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000) | ||||||
|   | |||||||
							
								
								
									
										42
									
								
								common/env/helper.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								common/env/helper.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | |||||||
|  | package env | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"os" | ||||||
|  | 	"strconv" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func Bool(env string, defaultValue bool) bool { | ||||||
|  | 	if env == "" || os.Getenv(env) == "" { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	return os.Getenv(env) == "true" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Int(env string, defaultValue int) int { | ||||||
|  | 	if env == "" || os.Getenv(env) == "" { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	num, err := strconv.Atoi(os.Getenv(env)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	return num | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Float64(env string, defaultValue float64) float64 { | ||||||
|  | 	if env == "" || os.Getenv(env) == "" { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	num, err := strconv.ParseFloat(os.Getenv(env), 64) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	return num | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func String(env string, defaultValue string) string { | ||||||
|  | 	if env == "" || os.Getenv(env) == "" { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	return os.Getenv(env) | ||||||
|  | } | ||||||
| @@ -8,12 +8,24 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func UnmarshalBodyReusable(c *gin.Context, v any) error { | const KeyRequestBody = "key_request_body" | ||||||
|  |  | ||||||
|  | func GetRequestBody(c *gin.Context) ([]byte, error) { | ||||||
|  | 	requestBody, _ := c.Get(KeyRequestBody) | ||||||
|  | 	if requestBody != nil { | ||||||
|  | 		return requestBody.([]byte), nil | ||||||
|  | 	} | ||||||
| 	requestBody, err := io.ReadAll(c.Request.Body) | 	requestBody, err := io.ReadAll(c.Request.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	err = c.Request.Body.Close() | 	_ = c.Request.Body.Close() | ||||||
|  | 	c.Set(KeyRequestBody, requestBody) | ||||||
|  | 	return requestBody.([]byte), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||||
|  | 	requestBody, err := GetRequestBody(c) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -3,12 +3,10 @@ package helper | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/google/uuid" | 	"github.com/google/uuid" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"html/template" | 	"html/template" | ||||||
| 	"log" | 	"log" | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
| 	"net" | 	"net" | ||||||
| 	"os" |  | ||||||
| 	"os/exec" | 	"os/exec" | ||||||
| 	"runtime" | 	"runtime" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| @@ -187,6 +185,10 @@ func GetTimeString() string { | |||||||
| 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) | 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GenRequestID() string { | ||||||
|  | 	return GetTimeString() + GetRandomNumberString(8) | ||||||
|  | } | ||||||
|  |  | ||||||
| func Max(a int, b int) int { | func Max(a int, b int) int { | ||||||
| 	if a >= b { | 	if a >= b { | ||||||
| 		return a | 		return a | ||||||
| @@ -195,25 +197,6 @@ func Max(a int, b int) int { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetOrDefaultEnvInt(env string, defaultValue int) int { |  | ||||||
| 	if env == "" || os.Getenv(env) == "" { |  | ||||||
| 		return defaultValue |  | ||||||
| 	} |  | ||||||
| 	num, err := strconv.Atoi(os.Getenv(env)) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) |  | ||||||
| 		return defaultValue |  | ||||||
| 	} |  | ||||||
| 	return num |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetOrDefaultEnvString(env string, defaultValue string) string { |  | ||||||
| 	if env == "" || os.Getenv(env) == "" { |  | ||||||
| 		return defaultValue |  | ||||||
| 	} |  | ||||||
| 	return os.Getenv(env) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func AssignOrDefault(value string, defaultValue string) string { | func AssignOrDefault(value string, defaultValue string) string { | ||||||
| 	if len(value) != 0 { | 	if len(value) != 0 { | ||||||
| 		return value | 		return value | ||||||
|   | |||||||
| @@ -4,6 +4,8 @@ import ( | |||||||
| 	"context" | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
| 	"io" | 	"io" | ||||||
| 	"log" | 	"log" | ||||||
| 	"os" | 	"os" | ||||||
| @@ -13,14 +15,12 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
|  | 	loggerDEBUG = "DEBUG" | ||||||
| 	loggerINFO  = "INFO" | 	loggerINFO  = "INFO" | ||||||
| 	loggerWarn  = "WARN" | 	loggerWarn  = "WARN" | ||||||
| 	loggerError = "ERR" | 	loggerError = "ERR" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const maxLogCount = 1000000 |  | ||||||
|  |  | ||||||
| var logCount int |  | ||||||
| var setupLogLock sync.Mutex | var setupLogLock sync.Mutex | ||||||
| var setupLogWorking bool | var setupLogWorking bool | ||||||
|  |  | ||||||
| @@ -55,6 +55,12 @@ func SysError(s string) { | |||||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func Debug(ctx context.Context, msg string) { | ||||||
|  | 	if config.DebugEnabled { | ||||||
|  | 		logHelper(ctx, loggerDEBUG, msg) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func Info(ctx context.Context, msg string) { | func Info(ctx context.Context, msg string) { | ||||||
| 	logHelper(ctx, loggerINFO, msg) | 	logHelper(ctx, loggerINFO, msg) | ||||||
| } | } | ||||||
| @@ -67,6 +73,10 @@ func Error(ctx context.Context, msg string) { | |||||||
| 	logHelper(ctx, loggerError, msg) | 	logHelper(ctx, loggerError, msg) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func Debugf(ctx context.Context, format string, a ...any) { | ||||||
|  | 	Debug(ctx, fmt.Sprintf(format, a...)) | ||||||
|  | } | ||||||
|  |  | ||||||
| func Infof(ctx context.Context, format string, a ...any) { | func Infof(ctx context.Context, format string, a ...any) { | ||||||
| 	Info(ctx, fmt.Sprintf(format, a...)) | 	Info(ctx, fmt.Sprintf(format, a...)) | ||||||
| } | } | ||||||
| @@ -85,11 +95,12 @@ func logHelper(ctx context.Context, level string, msg string) { | |||||||
| 		writer = gin.DefaultWriter | 		writer = gin.DefaultWriter | ||||||
| 	} | 	} | ||||||
| 	id := ctx.Value(RequestIdKey) | 	id := ctx.Value(RequestIdKey) | ||||||
|  | 	if id == nil { | ||||||
|  | 		id = helper.GenRequestID() | ||||||
|  | 	} | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| 	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) | 	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) | ||||||
| 	logCount++ // we don't need accurate count, so no lock here | 	if !setupLogWorking { | ||||||
| 	if logCount > maxLogCount && !setupLogWorking { |  | ||||||
| 		logCount = 0 |  | ||||||
| 		setupLogWorking = true | 		setupLogWorking = true | ||||||
| 		go func() { | 		go func() { | ||||||
| 			SetupLogger() | 			SetupLogger() | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| package common | package message | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
| @@ -12,6 +12,9 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func SendEmail(subject string, receiver string, content string) error { | func SendEmail(subject string, receiver string, content string) error { | ||||||
|  | 	if receiver == "" { | ||||||
|  | 		return fmt.Errorf("receiver is empty") | ||||||
|  | 	} | ||||||
| 	if config.SMTPFrom == "" { // for compatibility | 	if config.SMTPFrom == "" { // for compatibility | ||||||
| 		config.SMTPFrom = config.SMTPAccount | 		config.SMTPFrom = config.SMTPAccount | ||||||
| 	} | 	} | ||||||
							
								
								
									
										22
									
								
								common/message/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								common/message/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | |||||||
|  | package message | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	ByAll           = "all" | ||||||
|  | 	ByEmail         = "email" | ||||||
|  | 	ByMessagePusher = "message_pusher" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func Notify(by string, title string, description string, content string) error { | ||||||
|  | 	if by == ByEmail { | ||||||
|  | 		return SendEmail(title, config.RootUserEmail, content) | ||||||
|  | 	} | ||||||
|  | 	if by == ByMessagePusher { | ||||||
|  | 		return SendMessage(title, description, content) | ||||||
|  | 	} | ||||||
|  | 	return fmt.Errorf("unknown notify method: %s", by) | ||||||
|  | } | ||||||
							
								
								
									
										53
									
								
								common/message/message-pusher.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								common/message/message-pusher.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | |||||||
|  | package message | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"net/http" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type request struct { | ||||||
|  | 	Title       string `json:"title"` | ||||||
|  | 	Description string `json:"description"` | ||||||
|  | 	Content     string `json:"content"` | ||||||
|  | 	URL         string `json:"url"` | ||||||
|  | 	Channel     string `json:"channel"` | ||||||
|  | 	Token       string `json:"token"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type response struct { | ||||||
|  | 	Success bool   `json:"success"` | ||||||
|  | 	Message string `json:"message"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func SendMessage(title string, description string, content string) error { | ||||||
|  | 	if config.MessagePusherAddress == "" { | ||||||
|  | 		return errors.New("message pusher address is not set") | ||||||
|  | 	} | ||||||
|  | 	req := request{ | ||||||
|  | 		Title:       title, | ||||||
|  | 		Description: description, | ||||||
|  | 		Content:     content, | ||||||
|  | 		Token:       config.MessagePusherToken, | ||||||
|  | 	} | ||||||
|  | 	data, err := json.Marshal(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	resp, err := http.Post(config.MessagePusherAddress, | ||||||
|  | 		"application/json", bytes.NewBuffer(data)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	var res response | ||||||
|  | 	err = json.NewDecoder(resp.Body).Decode(&res) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if !res.Success { | ||||||
|  | 		return errors.New(res.Message) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
| @@ -4,32 +4,8 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var DalleSizeRatios = map[string]map[string]float64{ |  | ||||||
| 	"dall-e-2": { |  | ||||||
| 		"256x256":   1, |  | ||||||
| 		"512x512":   1.125, |  | ||||||
| 		"1024x1024": 1.25, |  | ||||||
| 	}, |  | ||||||
| 	"dall-e-3": { |  | ||||||
| 		"1024x1024": 1, |  | ||||||
| 		"1024x1792": 2, |  | ||||||
| 		"1792x1024": 2, |  | ||||||
| 	}, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var DalleGenerationImageAmounts = map[string][2]int{ |  | ||||||
| 	"dall-e-2": {1, 10}, |  | ||||||
| 	"dall-e-3": {1, 1}, // OpenAI allows n=1 currently. |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var DalleImagePromptLengthLimitations = map[string]int{ |  | ||||||
| 	"dall-e-2": 1000, |  | ||||||
| 	"dall-e-3": 4000, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	USD2RMB = 7 | 	USD2RMB = 7 | ||||||
| 	USD     = 500 // $0.002 = 1 -> $1 = 500 | 	USD     = 500 // $0.002 = 1 -> $1 = 500 | ||||||
| @@ -40,7 +16,6 @@ const ( | |||||||
| // https://platform.openai.com/docs/models/model-endpoint-compatibility | // https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf | // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf | ||||||
| // https://openai.com/pricing | // https://openai.com/pricing | ||||||
| // TODO: when a new api is enabled, check the pricing here |  | ||||||
| // 1 === $0.002 / 1K tokens | // 1 === $0.002 / 1K tokens | ||||||
| // 1 === ¥0.014 / 1k tokens | // 1 === ¥0.014 / 1k tokens | ||||||
| var ModelRatio = map[string]float64{ | var ModelRatio = map[string]float64{ | ||||||
| @@ -55,7 +30,7 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"gpt-4-0125-preview":      5,    // $0.01 / 1K tokens | 	"gpt-4-0125-preview":      5,    // $0.01 / 1K tokens | ||||||
| 	"gpt-4-turbo-preview":     5,    // $0.01 / 1K tokens | 	"gpt-4-turbo-preview":     5,    // $0.01 / 1K tokens | ||||||
| 	"gpt-4-vision-preview":    5,    // $0.01 / 1K tokens | 	"gpt-4-vision-preview":    5,    // $0.01 / 1K tokens | ||||||
| 	"gpt-3.5-turbo":           0.75, // $0.0015 / 1K tokens | 	"gpt-3.5-turbo":           0.25, // $0.0005 / 1K tokens | ||||||
| 	"gpt-3.5-turbo-0301":      0.75, | 	"gpt-3.5-turbo-0301":      0.75, | ||||||
| 	"gpt-3.5-turbo-0613":      0.75, | 	"gpt-3.5-turbo-0613":      0.75, | ||||||
| 	"gpt-3.5-turbo-16k":       1.5, // $0.003 / 1K tokens | 	"gpt-3.5-turbo-16k":       1.5, // $0.003 / 1K tokens | ||||||
| @@ -87,21 +62,35 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"text-search-ada-doc-001": 10, | 	"text-search-ada-doc-001": 10, | ||||||
| 	"text-moderation-stable":  0.1, | 	"text-moderation-stable":  0.1, | ||||||
| 	"text-moderation-latest":  0.1, | 	"text-moderation-latest":  0.1, | ||||||
| 	"dall-e-2":                8,     // $0.016 - $0.020 / image | 	"dall-e-2":                8,  // $0.016 - $0.020 / image | ||||||
| 	"dall-e-3":                20,    // $0.040 - $0.120 / image | 	"dall-e-3":                20, // $0.040 - $0.120 / image | ||||||
| 	"claude-instant-1":        0.815, // $1.63 / 1M tokens | 	// https://www.anthropic.com/api#pricing | ||||||
| 	"claude-2":                5.51,  // $11.02 / 1M tokens | 	"claude-instant-1.2":       0.8 / 1000 * USD, | ||||||
| 	"claude-2.0":              5.51,  // $11.02 / 1M tokens | 	"claude-2.0":               8.0 / 1000 * USD, | ||||||
| 	"claude-2.1":              5.51,  // $11.02 / 1M tokens | 	"claude-2.1":               8.0 / 1000 * USD, | ||||||
|  | 	"claude-3-haiku-20240307":  0.25 / 1000 * USD, | ||||||
|  | 	"claude-3-sonnet-20240229": 3.0 / 1000 * USD, | ||||||
|  | 	"claude-3-opus-20240229":   15.0 / 1000 * USD, | ||||||
| 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 | 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 | ||||||
| 	"ERNIE-Bot":                 0.8572,     // ¥0.012 / 1k tokens | 	"ERNIE-Bot":       0.8572,     // ¥0.012 / 1k tokens | ||||||
| 	"ERNIE-Bot-turbo":           0.5715,     // ¥0.008 / 1k tokens | 	"ERNIE-Bot-turbo": 0.5715,     // ¥0.008 / 1k tokens | ||||||
| 	"ERNIE-Bot-4":               0.12 * RMB, // ¥0.12 / 1k tokens | 	"ERNIE-Bot-4":     0.12 * RMB, // ¥0.12 / 1k tokens | ||||||
| 	"ERNIE-Bot-8k":              0.024 * RMB, | 	"ERNIE-Bot-8k":    0.024 * RMB, | ||||||
| 	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens | 	"Embedding-V1":    0.1429, // ¥0.002 / 1k tokens | ||||||
|  | 	"bge-large-zh":    0.002 * RMB, | ||||||
|  | 	"bge-large-en":    0.002 * RMB, | ||||||
|  | 	"bge-large-8k":    0.002 * RMB, | ||||||
|  | 	// https://ai.google.dev/pricing | ||||||
| 	"PaLM-2":                    1, | 	"PaLM-2":                    1, | ||||||
| 	"gemini-pro":                1,      // $0.00025 / 1k characters -> $0.001 / 1k tokens | 	"gemini-pro":                1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||||
| 	"gemini-pro-vision":         1,      // $0.00025 / 1k characters -> $0.001 / 1k tokens | 	"gemini-pro-vision":         1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||||
|  | 	"gemini-1.0-pro-vision-001": 1, | ||||||
|  | 	"gemini-1.0-pro-001":        1, | ||||||
|  | 	"gemini-1.5-pro":            1, | ||||||
|  | 	// https://open.bigmodel.cn/pricing | ||||||
|  | 	"glm-4":                     0.1 * RMB, | ||||||
|  | 	"glm-4v":                    0.1 * RMB, | ||||||
|  | 	"glm-3-turbo":               0.005 * RMB, | ||||||
| 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens | 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens | ||||||
| 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | ||||||
| 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | ||||||
| @@ -127,6 +116,66 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"moonshot-v1-8k":   0.012 * RMB, | 	"moonshot-v1-8k":   0.012 * RMB, | ||||||
| 	"moonshot-v1-32k":  0.024 * RMB, | 	"moonshot-v1-32k":  0.024 * RMB, | ||||||
| 	"moonshot-v1-128k": 0.06 * RMB, | 	"moonshot-v1-128k": 0.06 * RMB, | ||||||
|  | 	// https://platform.baichuan-ai.com/price | ||||||
|  | 	"Baichuan2-Turbo":      0.008 * RMB, | ||||||
|  | 	"Baichuan2-Turbo-192k": 0.016 * RMB, | ||||||
|  | 	"Baichuan2-53B":        0.02 * RMB, | ||||||
|  | 	// https://api.minimax.chat/document/price | ||||||
|  | 	"abab6-chat":    0.1 * RMB, | ||||||
|  | 	"abab5.5-chat":  0.015 * RMB, | ||||||
|  | 	"abab5.5s-chat": 0.005 * RMB, | ||||||
|  | 	// https://docs.mistral.ai/platform/pricing/ | ||||||
|  | 	"open-mistral-7b":       0.25 / 1000 * USD, | ||||||
|  | 	"open-mixtral-8x7b":     0.7 / 1000 * USD, | ||||||
|  | 	"mistral-small-latest":  2.0 / 1000 * USD, | ||||||
|  | 	"mistral-medium-latest": 2.7 / 1000 * USD, | ||||||
|  | 	"mistral-large-latest":  8.0 / 1000 * USD, | ||||||
|  | 	"mistral-embed":         0.1 / 1000 * USD, | ||||||
|  | 	// https://wow.groq.com/ | ||||||
|  | 	"llama2-70b-4096":    0.7 / 1000 * USD, | ||||||
|  | 	"llama2-7b-2048":     0.1 / 1000 * USD, | ||||||
|  | 	"mixtral-8x7b-32768": 0.27 / 1000 * USD, | ||||||
|  | 	"gemma-7b-it":        0.1 / 1000 * USD, | ||||||
|  | 	// https://platform.lingyiwanwu.com/docs#-计费单元 | ||||||
|  | 	"yi-34b-chat-0205": 2.5 / 1000 * RMB, | ||||||
|  | 	"yi-34b-chat-200k": 12.0 / 1000 * RMB, | ||||||
|  | 	"yi-vl-plus":       6.0 / 1000 * RMB, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var CompletionRatio = map[string]float64{} | ||||||
|  |  | ||||||
|  | var DefaultModelRatio map[string]float64 | ||||||
|  | var DefaultCompletionRatio map[string]float64 | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	DefaultModelRatio = make(map[string]float64) | ||||||
|  | 	for k, v := range ModelRatio { | ||||||
|  | 		DefaultModelRatio[k] = v | ||||||
|  | 	} | ||||||
|  | 	DefaultCompletionRatio = make(map[string]float64) | ||||||
|  | 	for k, v := range CompletionRatio { | ||||||
|  | 		DefaultCompletionRatio[k] = v | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func AddNewMissingRatio(oldRatio string) string { | ||||||
|  | 	newRatio := make(map[string]float64) | ||||||
|  | 	err := json.Unmarshal([]byte(oldRatio), &newRatio) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.SysError("error unmarshalling old ratio: " + err.Error()) | ||||||
|  | 		return oldRatio | ||||||
|  | 	} | ||||||
|  | 	for k, v := range DefaultModelRatio { | ||||||
|  | 		if _, ok := newRatio[k]; !ok { | ||||||
|  | 			newRatio[k] = v | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	jsonBytes, err := json.Marshal(newRatio) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.SysError("error marshalling new ratio: " + err.Error()) | ||||||
|  | 		return oldRatio | ||||||
|  | 	} | ||||||
|  | 	return string(jsonBytes) | ||||||
| } | } | ||||||
|  |  | ||||||
| func ModelRatio2JSONString() string { | func ModelRatio2JSONString() string { | ||||||
| @@ -147,6 +196,9 @@ func GetModelRatio(name string) float64 { | |||||||
| 		name = strings.TrimSuffix(name, "-internet") | 		name = strings.TrimSuffix(name, "-internet") | ||||||
| 	} | 	} | ||||||
| 	ratio, ok := ModelRatio[name] | 	ratio, ok := ModelRatio[name] | ||||||
|  | 	if !ok { | ||||||
|  | 		ratio, ok = DefaultModelRatio[name] | ||||||
|  | 	} | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		logger.SysError("model ratio not found: " + name) | 		logger.SysError("model ratio not found: " + name) | ||||||
| 		return 30 | 		return 30 | ||||||
| @@ -154,8 +206,6 @@ func GetModelRatio(name string) float64 { | |||||||
| 	return ratio | 	return ratio | ||||||
| } | } | ||||||
|  |  | ||||||
| var CompletionRatio = map[string]float64{} |  | ||||||
|  |  | ||||||
| func CompletionRatio2JSONString() string { | func CompletionRatio2JSONString() string { | ||||||
| 	jsonBytes, err := json.Marshal(CompletionRatio) | 	jsonBytes, err := json.Marshal(CompletionRatio) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -173,8 +223,11 @@ func GetCompletionRatio(name string) float64 { | |||||||
| 	if ratio, ok := CompletionRatio[name]; ok { | 	if ratio, ok := CompletionRatio[name]; ok { | ||||||
| 		return ratio | 		return ratio | ||||||
| 	} | 	} | ||||||
|  | 	if ratio, ok := DefaultCompletionRatio[name]; ok { | ||||||
|  | 		return ratio | ||||||
|  | 	} | ||||||
| 	if strings.HasPrefix(name, "gpt-3.5") { | 	if strings.HasPrefix(name, "gpt-3.5") { | ||||||
| 		if strings.HasSuffix(name, "0125") { | 		if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") { | ||||||
| 			// https://openai.com/blog/new-embedding-models-and-api-updates | 			// https://openai.com/blog/new-embedding-models-and-api-updates | ||||||
| 			// Updated GPT-3.5 Turbo model and lower pricing | 			// Updated GPT-3.5 Turbo model and lower pricing | ||||||
| 			return 3 | 			return 3 | ||||||
| @@ -182,16 +235,7 @@ func GetCompletionRatio(name string) float64 { | |||||||
| 		if strings.HasSuffix(name, "1106") { | 		if strings.HasSuffix(name, "1106") { | ||||||
| 			return 2 | 			return 2 | ||||||
| 		} | 		} | ||||||
| 		if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" { | 		return 4.0 / 3.0 | ||||||
| 			// TODO: clear this after 2023-12-11 |  | ||||||
| 			now := time.Now() |  | ||||||
| 			// https://platform.openai.com/docs/models/continuous-model-upgrades |  | ||||||
| 			// if after 2023-12-11, use 2 |  | ||||||
| 			if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) { |  | ||||||
| 				return 2 |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		return 1.333333 |  | ||||||
| 	} | 	} | ||||||
| 	if strings.HasPrefix(name, "gpt-4") { | 	if strings.HasPrefix(name, "gpt-4") { | ||||||
| 		if strings.HasSuffix(name, "preview") { | 		if strings.HasSuffix(name, "preview") { | ||||||
| @@ -199,11 +243,21 @@ func GetCompletionRatio(name string) float64 { | |||||||
| 		} | 		} | ||||||
| 		return 2 | 		return 2 | ||||||
| 	} | 	} | ||||||
| 	if strings.HasPrefix(name, "claude-instant-1") { | 	if strings.HasPrefix(name, "claude-3") { | ||||||
| 		return 3.38 | 		return 5 | ||||||
| 	} | 	} | ||||||
| 	if strings.HasPrefix(name, "claude-2") { | 	if strings.HasPrefix(name, "claude-") { | ||||||
| 		return 2.965517 | 		return 3 | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(name, "mistral-") { | ||||||
|  | 		return 3 | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(name, "gemini-") { | ||||||
|  | 		return 3 | ||||||
|  | 	} | ||||||
|  | 	switch name { | ||||||
|  | 	case "llama2-70b-4096": | ||||||
|  | 		return 0.8 / 0.7 | ||||||
| 	} | 	} | ||||||
| 	return 1 | 	return 1 | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								common/random.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								common/random.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | |||||||
|  | package common | ||||||
|  |  | ||||||
|  | import "math/rand" | ||||||
|  |  | ||||||
|  | // RandRange returns a random number between min and max (max is not included) | ||||||
|  | func RandRange(min, max int) int { | ||||||
|  | 	return min + rand.Intn(max-min) | ||||||
|  | } | ||||||
| @@ -5,7 +5,7 @@ import ( | |||||||
| 	"github.com/songquanpeng/one-api/common/config" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func LogQuota(quota int) string { | func LogQuota(quota int64) string { | ||||||
| 	if config.DisplayInCurrencyEnabled { | 	if config.DisplayInCurrencyEnabled { | ||||||
| 		return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) | 		return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) | ||||||
| 	} else { | 	} else { | ||||||
|   | |||||||
| @@ -8,8 +8,8 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| func GetSubscription(c *gin.Context) { | func GetSubscription(c *gin.Context) { | ||||||
| 	var remainQuota int | 	var remainQuota int64 | ||||||
| 	var usedQuota int | 	var usedQuota int64 | ||||||
| 	var err error | 	var err error | ||||||
| 	var token *model.Token | 	var token *model.Token | ||||||
| 	var expiredTime int64 | 	var expiredTime int64 | ||||||
| @@ -60,7 +60,7 @@ func GetSubscription(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GetUsage(c *gin.Context) { | func GetUsage(c *gin.Context) { | ||||||
| 	var quota int | 	var quota int64 | ||||||
| 	var err error | 	var err error | ||||||
| 	var token *model.Token | 	var token *model.Token | ||||||
| 	if config.DisplayTokenStatEnabled { | 	if config.DisplayTokenStatEnabled { | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ import ( | |||||||
| 	"github.com/songquanpeng/one-api/common/config" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"github.com/songquanpeng/one-api/model" | 	"github.com/songquanpeng/one-api/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/monitor" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| @@ -295,7 +296,7 @@ func UpdateChannelBalance(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func updateAllChannelsBalance() error { | func updateAllChannelsBalance() error { | ||||||
| 	channels, err := model.GetAllChannels(0, 0, true) | 	channels, err := model.GetAllChannels(0, 0, "all") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @@ -313,7 +314,7 @@ func updateAllChannelsBalance() error { | |||||||
| 		} else { | 		} else { | ||||||
| 			// err is nil & balance <= 0 means quota is used up | 			// err is nil & balance <= 0 means quota is used up | ||||||
| 			if balance <= 0 { | 			if balance <= 0 { | ||||||
| 				disableChannel(channel.Id, channel.Name, "余额不足") | 				monitor.DisableChannel(channel.Id, channel.Name, "余额不足") | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		time.Sleep(config.RequestInterval) | 		time.Sleep(config.RequestInterval) | ||||||
| @@ -322,15 +323,14 @@ func updateAllChannelsBalance() error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateAllChannelsBalance(c *gin.Context) { | func UpdateAllChannelsBalance(c *gin.Context) { | ||||||
| 	// TODO: make it async | 	//err := updateAllChannelsBalance() | ||||||
| 	err := updateAllChannelsBalance() | 	//if err != nil { | ||||||
| 	if err != nil { | 	//	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 	//		"success": false, | ||||||
| 			"success": false, | 	//		"message": err.Error(), | ||||||
| 			"message": err.Error(), | 	//	}) | ||||||
| 		}) | 	//	return | ||||||
| 		return | 	//} | ||||||
| 	} |  | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
|   | |||||||
| @@ -8,7 +8,10 @@ import ( | |||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/message" | ||||||
|  | 	"github.com/songquanpeng/one-api/middleware" | ||||||
| 	"github.com/songquanpeng/one-api/model" | 	"github.com/songquanpeng/one-api/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/monitor" | ||||||
| 	"github.com/songquanpeng/one-api/relay/constant" | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
| 	"github.com/songquanpeng/one-api/relay/helper" | 	"github.com/songquanpeng/one-api/relay/helper" | ||||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||||
| @@ -18,6 +21,7 @@ import ( | |||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -26,7 +30,7 @@ import ( | |||||||
|  |  | ||||||
| func buildTestRequest() *relaymodel.GeneralOpenAIRequest { | func buildTestRequest() *relaymodel.GeneralOpenAIRequest { | ||||||
| 	testRequest := &relaymodel.GeneralOpenAIRequest{ | 	testRequest := &relaymodel.GeneralOpenAIRequest{ | ||||||
| 		MaxTokens: 1, | 		MaxTokens: 2, | ||||||
| 		Stream:    false, | 		Stream:    false, | ||||||
| 		Model:     "gpt-3.5-turbo", | 		Model:     "gpt-3.5-turbo", | ||||||
| 	} | 	} | ||||||
| @@ -51,6 +55,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error | |||||||
| 	c.Request.Header.Set("Content-Type", "application/json") | 	c.Request.Header.Set("Content-Type", "application/json") | ||||||
| 	c.Set("channel", channel.Type) | 	c.Set("channel", channel.Type) | ||||||
| 	c.Set("base_url", channel.GetBaseURL()) | 	c.Set("base_url", channel.GetBaseURL()) | ||||||
|  | 	middleware.SetupContextForSelectedChannel(c, channel, "") | ||||||
| 	meta := util.GetRelayMeta(c) | 	meta := util.GetRelayMeta(c) | ||||||
| 	apiType := constant.ChannelType2APIType(channel.Type) | 	apiType := constant.ChannelType2APIType(channel.Type) | ||||||
| 	adaptor := helper.GetAdaptor(apiType) | 	adaptor := helper.GetAdaptor(apiType) | ||||||
| @@ -59,6 +64,12 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error | |||||||
| 	} | 	} | ||||||
| 	adaptor.Init(meta) | 	adaptor.Init(meta) | ||||||
| 	modelName := adaptor.GetModelList()[0] | 	modelName := adaptor.GetModelList()[0] | ||||||
|  | 	if !strings.Contains(channel.Models, modelName) { | ||||||
|  | 		modelNames := strings.Split(channel.Models, ",") | ||||||
|  | 		if len(modelNames) > 0 { | ||||||
|  | 			modelName = modelNames[0] | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 	request := buildTestRequest() | 	request := buildTestRequest() | ||||||
| 	request.Model = modelName | 	request.Model = modelName | ||||||
| 	meta.OriginModelName, meta.ActualModelName = modelName, modelName | 	meta.OriginModelName, meta.ActualModelName = modelName, modelName | ||||||
| @@ -139,33 +150,7 @@ func TestChannel(c *gin.Context) { | |||||||
| var testAllChannelsLock sync.Mutex | var testAllChannelsLock sync.Mutex | ||||||
| var testAllChannelsRunning bool = false | var testAllChannelsRunning bool = false | ||||||
|  |  | ||||||
| func notifyRootUser(subject string, content string) { | func testChannels(notify bool, scope string) error { | ||||||
| 	if config.RootUserEmail == "" { |  | ||||||
| 		config.RootUserEmail = model.GetRootUserEmail() |  | ||||||
| 	} |  | ||||||
| 	err := common.SendEmail(subject, config.RootUserEmail, content) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // disable & notify |  | ||||||
| func disableChannel(channelId int, channelName string, reason string) { |  | ||||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) |  | ||||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) |  | ||||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) |  | ||||||
| 	notifyRootUser(subject, content) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // enable & notify |  | ||||||
| func enableChannel(channelId int, channelName string) { |  | ||||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) |  | ||||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) |  | ||||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) |  | ||||||
| 	notifyRootUser(subject, content) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func testAllChannels(notify bool) error { |  | ||||||
| 	if config.RootUserEmail == "" { | 	if config.RootUserEmail == "" { | ||||||
| 		config.RootUserEmail = model.GetRootUserEmail() | 		config.RootUserEmail = model.GetRootUserEmail() | ||||||
| 	} | 	} | ||||||
| @@ -176,7 +161,7 @@ func testAllChannels(notify bool) error { | |||||||
| 	} | 	} | ||||||
| 	testAllChannelsRunning = true | 	testAllChannelsRunning = true | ||||||
| 	testAllChannelsLock.Unlock() | 	testAllChannelsLock.Unlock() | ||||||
| 	channels, err := model.GetAllChannels(0, 0, true) | 	channels, err := model.GetAllChannels(0, 0, scope) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @@ -193,13 +178,17 @@ func testAllChannels(notify bool) error { | |||||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | 			milliseconds := tok.Sub(tik).Milliseconds() | ||||||
| 			if isChannelEnabled && milliseconds > disableThreshold { | 			if isChannelEnabled && milliseconds > disableThreshold { | ||||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | 				if config.AutomaticDisableChannelEnabled { | ||||||
|  | 					monitor.DisableChannel(channel.Id, channel.Name, err.Error()) | ||||||
|  | 				} else { | ||||||
|  | 					_ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s (%d)测试超时", channel.Name, channel.Id), "", err.Error()) | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 			if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { | 			if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { | ||||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | 				monitor.DisableChannel(channel.Id, channel.Name, err.Error()) | ||||||
| 			} | 			} | ||||||
| 			if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { | 			if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { | ||||||
| 				enableChannel(channel.Id, channel.Name) | 				monitor.EnableChannel(channel.Id, channel.Name) | ||||||
| 			} | 			} | ||||||
| 			channel.UpdateResponseTime(milliseconds) | 			channel.UpdateResponseTime(milliseconds) | ||||||
| 			time.Sleep(config.RequestInterval) | 			time.Sleep(config.RequestInterval) | ||||||
| @@ -208,7 +197,7 @@ func testAllChannels(notify bool) error { | |||||||
| 		testAllChannelsRunning = false | 		testAllChannelsRunning = false | ||||||
| 		testAllChannelsLock.Unlock() | 		testAllChannelsLock.Unlock() | ||||||
| 		if notify { | 		if notify { | ||||||
| 			err := common.SendEmail("通道测试完成", config.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") | 			err := message.Notify(message.ByAll, "渠道测试完成", "", "渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常") | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | 				logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||||
| 			} | 			} | ||||||
| @@ -217,8 +206,12 @@ func testAllChannels(notify bool) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestAllChannels(c *gin.Context) { | func TestChannels(c *gin.Context) { | ||||||
| 	err := testAllChannels(true) | 	scope := c.Query("scope") | ||||||
|  | 	if scope == "" { | ||||||
|  | 		scope = "all" | ||||||
|  | 	} | ||||||
|  | 	err := testChannels(true, scope) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -237,7 +230,7 @@ func AutomaticallyTestChannels(frequency int) { | |||||||
| 	for { | 	for { | ||||||
| 		time.Sleep(time.Duration(frequency) * time.Minute) | 		time.Sleep(time.Duration(frequency) * time.Minute) | ||||||
| 		logger.SysLog("testing all channels") | 		logger.SysLog("testing all channels") | ||||||
| 		_ = testAllChannels(false) | 		_ = testChannels(false, "all") | ||||||
| 		logger.SysLog("channel test finished") | 		logger.SysLog("channel test finished") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -15,7 +15,7 @@ func GetAllChannels(c *gin.Context) { | |||||||
| 	if p < 0 { | 	if p < 0 { | ||||||
| 		p = 0 | 		p = 0 | ||||||
| 	} | 	} | ||||||
| 	channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, false) | 	channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, "limited") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/message" | ||||||
| 	"github.com/songquanpeng/one-api/model" | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"strings" | 	"strings" | ||||||
| @@ -110,7 +111,7 @@ func SendEmailVerification(c *gin.Context) { | |||||||
| 	content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+ | 	content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+ | ||||||
| 		"<p>您的验证码为: <strong>%s</strong></p>"+ | 		"<p>您的验证码为: <strong>%s</strong></p>"+ | ||||||
| 		"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, code, common.VerificationValidMinutes) | 		"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, code, common.VerificationValidMinutes) | ||||||
| 	err := common.SendEmail(subject, email, content) | 	err := message.SendEmail(subject, email, content) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -149,7 +150,7 @@ func SendPasswordResetEmail(c *gin.Context) { | |||||||
| 		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ | 		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ | ||||||
| 		"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ | 		"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ | ||||||
| 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, link, link, common.VerificationValidMinutes) | 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, link, link, common.VerificationValidMinutes) | ||||||
| 	err := common.SendEmail(subject, email, content) | 	err := message.SendEmail(subject, email, content) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
|   | |||||||
| @@ -3,11 +3,13 @@ package controller | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/ai360" | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/moonshot" | 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||||
| 	"github.com/songquanpeng/one-api/relay/constant" | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
| 	"github.com/songquanpeng/one-api/relay/helper" | 	"github.com/songquanpeng/one-api/relay/helper" | ||||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
|  | 	"net/http" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // https://platform.openai.com/docs/api-reference/models/list | // https://platform.openai.com/docs/api-reference/models/list | ||||||
| @@ -39,6 +41,7 @@ type OpenAIModels struct { | |||||||
|  |  | ||||||
| var openAIModels []OpenAIModels | var openAIModels []OpenAIModels | ||||||
| var openAIModelsMap map[string]OpenAIModels | var openAIModelsMap map[string]OpenAIModels | ||||||
|  | var channelId2Models map[int][]string | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
| 	var permission []OpenAIModelPermission | 	var permission []OpenAIModelPermission | ||||||
| @@ -76,32 +79,44 @@ func init() { | |||||||
| 			}) | 			}) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	for _, modelName := range ai360.ModelList { | 	for _, channelType := range openai.CompatibleChannels { | ||||||
| 		openAIModels = append(openAIModels, OpenAIModels{ | 		if channelType == common.ChannelTypeAzure { | ||||||
| 			Id:         modelName, | 			continue | ||||||
| 			Object:     "model", | 		} | ||||||
| 			Created:    1626777600, | 		channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) | ||||||
| 			OwnedBy:    "360", | 		for _, modelName := range channelModelList { | ||||||
| 			Permission: permission, | 			openAIModels = append(openAIModels, OpenAIModels{ | ||||||
| 			Root:       modelName, | 				Id:         modelName, | ||||||
| 			Parent:     nil, | 				Object:     "model", | ||||||
| 		}) | 				Created:    1626777600, | ||||||
| 	} | 				OwnedBy:    channelName, | ||||||
| 	for _, modelName := range moonshot.ModelList { | 				Permission: permission, | ||||||
| 		openAIModels = append(openAIModels, OpenAIModels{ | 				Root:       modelName, | ||||||
| 			Id:         modelName, | 				Parent:     nil, | ||||||
| 			Object:     "model", | 			}) | ||||||
| 			Created:    1626777600, | 		} | ||||||
| 			OwnedBy:    "moonshot", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       modelName, |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}) |  | ||||||
| 	} | 	} | ||||||
| 	openAIModelsMap = make(map[string]OpenAIModels) | 	openAIModelsMap = make(map[string]OpenAIModels) | ||||||
| 	for _, model := range openAIModels { | 	for _, model := range openAIModels { | ||||||
| 		openAIModelsMap[model.Id] = model | 		openAIModelsMap[model.Id] = model | ||||||
| 	} | 	} | ||||||
|  | 	channelId2Models = make(map[int][]string) | ||||||
|  | 	for i := 1; i < common.ChannelTypeDummy; i++ { | ||||||
|  | 		adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i)) | ||||||
|  | 		meta := &util.RelayMeta{ | ||||||
|  | 			ChannelType: i, | ||||||
|  | 		} | ||||||
|  | 		adaptor.Init(meta) | ||||||
|  | 		channelId2Models[i] = adaptor.GetModelList() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func DashboardListModels(c *gin.Context) { | ||||||
|  | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 		"success": true, | ||||||
|  | 		"message": "", | ||||||
|  | 		"data":    channelId2Models, | ||||||
|  | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func ListModels(c *gin.Context) { | func ListModels(c *gin.Context) { | ||||||
|   | |||||||
| @@ -1,18 +1,22 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"github.com/songquanpeng/one-api/middleware" | 	"github.com/songquanpeng/one-api/middleware" | ||||||
| 	dbmodel "github.com/songquanpeng/one-api/model" | 	dbmodel "github.com/songquanpeng/one-api/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/monitor" | ||||||
| 	"github.com/songquanpeng/one-api/relay/constant" | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
| 	"github.com/songquanpeng/one-api/relay/controller" | 	"github.com/songquanpeng/one-api/relay/controller" | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
|  | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -38,11 +42,16 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { | |||||||
| func Relay(c *gin.Context) { | func Relay(c *gin.Context) { | ||||||
| 	ctx := c.Request.Context() | 	ctx := c.Request.Context() | ||||||
| 	relayMode := constant.Path2RelayMode(c.Request.URL.Path) | 	relayMode := constant.Path2RelayMode(c.Request.URL.Path) | ||||||
| 	bizErr := relay(c, relayMode) | 	if config.DebugEnabled { | ||||||
| 	if bizErr == nil { | 		requestBody, _ := common.GetRequestBody(c) | ||||||
| 		return | 		logger.Debugf(ctx, "request body: %s", string(requestBody)) | ||||||
| 	} | 	} | ||||||
| 	channelId := c.GetInt("channel_id") | 	channelId := c.GetInt("channel_id") | ||||||
|  | 	bizErr := relay(c, relayMode) | ||||||
|  | 	if bizErr == nil { | ||||||
|  | 		monitor.Emit(channelId, true) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
| 	lastFailedChannelId := channelId | 	lastFailedChannelId := channelId | ||||||
| 	channelName := c.GetString("channel_name") | 	channelName := c.GetString("channel_name") | ||||||
| 	group := c.GetString("group") | 	group := c.GetString("group") | ||||||
| @@ -50,12 +59,12 @@ func Relay(c *gin.Context) { | |||||||
| 	go processChannelRelayError(ctx, channelId, channelName, bizErr) | 	go processChannelRelayError(ctx, channelId, channelName, bizErr) | ||||||
| 	requestId := c.GetString(logger.RequestIdKey) | 	requestId := c.GetString(logger.RequestIdKey) | ||||||
| 	retryTimes := config.RetryTimes | 	retryTimes := config.RetryTimes | ||||||
| 	if !shouldRetry(bizErr.StatusCode) { | 	if !shouldRetry(c, bizErr.StatusCode) { | ||||||
| 		logger.Errorf(ctx, "relay error happen, but status code is %d, won't retry in this case", bizErr.StatusCode) | 		logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode) | ||||||
| 		retryTimes = 0 | 		retryTimes = 0 | ||||||
| 	} | 	} | ||||||
| 	for i := retryTimes; i > 0; i-- { | 	for i := retryTimes; i > 0; i-- { | ||||||
| 		channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel) | 		channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) | 			logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) | ||||||
| 			break | 			break | ||||||
| @@ -65,6 +74,8 @@ func Relay(c *gin.Context) { | |||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		middleware.SetupContextForSelectedChannel(c, channel, originalModel) | 		middleware.SetupContextForSelectedChannel(c, channel, originalModel) | ||||||
|  | 		requestBody, err := common.GetRequestBody(c) | ||||||
|  | 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||||
| 		bizErr = relay(c, relayMode) | 		bizErr = relay(c, relayMode) | ||||||
| 		if bizErr == nil { | 		if bizErr == nil { | ||||||
| 			return | 			return | ||||||
| @@ -85,7 +96,10 @@ func Relay(c *gin.Context) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func shouldRetry(statusCode int) bool { | func shouldRetry(c *gin.Context, statusCode int) bool { | ||||||
|  | 	if _, ok := c.Get("specific_channel_id"); ok { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
| 	if statusCode == http.StatusTooManyRequests { | 	if statusCode == http.StatusTooManyRequests { | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
| @@ -105,7 +119,9 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st | |||||||
| 	logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) | 	logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) | ||||||
| 	// https://platform.openai.com/docs/guides/error-codes/api-errors | 	// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||||
| 	if util.ShouldDisableChannel(&err.Error, err.StatusCode) { | 	if util.ShouldDisableChannel(&err.Error, err.StatusCode) { | ||||||
| 		disableChannel(channelId, channelName, err.Message) | 		monitor.DisableChannel(channelId, channelName, err.Message) | ||||||
|  | 	} else { | ||||||
|  | 		monitor.Emit(channelId, false) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -16,7 +16,10 @@ func GetAllTokens(c *gin.Context) { | |||||||
| 	if p < 0 { | 	if p < 0 { | ||||||
| 		p = 0 | 		p = 0 | ||||||
| 	} | 	} | ||||||
| 	tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage) |  | ||||||
|  | 	order := c.Query("order") | ||||||
|  | 	tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage, order) | ||||||
|  |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -139,6 +142,7 @@ func AddToken(c *gin.Context) { | |||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
|  | 		"data":    cleanToken, | ||||||
| 	}) | 	}) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|   | |||||||
| @@ -180,24 +180,27 @@ func Register(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllUsers(c *gin.Context) { | func GetAllUsers(c *gin.Context) { | ||||||
| 	p, _ := strconv.Atoi(c.Query("p")) |     p, _ := strconv.Atoi(c.Query("p")) | ||||||
| 	if p < 0 { |     if p < 0 { | ||||||
| 		p = 0 |         p = 0 | ||||||
| 	} |     } | ||||||
| 	users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage) |      | ||||||
| 	if err != nil { |     order := c.DefaultQuery("order", "") | ||||||
| 		c.JSON(http.StatusOK, gin.H{ |     users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) | ||||||
| 			"success": false, | 	 | ||||||
| 			"message": err.Error(), |     if err != nil { | ||||||
| 		}) |         c.JSON(http.StatusOK, gin.H{ | ||||||
| 		return |             "success": false, | ||||||
| 	} |             "message": err.Error(), | ||||||
| 	c.JSON(http.StatusOK, gin.H{ |         }) | ||||||
| 		"success": true, |         return | ||||||
| 		"message": "", |     } | ||||||
| 		"data":    users, |      | ||||||
| 	}) |     c.JSON(http.StatusOK, gin.H{ | ||||||
| 	return |         "success": true, | ||||||
|  |         "message": "", | ||||||
|  |         "data":    users, | ||||||
|  |     }) | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchUsers(c *gin.Context) { | func SearchUsers(c *gin.Context) { | ||||||
|   | |||||||
| @@ -2,7 +2,7 @@ version: '3.4' | |||||||
|  |  | ||||||
| services: | services: | ||||||
|   one-api: |   one-api: | ||||||
|     image: justsong/one-api:latest |     image: "${REGISTRY:-docker.io}/justsong/one-api:latest" | ||||||
|     container_name: one-api |     container_name: one-api | ||||||
|     restart: always |     restart: always | ||||||
|     command: --log-dir /app/logs |     command: --log-dir /app/logs | ||||||
| @@ -29,12 +29,12 @@ services: | |||||||
|       retries: 3 |       retries: 3 | ||||||
|  |  | ||||||
|   redis: |   redis: | ||||||
|     image: redis:latest |     image: "${REGISTRY:-docker.io}/redis:latest" | ||||||
|     container_name: redis |     container_name: redis | ||||||
|     restart: always |     restart: always | ||||||
|  |  | ||||||
|   db: |   db: | ||||||
|     image: mysql:8.2.0 |     image: "${REGISTRY:-docker.io}/mysql:8.2.0" | ||||||
|     restart: always |     restart: always | ||||||
|     container_name: mysql |     container_name: mysql | ||||||
|     volumes: |     volumes: | ||||||
|   | |||||||
							
								
								
									
										6
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								go.mod
									
									
									
									
									
								
							| @@ -42,7 +42,8 @@ require ( | |||||||
| 	github.com/gorilla/sessions v1.2.1 // indirect | 	github.com/gorilla/sessions v1.2.1 // indirect | ||||||
| 	github.com/jackc/pgpassfile v1.0.0 // indirect | 	github.com/jackc/pgpassfile v1.0.0 // indirect | ||||||
| 	github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect | 	github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect | ||||||
| 	github.com/jackc/pgx/v5 v5.3.1 // indirect | 	github.com/jackc/pgx/v5 v5.5.4 // indirect | ||||||
|  | 	github.com/jackc/puddle/v2 v2.2.1 // indirect | ||||||
| 	github.com/jinzhu/inflection v1.0.0 // indirect | 	github.com/jinzhu/inflection v1.0.0 // indirect | ||||||
| 	github.com/jinzhu/now v1.1.5 // indirect | 	github.com/jinzhu/now v1.1.5 // indirect | ||||||
| 	github.com/json-iterator/go v1.1.12 // indirect | 	github.com/json-iterator/go v1.1.12 // indirect | ||||||
| @@ -58,8 +59,9 @@ require ( | |||||||
| 	github.com/ugorji/go/codec v1.2.11 // indirect | 	github.com/ugorji/go/codec v1.2.11 // indirect | ||||||
| 	golang.org/x/arch v0.3.0 // indirect | 	golang.org/x/arch v0.3.0 // indirect | ||||||
| 	golang.org/x/net v0.17.0 // indirect | 	golang.org/x/net v0.17.0 // indirect | ||||||
|  | 	golang.org/x/sync v0.1.0 // indirect | ||||||
| 	golang.org/x/sys v0.15.0 // indirect | 	golang.org/x/sys v0.15.0 // indirect | ||||||
| 	golang.org/x/text v0.14.0 // indirect | 	golang.org/x/text v0.14.0 // indirect | ||||||
| 	google.golang.org/protobuf v1.30.0 // indirect | 	google.golang.org/protobuf v1.33.0 // indirect | ||||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||||
| ) | ) | ||||||
|   | |||||||
							
								
								
									
										12
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								go.sum
									
									
									
									
									
								
							| @@ -73,8 +73,10 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI | |||||||
| github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= | ||||||
| github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= | ||||||
| github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= | ||||||
| github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= | github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8= | ||||||
| github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= | github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= | ||||||
|  | github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= | ||||||
|  | github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= | ||||||
| github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | ||||||
| github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | ||||||
| github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||||
| @@ -157,6 +159,8 @@ golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE | |||||||
| golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | ||||||
| golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= | golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= | ||||||
| golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= | golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= | ||||||
|  | golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= | ||||||
|  | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||||
| golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||||
| golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||||
| golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
| @@ -177,8 +181,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IV | |||||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||||
| google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= | ||||||
| google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= | google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= | ||||||
| google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= | google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= | ||||||
| google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= | google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= | ||||||
| gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||||
| gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||||
| gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= | ||||||
|   | |||||||
							
								
								
									
										34
									
								
								i18n/en.json
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								i18n/en.json
									
									
									
									
									
								
							| @@ -8,12 +8,12 @@ | |||||||
|   "确认删除": "Confirm Delete", |   "确认删除": "Confirm Delete", | ||||||
|   "确认绑定": "Confirm Binding", |   "确认绑定": "Confirm Binding", | ||||||
|   "您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.", |   "您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.", | ||||||
|   "\"通道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"", |   "\"渠道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"", | ||||||
|   "通道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s", |   "渠道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s", | ||||||
|   "测试已在运行中": "Test is already running", |   "测试已在运行中": "Test is already running", | ||||||
|   "响应时间 %.2fs 超过阈值 %.2fs": "Response time %.2fs exceeds threshold %.2fs", |   "响应时间 %.2fs 超过阈值 %.2fs": "Response time %.2fs exceeds threshold %.2fs", | ||||||
|   "通道测试完成": "Channel test completed", |   "渠道测试完成": "Channel test completed", | ||||||
|   "通道测试完成,如果没有收到禁用通知,说明所有通道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal", |   "渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal", | ||||||
|   "无法连接至 GitHub 服务器,请稍后重试!": "Unable to connect to GitHub server, please try again later!", |   "无法连接至 GitHub 服务器,请稍后重试!": "Unable to connect to GitHub server, please try again later!", | ||||||
|   "返回值非法,用户字段为空,请稍后重试!": "The return value is illegal, the user field is empty, please try again later!", |   "返回值非法,用户字段为空,请稍后重试!": "The return value is illegal, the user field is empty, please try again later!", | ||||||
|   "管理员未开启通过 GitHub 登录以及注册": "The administrator did not turn on login and registration via GitHub", |   "管理员未开启通过 GitHub 登录以及注册": "The administrator did not turn on login and registration via GitHub", | ||||||
| @@ -119,11 +119,11 @@ | |||||||
|   " 个月 ": " M ", |   " 个月 ": " M ", | ||||||
|   " 年 ": " y ", |   " 年 ": " y ", | ||||||
|   "未测试": "Not tested", |   "未测试": "Not tested", | ||||||
|   "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", |   "渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", | ||||||
|   "已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.", |   "已成功开始测试所有渠道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.", | ||||||
|   "已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", |   "已成功开始测试所有已启用渠道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", | ||||||
|   "通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", |   "渠道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", | ||||||
|   "已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!", |   "已更新完毕所有已启用渠道余额!": "The balance of all enabled channels has been updated!", | ||||||
|   "搜索渠道的 ID,名称和密钥 ...": "Search for channel ID, name and key ...", |   "搜索渠道的 ID,名称和密钥 ...": "Search for channel ID, name and key ...", | ||||||
|   "名称": "Name", |   "名称": "Name", | ||||||
|   "分组": "Group", |   "分组": "Group", | ||||||
| @@ -141,9 +141,9 @@ | |||||||
|   "启用": "Enable", |   "启用": "Enable", | ||||||
|   "编辑": "Edit", |   "编辑": "Edit", | ||||||
|   "添加新的渠道": "Add a new channel", |   "添加新的渠道": "Add a new channel", | ||||||
|   "测试所有通道": "Test all channels", |   "测试所有渠道": "Test all channels", | ||||||
|   "测试所有已启用通道": "Test all enabled channels", |   "测试所有已启用渠道": "Test all enabled channels", | ||||||
|   "更新所有已启用通道余额": "Update the balance of all enabled channels", |   "更新所有已启用渠道余额": "Update the balance of all enabled channels", | ||||||
|   "刷新": "Refresh", |   "刷新": "Refresh", | ||||||
|   "处理中...": "Processing...", |   "处理中...": "Processing...", | ||||||
|   "绑定成功!": "Binding succeeded!", |   "绑定成功!": "Binding succeeded!", | ||||||
| @@ -207,11 +207,11 @@ | |||||||
|   "监控设置": "Monitoring Settings", |   "监控设置": "Monitoring Settings", | ||||||
|   "最长响应时间": "Longest Response Time", |   "最长响应时间": "Longest Response Time", | ||||||
|   "单位秒": "Unit in seconds", |   "单位秒": "Unit in seconds", | ||||||
|   "当运行通道全部测试时": "When all operating channels are tested", |   "当运行渠道全部测试时": "When all operating channels are tested", | ||||||
|   "超过此时间将自动禁用通道": "Channels will be automatically disabled if this time is exceeded", |   "超过此时间将自动禁用渠道": "Channels will be automatically disabled if this time is exceeded", | ||||||
|   "额度提醒阈值": "Quota reminder threshold", |   "额度提醒阈值": "Quota reminder threshold", | ||||||
|   "低于此额度时将发送邮件提醒用户": "Email will be sent to remind users when the quota is below this", |   "低于此额度时将发送邮件提醒用户": "Email will be sent to remind users when the quota is below this", | ||||||
|   "失败时自动禁用通道": "Automatically disable the channel when it fails", |   "失败时自动禁用渠道": "Automatically disable the channel when it fails", | ||||||
|   "保存监控设置": "Save Monitoring Settings", |   "保存监控设置": "Save Monitoring Settings", | ||||||
|   "额度设置": "Quota Settings", |   "额度设置": "Quota Settings", | ||||||
|   "新用户初始额度": "Initial quota for new users", |   "新用户初始额度": "Initial quota for new users", | ||||||
| @@ -405,7 +405,7 @@ | |||||||
|   "镜像": "Mirror", |   "镜像": "Mirror", | ||||||
|   "请输入镜像站地址,格式为:https://domain.com,可不填,不填则使用渠道默认值": "Please enter the mirror site address, the format is: https://domain.com, it can be left blank, if left blank, the default value of the channel will be used", |   "请输入镜像站地址,格式为:https://domain.com,可不填,不填则使用渠道默认值": "Please enter the mirror site address, the format is: https://domain.com, it can be left blank, if left blank, the default value of the channel will be used", | ||||||
|   "模型": "Model", |   "模型": "Model", | ||||||
|   "请选择该通道所支持的模型": "Please select the model supported by the channel", |   "请选择该渠道所支持的模型": "Please select the model supported by the channel", | ||||||
|   "填入基础模型": "Fill in the basic model", |   "填入基础模型": "Fill in the basic model", | ||||||
|   "填入所有模型": "Fill in all models", |   "填入所有模型": "Fill in all models", | ||||||
|   "清除所有模型": "Clear all models", |   "清除所有模型": "Clear all models", | ||||||
| @@ -515,7 +515,7 @@ | |||||||
|   "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", |   "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", | ||||||
|   "Homepage URL 填": "Fill in the Homepage URL", |   "Homepage URL 填": "Fill in the Homepage URL", | ||||||
|   "Authorization callback URL 填": "Fill in the Authorization callback URL", |   "Authorization callback URL 填": "Fill in the Authorization callback URL", | ||||||
|   "请为通道命名": "Please name the channel", |   "请为渠道命名": "Please name the channel", | ||||||
|   "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", |   "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", | ||||||
|   "模型重定向": "Model redirection", |   "模型重定向": "Model redirection", | ||||||
|   "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel", |   "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel", | ||||||
|   | |||||||
							
								
								
									
										26
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										26
									
								
								main.go
									
									
									
									
									
								
							| @@ -30,11 +30,25 @@ func main() { | |||||||
| 	if config.DebugEnabled { | 	if config.DebugEnabled { | ||||||
| 		logger.SysLog("running in debug mode") | 		logger.SysLog("running in debug mode") | ||||||
| 	} | 	} | ||||||
|  | 	var err error | ||||||
| 	// Initialize SQL Database | 	// Initialize SQL Database | ||||||
| 	err := model.InitDB() | 	model.DB, err = model.InitDB("SQL_DSN") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.FatalLog("failed to initialize database: " + err.Error()) | 		logger.FatalLog("failed to initialize database: " + err.Error()) | ||||||
| 	} | 	} | ||||||
|  | 	if os.Getenv("LOG_SQL_DSN") != "" { | ||||||
|  | 		logger.SysLog("using secondary database for table logs") | ||||||
|  | 		model.LOG_DB, err = model.InitDB("LOG_SQL_DSN") | ||||||
|  | 		if err != nil { | ||||||
|  | 			logger.FatalLog("failed to initialize secondary database: " + err.Error()) | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		model.LOG_DB = model.DB | ||||||
|  | 	} | ||||||
|  | 	err = model.CreateRootAccountIfNeed() | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.FatalLog("database init error: " + err.Error()) | ||||||
|  | 	} | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		err := model.CloseDB() | 		err := model.CloseDB() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @@ -64,13 +78,6 @@ func main() { | |||||||
| 		go model.SyncOptions(config.SyncFrequency) | 		go model.SyncOptions(config.SyncFrequency) | ||||||
| 		go model.SyncChannelCache(config.SyncFrequency) | 		go model.SyncChannelCache(config.SyncFrequency) | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { |  | ||||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) |  | ||||||
| 		if err != nil { |  | ||||||
| 			logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 		go controller.AutomaticallyUpdateChannels(frequency) |  | ||||||
| 	} |  | ||||||
| 	if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { | 	if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { | ||||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) | 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @@ -83,6 +90,9 @@ func main() { | |||||||
| 		logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") | 		logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") | ||||||
| 		model.InitBatchUpdater() | 		model.InitBatchUpdater() | ||||||
| 	} | 	} | ||||||
|  | 	if config.EnableMetric { | ||||||
|  | 		logger.SysLog("metric enabled, will disable channel if too much request failed") | ||||||
|  | 	} | ||||||
| 	openai.InitTokenEncoders() | 	openai.InitTokenEncoders() | ||||||
|  |  | ||||||
| 	// Initialize HTTP server | 	// Initialize HTTP server | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"github.com/gin-contrib/sessions" | 	"github.com/gin-contrib/sessions" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/blacklist" | ||||||
| 	"github.com/songquanpeng/one-api/model" | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"strings" | 	"strings" | ||||||
| @@ -42,11 +43,14 @@ func authHelper(c *gin.Context, minRole int) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if status.(int) == common.UserStatusDisabled { | 	if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "用户已被封禁", | 			"message": "用户已被封禁", | ||||||
| 		}) | 		}) | ||||||
|  | 		session := sessions.Default(c) | ||||||
|  | 		session.Clear() | ||||||
|  | 		_ = session.Save() | ||||||
| 		c.Abort() | 		c.Abort() | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -99,7 +103,7 @@ func TokenAuth() func(c *gin.Context) { | |||||||
| 			abortWithMessage(c, http.StatusInternalServerError, err.Error()) | 			abortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		if !userEnabled { | 		if !userEnabled || blacklist.IsUserBanned(token.UserId) { | ||||||
| 			abortWithMessage(c, http.StatusForbidden, "用户已被封禁") | 			abortWithMessage(c, http.StatusForbidden, "用户已被封禁") | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -68,7 +68,7 @@ func Distribute() func(c *gin.Context) { | |||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 			requestModel = modelRequest.Model | 			requestModel = modelRequest.Model | ||||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) | 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | ||||||
| 				if channel != nil { | 				if channel != nil { | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ package middleware | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"runtime/debug" | 	"runtime/debug" | ||||||
| @@ -12,11 +13,15 @@ func RelayPanicRecover() gin.HandlerFunc { | |||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		defer func() { | 		defer func() { | ||||||
| 			if err := recover(); err != nil { | 			if err := recover(); err != nil { | ||||||
| 				logger.SysError(fmt.Sprintf("panic detected: %v", err)) | 				ctx := c.Request.Context() | ||||||
| 				logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) | 				logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err)) | ||||||
|  | 				logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) | ||||||
|  | 				logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path)) | ||||||
|  | 				body, _ := common.GetRequestBody(c) | ||||||
|  | 				logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body))) | ||||||
| 				c.JSON(http.StatusInternalServerError, gin.H{ | 				c.JSON(http.StatusInternalServerError, gin.H{ | ||||||
| 					"error": gin.H{ | 					"error": gin.H{ | ||||||
| 						"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), | 						"message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err), | ||||||
| 						"type":    "one_api_panic", | 						"type":    "one_api_panic", | ||||||
| 					}, | 					}, | ||||||
| 				}) | 				}) | ||||||
|   | |||||||
| @@ -9,7 +9,7 @@ import ( | |||||||
|  |  | ||||||
| func RequestId() func(c *gin.Context) { | func RequestId() func(c *gin.Context) { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		id := helper.GetTimeString() + helper.GetRandomNumberString(8) | 		id := helper.GenRequestID() | ||||||
| 		c.Set(logger.RequestIdKey, id) | 		c.Set(logger.RequestIdKey, id) | ||||||
| 		ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) | 		ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) | ||||||
| 		c.Request = c.Request.WithContext(ctx) | 		c.Request = c.Request.WithContext(ctx) | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| @@ -70,31 +71,42 @@ func CacheGetUserGroup(id int) (group string, err error) { | |||||||
| 	return group, err | 	return group, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func CacheGetUserQuota(id int) (quota int, err error) { | func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) { | ||||||
|  | 	quota, err = GetUserQuota(id) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  | 	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Error(ctx, "Redis set user quota error: "+err.Error()) | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) { | ||||||
| 	if !common.RedisEnabled { | 	if !common.RedisEnabled { | ||||||
| 		return GetUserQuota(id) | 		return GetUserQuota(id) | ||||||
| 	} | 	} | ||||||
| 	quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) | 	quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		quota, err = GetUserQuota(id) | 		return fetchAndUpdateUserQuota(ctx, id) | ||||||
| 		if err != nil { |  | ||||||
| 			return 0, err |  | ||||||
| 		} |  | ||||||
| 		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) |  | ||||||
| 		if err != nil { |  | ||||||
| 			logger.SysError("Redis set user quota error: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 		return quota, err |  | ||||||
| 	} | 	} | ||||||
| 	quota, err = strconv.Atoi(quotaString) | 	quota, err = strconv.ParseInt(quotaString, 10, 64) | ||||||
| 	return quota, err | 	if err != nil { | ||||||
|  | 		return 0, nil | ||||||
|  | 	} | ||||||
|  | 	if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db | ||||||
|  | 		logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id) | ||||||
|  | 		return fetchAndUpdateUserQuota(ctx, id) | ||||||
|  | 	} | ||||||
|  | 	return quota, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func CacheUpdateUserQuota(id int) error { | func CacheUpdateUserQuota(ctx context.Context, id int) error { | ||||||
| 	if !common.RedisEnabled { | 	if !common.RedisEnabled { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	quota, err := CacheGetUserQuota(id) | 	quota, err := CacheGetUserQuota(ctx, id) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @@ -102,7 +114,7 @@ func CacheUpdateUserQuota(id int) error { | |||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func CacheDecreaseUserQuota(id int, quota int) error { | func CacheDecreaseUserQuota(id int, quota int64) error { | ||||||
| 	if !common.RedisEnabled { | 	if !common.RedisEnabled { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| @@ -191,7 +203,7 @@ func SyncChannelCache(frequency int) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { | ||||||
| 	if !config.MemoryCacheEnabled { | 	if !config.MemoryCacheEnabled { | ||||||
| 		return GetRandomSatisfiedChannel(group, model) | 		return GetRandomSatisfiedChannel(group, model) | ||||||
| 	} | 	} | ||||||
| @@ -213,5 +225,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	idx := rand.Intn(endIdx) | 	idx := rand.Intn(endIdx) | ||||||
|  | 	if ignoreFirstPriority { | ||||||
|  | 		if endIdx < len(channels) { // which means there are more than one priority | ||||||
|  | 			idx = common.RandRange(endIdx, len(channels)) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 	return channels[idx], nil | 	return channels[idx], nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ import ( | |||||||
| type Channel struct { | type Channel struct { | ||||||
| 	Id                 int     `json:"id"` | 	Id                 int     `json:"id"` | ||||||
| 	Type               int     `json:"type" gorm:"default:0"` | 	Type               int     `json:"type" gorm:"default:0"` | ||||||
| 	Key                string  `json:"key" gorm:"not null;index"` | 	Key                string  `json:"key" gorm:"type:text"` | ||||||
| 	Status             int     `json:"status" gorm:"default:1"` | 	Status             int     `json:"status" gorm:"default:1"` | ||||||
| 	Name               string  `json:"name" gorm:"index"` | 	Name               string  `json:"name" gorm:"index"` | ||||||
| 	Weight             *uint   `json:"weight" gorm:"default:0"` | 	Weight             *uint   `json:"weight" gorm:"default:0"` | ||||||
| @@ -32,23 +32,22 @@ type Channel struct { | |||||||
| 	Config             string  `json:"config"` | 	Config             string  `json:"config"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { | ||||||
| 	var channels []*Channel | 	var channels []*Channel | ||||||
| 	var err error | 	var err error | ||||||
| 	if selectAll { | 	switch scope { | ||||||
|  | 	case "all": | ||||||
| 		err = DB.Order("id desc").Find(&channels).Error | 		err = DB.Order("id desc").Find(&channels).Error | ||||||
| 	} else { | 	case "disabled": | ||||||
|  | 		err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error | ||||||
|  | 	default: | ||||||
| 		err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error | 		err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error | ||||||
| 	} | 	} | ||||||
| 	return channels, err | 	return channels, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchChannels(keyword string) (channels []*Channel, err error) { | func SearchChannels(keyword string) (channels []*Channel, err error) { | ||||||
| 	keyCol := "`key`" | 	err = DB.Omit("key").Where("id = ? or name LIKE ?", helper.String2Int(keyword), keyword+"%").Find(&channels).Error | ||||||
| 	if common.UsingPostgreSQL { |  | ||||||
| 		keyCol = `"key"` |  | ||||||
| 	} |  | ||||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", helper.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error |  | ||||||
| 	return channels, err | 	return channels, err | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -179,7 +178,7 @@ func UpdateChannelStatusById(id int, status int) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateChannelUsedQuota(id int, quota int) { | func UpdateChannelUsedQuota(id int, quota int64) { | ||||||
| 	if config.BatchUpdateEnabled { | 	if config.BatchUpdateEnabled { | ||||||
| 		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) | 		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) | ||||||
| 		return | 		return | ||||||
| @@ -187,7 +186,7 @@ func UpdateChannelUsedQuota(id int, quota int) { | |||||||
| 	updateChannelUsedQuota(id, quota) | 	updateChannelUsedQuota(id, quota) | ||||||
| } | } | ||||||
|  |  | ||||||
| func updateChannelUsedQuota(id int, quota int) { | func updateChannelUsedQuota(id int, quota int64) { | ||||||
| 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("failed to update channel used quota: " + err.Error()) | 		logger.SysError("failed to update channel used quota: " + err.Error()) | ||||||
|   | |||||||
							
								
								
									
										30
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -45,13 +45,13 @@ func RecordLog(userId int, logType int, content string) { | |||||||
| 		Type:      logType, | 		Type:      logType, | ||||||
| 		Content:   content, | 		Content:   content, | ||||||
| 	} | 	} | ||||||
| 	err := DB.Create(log).Error | 	err := LOG_DB.Create(log).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("failed to record log: " + err.Error()) | 		logger.SysError("failed to record log: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { | ||||||
| 	logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | 	logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||||
| 	if !config.LogConsumeEnabled { | 	if !config.LogConsumeEnabled { | ||||||
| 		return | 		return | ||||||
| @@ -66,10 +66,10 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke | |||||||
| 		CompletionTokens: completionTokens, | 		CompletionTokens: completionTokens, | ||||||
| 		TokenName:        tokenName, | 		TokenName:        tokenName, | ||||||
| 		ModelName:        modelName, | 		ModelName:        modelName, | ||||||
| 		Quota:            quota, | 		Quota:            int(quota), | ||||||
| 		ChannelId:        channelId, | 		ChannelId:        channelId, | ||||||
| 	} | 	} | ||||||
| 	err := DB.Create(log).Error | 	err := LOG_DB.Create(log).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Error(ctx, "failed to record log: "+err.Error()) | 		logger.Error(ctx, "failed to record log: "+err.Error()) | ||||||
| 	} | 	} | ||||||
| @@ -78,9 +78,9 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke | |||||||
| func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { | func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { | ||||||
| 	var tx *gorm.DB | 	var tx *gorm.DB | ||||||
| 	if logType == LogTypeUnknown { | 	if logType == LogTypeUnknown { | ||||||
| 		tx = DB | 		tx = LOG_DB | ||||||
| 	} else { | 	} else { | ||||||
| 		tx = DB.Where("type = ?", logType) | 		tx = LOG_DB.Where("type = ?", logType) | ||||||
| 	} | 	} | ||||||
| 	if modelName != "" { | 	if modelName != "" { | ||||||
| 		tx = tx.Where("model_name = ?", modelName) | 		tx = tx.Where("model_name = ?", modelName) | ||||||
| @@ -107,9 +107,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName | |||||||
| func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { | func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { | ||||||
| 	var tx *gorm.DB | 	var tx *gorm.DB | ||||||
| 	if logType == LogTypeUnknown { | 	if logType == LogTypeUnknown { | ||||||
| 		tx = DB.Where("user_id = ?", userId) | 		tx = LOG_DB.Where("user_id = ?", userId) | ||||||
| 	} else { | 	} else { | ||||||
| 		tx = DB.Where("user_id = ? and type = ?", userId, logType) | 		tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType) | ||||||
| 	} | 	} | ||||||
| 	if modelName != "" { | 	if modelName != "" { | ||||||
| 		tx = tx.Where("model_name = ?", modelName) | 		tx = tx.Where("model_name = ?", modelName) | ||||||
| @@ -128,17 +128,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int | |||||||
| } | } | ||||||
|  |  | ||||||
| func SearchAllLogs(keyword string) (logs []*Log, err error) { | func SearchAllLogs(keyword string) (logs []*Log, err error) { | ||||||
| 	err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error | 	err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error | ||||||
| 	return logs, err | 	return logs, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | ||||||
| 	err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error | 	err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error | ||||||
| 	return logs, err | 	return logs, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { | func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { | ||||||
| 	tx := DB.Table("logs").Select("ifnull(sum(quota),0)") | 	tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)") | ||||||
| 	if username != "" { | 	if username != "" { | ||||||
| 		tx = tx.Where("username = ?", username) | 		tx = tx.Where("username = ?", username) | ||||||
| 	} | 	} | ||||||
| @@ -162,7 +162,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa | |||||||
| } | } | ||||||
|  |  | ||||||
| func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | ||||||
| 	tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") | 	tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") | ||||||
| 	if username != "" { | 	if username != "" { | ||||||
| 		tx = tx.Where("username = ?", username) | 		tx = tx.Where("username = ?", username) | ||||||
| 	} | 	} | ||||||
| @@ -183,7 +183,7 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa | |||||||
| } | } | ||||||
|  |  | ||||||
| func DeleteOldLog(targetTimestamp int64) (int64, error) { | func DeleteOldLog(targetTimestamp int64) (int64, error) { | ||||||
| 	result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) | 	result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) | ||||||
| 	return result.RowsAffected, result.Error | 	return result.RowsAffected, result.Error | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -207,7 +207,7 @@ func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatis | |||||||
| 		groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day" | 		groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day" | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = DB.Raw(` | 	err = LOG_DB.Raw(` | ||||||
| 		SELECT `+groupSelect+`, | 		SELECT `+groupSelect+`, | ||||||
| 		model_name, count(1) as request_count, | 		model_name, count(1) as request_count, | ||||||
| 		sum(quota) as quota, | 		sum(quota) as quota, | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/env" | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"gorm.io/driver/mysql" | 	"gorm.io/driver/mysql" | ||||||
| @@ -16,12 +17,13 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| var DB *gorm.DB | var DB *gorm.DB | ||||||
|  | var LOG_DB *gorm.DB | ||||||
|  |  | ||||||
| func createRootAccountIfNeed() error { | func CreateRootAccountIfNeed() error { | ||||||
| 	var user User | 	var user User | ||||||
| 	//if user.Status != util.UserStatusEnabled { | 	//if user.Status != util.UserStatusEnabled { | ||||||
| 	if err := DB.First(&user).Error; err != nil { | 	if err := DB.First(&user).Error; err != nil { | ||||||
| 		logger.SysLog("no user exists, create a root user for you: username is root, password is 123456") | 		logger.SysLog("no user exists, creating a root user for you: username is root, password is 123456") | ||||||
| 		hashedPassword, err := common.Password2Hash("123456") | 		hashedPassword, err := common.Password2Hash("123456") | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| @@ -33,16 +35,32 @@ func createRootAccountIfNeed() error { | |||||||
| 			Status:      common.UserStatusEnabled, | 			Status:      common.UserStatusEnabled, | ||||||
| 			DisplayName: "Root User", | 			DisplayName: "Root User", | ||||||
| 			AccessToken: helper.GetUUID(), | 			AccessToken: helper.GetUUID(), | ||||||
| 			Quota:       100000000, | 			Quota:       500000000000000, | ||||||
| 		} | 		} | ||||||
| 		DB.Create(&rootUser) | 		DB.Create(&rootUser) | ||||||
|  | 		if config.InitialRootToken != "" { | ||||||
|  | 			logger.SysLog("creating initial root token as requested") | ||||||
|  | 			token := Token{ | ||||||
|  | 				Id:             1, | ||||||
|  | 				UserId:         rootUser.Id, | ||||||
|  | 				Key:            config.InitialRootToken, | ||||||
|  | 				Status:         common.TokenStatusEnabled, | ||||||
|  | 				Name:           "Initial Root Token", | ||||||
|  | 				CreatedTime:    helper.GetTimestamp(), | ||||||
|  | 				AccessedTime:   helper.GetTimestamp(), | ||||||
|  | 				ExpiredTime:    -1, | ||||||
|  | 				RemainQuota:    500000000000000, | ||||||
|  | 				UnlimitedQuota: true, | ||||||
|  | 			} | ||||||
|  | 			DB.Create(&token) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func chooseDB() (*gorm.DB, error) { | func chooseDB(envName string) (*gorm.DB, error) { | ||||||
| 	if os.Getenv("SQL_DSN") != "" { | 	if os.Getenv(envName) != "" { | ||||||
| 		dsn := os.Getenv("SQL_DSN") | 		dsn := os.Getenv(envName) | ||||||
| 		if strings.HasPrefix(dsn, "postgres://") { | 		if strings.HasPrefix(dsn, "postgres://") { | ||||||
| 			// Use PostgreSQL | 			// Use PostgreSQL | ||||||
| 			logger.SysLog("using PostgreSQL as database") | 			logger.SysLog("using PostgreSQL as database") | ||||||
| @@ -56,6 +74,7 @@ func chooseDB() (*gorm.DB, error) { | |||||||
| 		} | 		} | ||||||
| 		// Use MySQL | 		// Use MySQL | ||||||
| 		logger.SysLog("using MySQL as database") | 		logger.SysLog("using MySQL as database") | ||||||
|  | 		common.UsingMySQL = true | ||||||
| 		return gorm.Open(mysql.Open(dsn), &gorm.Config{ | 		return gorm.Open(mysql.Open(dsn), &gorm.Config{ | ||||||
| 			PrepareStmt: true, // precompile SQL | 			PrepareStmt: true, // precompile SQL | ||||||
| 		}) | 		}) | ||||||
| @@ -69,67 +88,78 @@ func chooseDB() (*gorm.DB, error) { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func InitDB() (err error) { | func InitDB(envName string) (db *gorm.DB, err error) { | ||||||
| 	db, err := chooseDB() | 	db, err = chooseDB(envName) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		if config.DebugEnabled { | 		if config.DebugSQLEnabled { | ||||||
| 			db = db.Debug() | 			db = db.Debug() | ||||||
| 		} | 		} | ||||||
| 		DB = db | 		sqlDB, err := db.DB() | ||||||
| 		sqlDB, err := DB.DB() |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) | 		sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) | ||||||
| 		sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) | 		sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) | ||||||
| 		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) | 		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) | ||||||
|  |  | ||||||
| 		if !config.IsMasterNode { | 		if !config.IsMasterNode { | ||||||
| 			return nil | 			return db, err | ||||||
|  | 		} | ||||||
|  | 		if common.UsingMySQL { | ||||||
|  | 			_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded | ||||||
| 		} | 		} | ||||||
| 		logger.SysLog("database migration started") | 		logger.SysLog("database migration started") | ||||||
| 		err = db.AutoMigrate(&Channel{}) | 		err = db.AutoMigrate(&Channel{}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		err = db.AutoMigrate(&Token{}) | 		err = db.AutoMigrate(&Token{}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		err = db.AutoMigrate(&User{}) | 		err = db.AutoMigrate(&User{}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		err = db.AutoMigrate(&Option{}) | 		err = db.AutoMigrate(&Option{}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		err = db.AutoMigrate(&Redemption{}) | 		err = db.AutoMigrate(&Redemption{}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		err = db.AutoMigrate(&Ability{}) | 		err = db.AutoMigrate(&Ability{}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		err = db.AutoMigrate(&Log{}) | 		err = db.AutoMigrate(&Log{}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		logger.SysLog("database migrated") | 		logger.SysLog("database migrated") | ||||||
| 		err = createRootAccountIfNeed() | 		return db, err | ||||||
| 		return err |  | ||||||
| 	} else { | 	} else { | ||||||
| 		logger.FatalLog(err) | 		logger.FatalLog(err) | ||||||
| 	} | 	} | ||||||
| 	return err | 	return db, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func CloseDB() error { | func closeDB(db *gorm.DB) error { | ||||||
| 	sqlDB, err := DB.DB() | 	sqlDB, err := db.DB() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	err = sqlDB.Close() | 	err = sqlDB.Close() | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func CloseDB() error { | ||||||
|  | 	if LOG_DB != DB { | ||||||
|  | 		err := closeDB(LOG_DB) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return closeDB(DB) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -57,13 +57,15 @@ func InitOptionMap() { | |||||||
| 	config.OptionMap["WeChatServerAddress"] = "" | 	config.OptionMap["WeChatServerAddress"] = "" | ||||||
| 	config.OptionMap["WeChatServerToken"] = "" | 	config.OptionMap["WeChatServerToken"] = "" | ||||||
| 	config.OptionMap["WeChatAccountQRCodeImageURL"] = "" | 	config.OptionMap["WeChatAccountQRCodeImageURL"] = "" | ||||||
|  | 	config.OptionMap["MessagePusherAddress"] = "" | ||||||
|  | 	config.OptionMap["MessagePusherToken"] = "" | ||||||
| 	config.OptionMap["TurnstileSiteKey"] = "" | 	config.OptionMap["TurnstileSiteKey"] = "" | ||||||
| 	config.OptionMap["TurnstileSecretKey"] = "" | 	config.OptionMap["TurnstileSecretKey"] = "" | ||||||
| 	config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser) | 	config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10) | ||||||
| 	config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter) | 	config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10) | ||||||
| 	config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee) | 	config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10) | ||||||
| 	config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold) | 	config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10) | ||||||
| 	config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota) | 	config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10) | ||||||
| 	config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() | 	config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() | ||||||
| 	config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() | 	config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() | ||||||
| 	config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() | 	config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() | ||||||
| @@ -79,6 +81,9 @@ func InitOptionMap() { | |||||||
| func loadOptionsFromDatabase() { | func loadOptionsFromDatabase() { | ||||||
| 	options, _ := AllOption() | 	options, _ := AllOption() | ||||||
| 	for _, option := range options { | 	for _, option := range options { | ||||||
|  | 		if option.Key == "ModelRatio" { | ||||||
|  | 			option.Value = common.AddNewMissingRatio(option.Value) | ||||||
|  | 		} | ||||||
| 		err := updateOptionMap(option.Key, option.Value) | 		err := updateOptionMap(option.Key, option.Value) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.SysError("failed to update option map: " + err.Error()) | 			logger.SysError("failed to update option map: " + err.Error()) | ||||||
| @@ -179,20 +184,24 @@ func updateOptionMap(key string, value string) (err error) { | |||||||
| 		config.WeChatServerToken = value | 		config.WeChatServerToken = value | ||||||
| 	case "WeChatAccountQRCodeImageURL": | 	case "WeChatAccountQRCodeImageURL": | ||||||
| 		config.WeChatAccountQRCodeImageURL = value | 		config.WeChatAccountQRCodeImageURL = value | ||||||
|  | 	case "MessagePusherAddress": | ||||||
|  | 		config.MessagePusherAddress = value | ||||||
|  | 	case "MessagePusherToken": | ||||||
|  | 		config.MessagePusherToken = value | ||||||
| 	case "TurnstileSiteKey": | 	case "TurnstileSiteKey": | ||||||
| 		config.TurnstileSiteKey = value | 		config.TurnstileSiteKey = value | ||||||
| 	case "TurnstileSecretKey": | 	case "TurnstileSecretKey": | ||||||
| 		config.TurnstileSecretKey = value | 		config.TurnstileSecretKey = value | ||||||
| 	case "QuotaForNewUser": | 	case "QuotaForNewUser": | ||||||
| 		config.QuotaForNewUser, _ = strconv.Atoi(value) | 		config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64) | ||||||
| 	case "QuotaForInviter": | 	case "QuotaForInviter": | ||||||
| 		config.QuotaForInviter, _ = strconv.Atoi(value) | 		config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64) | ||||||
| 	case "QuotaForInvitee": | 	case "QuotaForInvitee": | ||||||
| 		config.QuotaForInvitee, _ = strconv.Atoi(value) | 		config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64) | ||||||
| 	case "QuotaRemindThreshold": | 	case "QuotaRemindThreshold": | ||||||
| 		config.QuotaRemindThreshold, _ = strconv.Atoi(value) | 		config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64) | ||||||
| 	case "PreConsumedQuota": | 	case "PreConsumedQuota": | ||||||
| 		config.PreConsumedQuota, _ = strconv.Atoi(value) | 		config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64) | ||||||
| 	case "RetryTimes": | 	case "RetryTimes": | ||||||
| 		config.RetryTimes, _ = strconv.Atoi(value) | 		config.RetryTimes, _ = strconv.Atoi(value) | ||||||
| 	case "ModelRatio": | 	case "ModelRatio": | ||||||
|   | |||||||
| @@ -14,7 +14,7 @@ type Redemption struct { | |||||||
| 	Key          string `json:"key" gorm:"type:char(32);uniqueIndex"` | 	Key          string `json:"key" gorm:"type:char(32);uniqueIndex"` | ||||||
| 	Status       int    `json:"status" gorm:"default:1"` | 	Status       int    `json:"status" gorm:"default:1"` | ||||||
| 	Name         string `json:"name" gorm:"index"` | 	Name         string `json:"name" gorm:"index"` | ||||||
| 	Quota        int    `json:"quota" gorm:"default:100"` | 	Quota        int64  `json:"quota" gorm:"bigint;default:100"` | ||||||
| 	CreatedTime  int64  `json:"created_time" gorm:"bigint"` | 	CreatedTime  int64  `json:"created_time" gorm:"bigint"` | ||||||
| 	RedeemedTime int64  `json:"redeemed_time" gorm:"bigint"` | 	RedeemedTime int64  `json:"redeemed_time" gorm:"bigint"` | ||||||
| 	Count        int    `json:"count" gorm:"-:all"` // only for api request | 	Count        int    `json:"count" gorm:"-:all"` // only for api request | ||||||
| @@ -42,7 +42,7 @@ func GetRedemptionById(id int) (*Redemption, error) { | |||||||
| 	return &redemption, err | 	return &redemption, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func Redeem(key string, userId int) (quota int, err error) { | func Redeem(key string, userId int) (quota int64, err error) { | ||||||
| 	if key == "" { | 	if key == "" { | ||||||
| 		return 0, errors.New("未提供兑换码") | 		return 0, errors.New("未提供兑换码") | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -7,6 +7,7 @@ import ( | |||||||
| 	"github.com/songquanpeng/one-api/common/config" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/message" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -19,15 +20,26 @@ type Token struct { | |||||||
| 	CreatedTime    int64  `json:"created_time" gorm:"bigint"` | 	CreatedTime    int64  `json:"created_time" gorm:"bigint"` | ||||||
| 	AccessedTime   int64  `json:"accessed_time" gorm:"bigint"` | 	AccessedTime   int64  `json:"accessed_time" gorm:"bigint"` | ||||||
| 	ExpiredTime    int64  `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired | 	ExpiredTime    int64  `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired | ||||||
| 	RemainQuota    int    `json:"remain_quota" gorm:"default:0"` | 	RemainQuota    int64  `json:"remain_quota" gorm:"bigint;default:0"` | ||||||
| 	UnlimitedQuota bool   `json:"unlimited_quota" gorm:"default:false"` | 	UnlimitedQuota bool   `json:"unlimited_quota" gorm:"default:false"` | ||||||
| 	UsedQuota      int    `json:"used_quota" gorm:"default:0"` // used quota | 	UsedQuota      int64  `json:"used_quota" gorm:"bigint;default:0"` // used quota | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { | func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) { | ||||||
| 	var tokens []*Token | 	var tokens []*Token | ||||||
| 	var err error | 	var err error | ||||||
| 	err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error | 	query := DB.Where("user_id = ?", userId) | ||||||
|  | 	 | ||||||
|  | 	switch order { | ||||||
|  | 	case "remain_quota": | ||||||
|  | 		query = query.Order("unlimited_quota desc, remain_quota desc") | ||||||
|  | 	case "used_quota": | ||||||
|  | 		query = query.Order("used_quota desc") | ||||||
|  | 	default: | ||||||
|  | 		query = query.Order("id desc") | ||||||
|  | 	} | ||||||
|  | 	 | ||||||
|  | 	err = query.Limit(num).Offset(startIdx).Find(&tokens).Error | ||||||
| 	return tokens, err | 	return tokens, err | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -137,7 +149,7 @@ func DeleteTokenById(id int, userId int) (err error) { | |||||||
| 	return token.Delete() | 	return token.Delete() | ||||||
| } | } | ||||||
|  |  | ||||||
| func IncreaseTokenQuota(id int, quota int) (err error) { | func IncreaseTokenQuota(id int, quota int64) (err error) { | ||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
| @@ -148,7 +160,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) { | |||||||
| 	return increaseTokenQuota(id, quota) | 	return increaseTokenQuota(id, quota) | ||||||
| } | } | ||||||
|  |  | ||||||
| func increaseTokenQuota(id int, quota int) (err error) { | func increaseTokenQuota(id int, quota int64) (err error) { | ||||||
| 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"remain_quota":  gorm.Expr("remain_quota + ?", quota), | 			"remain_quota":  gorm.Expr("remain_quota + ?", quota), | ||||||
| @@ -159,7 +171,7 @@ func increaseTokenQuota(id int, quota int) (err error) { | |||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func DecreaseTokenQuota(id int, quota int) (err error) { | func DecreaseTokenQuota(id int, quota int64) (err error) { | ||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
| @@ -170,7 +182,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) { | |||||||
| 	return decreaseTokenQuota(id, quota) | 	return decreaseTokenQuota(id, quota) | ||||||
| } | } | ||||||
|  |  | ||||||
| func decreaseTokenQuota(id int, quota int) (err error) { | func decreaseTokenQuota(id int, quota int64) (err error) { | ||||||
| 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"remain_quota":  gorm.Expr("remain_quota - ?", quota), | 			"remain_quota":  gorm.Expr("remain_quota - ?", quota), | ||||||
| @@ -181,7 +193,7 @@ func decreaseTokenQuota(id int, quota int) (err error) { | |||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func PreConsumeTokenQuota(tokenId int, quota int) (err error) { | func PreConsumeTokenQuota(tokenId int, quota int64) (err error) { | ||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
| @@ -213,7 +225,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { | |||||||
| 			} | 			} | ||||||
| 			if email != "" { | 			if email != "" { | ||||||
| 				topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) | 				topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) | ||||||
| 				err = common.SendEmail(prompt, email, | 				err = message.SendEmail(prompt, email, | ||||||
| 					fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) | 					fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					logger.SysError("failed to send email" + err.Error()) | 					logger.SysError("failed to send email" + err.Error()) | ||||||
| @@ -231,7 +243,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { | |||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func PostConsumeTokenQuota(tokenId int, quota int) (err error) { | func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { | ||||||
| 	token, err := GetTokenById(tokenId) | 	token, err := GetTokenById(tokenId) | ||||||
| 	if quota > 0 { | 	if quota > 0 { | ||||||
| 		err = DecreaseUserQuota(token.UserId, quota) | 		err = DecreaseUserQuota(token.UserId, quota) | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/blacklist" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| @@ -25,9 +26,9 @@ type User struct { | |||||||
| 	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"` | 	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"` | ||||||
| 	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database! | 	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database! | ||||||
| 	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management | 	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management | ||||||
| 	Quota            int    `json:"quota" gorm:"type:int;default:0"` | 	Quota            int64  `json:"quota" gorm:"bigint;default:0"` | ||||||
| 	UsedQuota        int    `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota | 	UsedQuota        int64  `json:"used_quota" gorm:"bigint;default:0;column:used_quota"` // used quota | ||||||
| 	RequestCount     int    `json:"request_count" gorm:"type:int;default:0;"`               // request number | 	RequestCount     int    `json:"request_count" gorm:"type:int;default:0;"`             // request number | ||||||
| 	Group            string `json:"group" gorm:"type:varchar(32);default:'default'"` | 	Group            string `json:"group" gorm:"type:varchar(32);default:'default'"` | ||||||
| 	AffCode          string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` | 	AffCode          string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` | ||||||
| 	InviterId        int    `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` | 	InviterId        int    `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` | ||||||
| @@ -39,9 +40,22 @@ func GetMaxUserId() int { | |||||||
| 	return user.Id | 	return user.Id | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllUsers(startIdx int, num int) (users []*User, err error) { | func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) { | ||||||
| 	err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error |     query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted) | ||||||
| 	return users, err |      | ||||||
|  |     switch order { | ||||||
|  |     case "quota": | ||||||
|  |         query = query.Order("quota desc") | ||||||
|  |     case "used_quota": | ||||||
|  |         query = query.Order("used_quota desc") | ||||||
|  |     case "request_count": | ||||||
|  |         query = query.Order("request_count desc") | ||||||
|  |     default: | ||||||
|  |         query = query.Order("id desc") | ||||||
|  |     } | ||||||
|  |      | ||||||
|  |     err = query.Find(&users).Error | ||||||
|  |     return users, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchUsers(keyword string) (users []*User, err error) { | func SearchUsers(keyword string) (users []*User, err error) { | ||||||
| @@ -123,6 +137,11 @@ func (user *User) Update(updatePassword bool) error { | |||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 	if user.Status == common.UserStatusDisabled { | ||||||
|  | 		blacklist.BanUser(user.Id) | ||||||
|  | 	} else if user.Status == common.UserStatusEnabled { | ||||||
|  | 		blacklist.UnbanUser(user.Id) | ||||||
|  | 	} | ||||||
| 	err = DB.Model(user).Updates(user).Error | 	err = DB.Model(user).Updates(user).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| @@ -131,7 +150,10 @@ func (user *User) Delete() error { | |||||||
| 	if user.Id == 0 { | 	if user.Id == 0 { | ||||||
| 		return errors.New("id 为空!") | 		return errors.New("id 为空!") | ||||||
| 	} | 	} | ||||||
| 	err := DB.Delete(user).Error | 	blacklist.BanUser(user.Id) | ||||||
|  | 	user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID()) | ||||||
|  | 	user.Status = common.UserStatusDeleted | ||||||
|  | 	err := DB.Model(user).Updates(user).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -265,12 +287,12 @@ func ValidateAccessToken(token string) (user *User) { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetUserQuota(id int) (quota int, err error) { | func GetUserQuota(id int) (quota int64, err error) { | ||||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error | 	err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error | ||||||
| 	return quota, err | 	return quota, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetUserUsedQuota(id int) (quota int, err error) { | func GetUserUsedQuota(id int) (quota int64, err error) { | ||||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error | 	err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error | ||||||
| 	return quota, err | 	return quota, err | ||||||
| } | } | ||||||
| @@ -290,7 +312,7 @@ func GetUserGroup(id int) (group string, err error) { | |||||||
| 	return group, err | 	return group, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func IncreaseUserQuota(id int, quota int) (err error) { | func IncreaseUserQuota(id int, quota int64) (err error) { | ||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
| @@ -301,12 +323,12 @@ func IncreaseUserQuota(id int, quota int) (err error) { | |||||||
| 	return increaseUserQuota(id, quota) | 	return increaseUserQuota(id, quota) | ||||||
| } | } | ||||||
|  |  | ||||||
| func increaseUserQuota(id int, quota int) (err error) { | func increaseUserQuota(id int, quota int64) (err error) { | ||||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error | 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func DecreaseUserQuota(id int, quota int) (err error) { | func DecreaseUserQuota(id int, quota int64) (err error) { | ||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
| @@ -317,7 +339,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { | |||||||
| 	return decreaseUserQuota(id, quota) | 	return decreaseUserQuota(id, quota) | ||||||
| } | } | ||||||
|  |  | ||||||
| func decreaseUserQuota(id int, quota int) (err error) { | func decreaseUserQuota(id int, quota int64) (err error) { | ||||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error | 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| @@ -327,7 +349,7 @@ func GetRootUserEmail() (email string) { | |||||||
| 	return email | 	return email | ||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) { | ||||||
| 	if config.BatchUpdateEnabled { | 	if config.BatchUpdateEnabled { | ||||||
| 		addNewRecord(BatchUpdateTypeUsedQuota, id, quota) | 		addNewRecord(BatchUpdateTypeUsedQuota, id, quota) | ||||||
| 		addNewRecord(BatchUpdateTypeRequestCount, id, 1) | 		addNewRecord(BatchUpdateTypeRequestCount, id, 1) | ||||||
| @@ -336,7 +358,7 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | |||||||
| 	updateUserUsedQuotaAndRequestCount(id, quota, 1) | 	updateUserUsedQuotaAndRequestCount(id, quota, 1) | ||||||
| } | } | ||||||
|  |  | ||||||
| func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) { | ||||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"used_quota":    gorm.Expr("used_quota + ?", quota), | 			"used_quota":    gorm.Expr("used_quota + ?", quota), | ||||||
| @@ -348,7 +370,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func updateUserUsedQuota(id int, quota int) { | func updateUserUsedQuota(id int, quota int64) { | ||||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"used_quota": gorm.Expr("used_quota + ?", quota), | 			"used_quota": gorm.Expr("used_quota + ?", quota), | ||||||
|   | |||||||
| @@ -16,12 +16,12 @@ const ( | |||||||
| 	BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock | 	BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var batchUpdateStores []map[int]int | var batchUpdateStores []map[int]int64 | ||||||
| var batchUpdateLocks []sync.Mutex | var batchUpdateLocks []sync.Mutex | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
| 	for i := 0; i < BatchUpdateTypeCount; i++ { | 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||||
| 		batchUpdateStores = append(batchUpdateStores, make(map[int]int)) | 		batchUpdateStores = append(batchUpdateStores, make(map[int]int64)) | ||||||
| 		batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) | 		batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -35,7 +35,7 @@ func InitBatchUpdater() { | |||||||
| 	}() | 	}() | ||||||
| } | } | ||||||
|  |  | ||||||
| func addNewRecord(type_ int, id int, value int) { | func addNewRecord(type_ int, id int, value int64) { | ||||||
| 	batchUpdateLocks[type_].Lock() | 	batchUpdateLocks[type_].Lock() | ||||||
| 	defer batchUpdateLocks[type_].Unlock() | 	defer batchUpdateLocks[type_].Unlock() | ||||||
| 	if _, ok := batchUpdateStores[type_][id]; !ok { | 	if _, ok := batchUpdateStores[type_][id]; !ok { | ||||||
| @@ -50,7 +50,7 @@ func batchUpdate() { | |||||||
| 	for i := 0; i < BatchUpdateTypeCount; i++ { | 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||||
| 		batchUpdateLocks[i].Lock() | 		batchUpdateLocks[i].Lock() | ||||||
| 		store := batchUpdateStores[i] | 		store := batchUpdateStores[i] | ||||||
| 		batchUpdateStores[i] = make(map[int]int) | 		batchUpdateStores[i] = make(map[int]int64) | ||||||
| 		batchUpdateLocks[i].Unlock() | 		batchUpdateLocks[i].Unlock() | ||||||
| 		// TODO: maybe we can combine updates with same key? | 		// TODO: maybe we can combine updates with same key? | ||||||
| 		for key, value := range store { | 		for key, value := range store { | ||||||
| @@ -68,7 +68,7 @@ func batchUpdate() { | |||||||
| 			case BatchUpdateTypeUsedQuota: | 			case BatchUpdateTypeUsedQuota: | ||||||
| 				updateUserUsedQuota(key, value) | 				updateUserUsedQuota(key, value) | ||||||
| 			case BatchUpdateTypeRequestCount: | 			case BatchUpdateTypeRequestCount: | ||||||
| 				updateUserRequestCount(key, value) | 				updateUserRequestCount(key, int(value)) | ||||||
| 			case BatchUpdateTypeChannelUsedQuota: | 			case BatchUpdateTypeChannelUsedQuota: | ||||||
| 				updateChannelUsedQuota(key, value) | 				updateChannelUsedQuota(key, value) | ||||||
| 			} | 			} | ||||||
|   | |||||||
							
								
								
									
										55
									
								
								monitor/channel.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								monitor/channel.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | |||||||
|  | package monitor | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/message" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func notifyRootUser(subject string, content string) { | ||||||
|  | 	if config.MessagePusherAddress != "" { | ||||||
|  | 		err := message.SendMessage(subject, content, content) | ||||||
|  | 		if err != nil { | ||||||
|  | 			logger.SysError(fmt.Sprintf("failed to send message: %s", err.Error())) | ||||||
|  | 		} else { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if config.RootUserEmail == "" { | ||||||
|  | 		config.RootUserEmail = model.GetRootUserEmail() | ||||||
|  | 	} | ||||||
|  | 	err := message.SendEmail(subject, config.RootUserEmail, content) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // DisableChannel disable & notify | ||||||
|  | func DisableChannel(channelId int, channelName string, reason string) { | ||||||
|  | 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) | ||||||
|  | 	logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) | ||||||
|  | 	subject := fmt.Sprintf("渠道「%s」(#%d)已被禁用", channelName, channelId) | ||||||
|  | 	content := fmt.Sprintf("渠道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||||
|  | 	notifyRootUser(subject, content) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func MetricDisableChannel(channelId int, successRate float64) { | ||||||
|  | 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) | ||||||
|  | 	logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) | ||||||
|  | 	subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId) | ||||||
|  | 	content := fmt.Sprintf("该渠道(#%d)在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", | ||||||
|  | 		channelId, config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100) | ||||||
|  | 	notifyRootUser(subject, content) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // EnableChannel enable & notify | ||||||
|  | func EnableChannel(channelId int, channelName string) { | ||||||
|  | 	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) | ||||||
|  | 	logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) | ||||||
|  | 	subject := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId) | ||||||
|  | 	content := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId) | ||||||
|  | 	notifyRootUser(subject, content) | ||||||
|  | } | ||||||
							
								
								
									
										79
									
								
								monitor/metric.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								monitor/metric.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | |||||||
|  | package monitor | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var store = make(map[int][]bool) | ||||||
|  | var metricSuccessChan = make(chan int, config.MetricSuccessChanSize) | ||||||
|  | var metricFailChan = make(chan int, config.MetricFailChanSize) | ||||||
|  |  | ||||||
|  | func consumeSuccess(channelId int) { | ||||||
|  | 	if len(store[channelId]) > config.MetricQueueSize { | ||||||
|  | 		store[channelId] = store[channelId][1:] | ||||||
|  | 	} | ||||||
|  | 	store[channelId] = append(store[channelId], true) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func consumeFail(channelId int) (bool, float64) { | ||||||
|  | 	if len(store[channelId]) > config.MetricQueueSize { | ||||||
|  | 		store[channelId] = store[channelId][1:] | ||||||
|  | 	} | ||||||
|  | 	store[channelId] = append(store[channelId], false) | ||||||
|  | 	successCount := 0 | ||||||
|  | 	for _, success := range store[channelId] { | ||||||
|  | 		if success { | ||||||
|  | 			successCount++ | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	successRate := float64(successCount) / float64(len(store[channelId])) | ||||||
|  | 	if len(store[channelId]) < config.MetricQueueSize { | ||||||
|  | 		return false, successRate | ||||||
|  | 	} | ||||||
|  | 	if successRate < config.MetricSuccessRateThreshold { | ||||||
|  | 		store[channelId] = make([]bool, 0) | ||||||
|  | 		return true, successRate | ||||||
|  | 	} | ||||||
|  | 	return false, successRate | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func metricSuccessConsumer() { | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case channelId := <-metricSuccessChan: | ||||||
|  | 			consumeSuccess(channelId) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func metricFailConsumer() { | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case channelId := <-metricFailChan: | ||||||
|  | 			disable, successRate := consumeFail(channelId) | ||||||
|  | 			if disable { | ||||||
|  | 				go MetricDisableChannel(channelId, successRate) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	if config.EnableMetric { | ||||||
|  | 		go metricSuccessConsumer() | ||||||
|  | 		go metricFailConsumer() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Emit(channelId int, success bool) { | ||||||
|  | 	if !config.EnableMetric { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	go func() { | ||||||
|  | 		if success { | ||||||
|  | 			metricSuccessChan <- channelId | ||||||
|  | 		} else { | ||||||
|  | 			metricFailChan <- channelId | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | } | ||||||
| @@ -1,5 +1,5 @@ | |||||||
| [//]: # (请按照以下格式关联 issue) | [//]: # (请按照以下格式关联 issue) | ||||||
| [//]: # (请在提交 PR 前确认所提交的功能可用,附上截图即可,这将有助于项目维护者 review & merge 该 PR,谢谢) | [//]: # (请在提交 PR 前确认所提交的功能可用,需要附上截图,谢谢) | ||||||
| [//]: # (项目维护者一般仅在周末处理 PR,因此如若未能及时回复希望能理解) | [//]: # (项目维护者一般仅在周末处理 PR,因此如若未能及时回复希望能理解) | ||||||
| [//]: # (开发者交流群:910657413) | [//]: # (开发者交流群:910657413) | ||||||
| [//]: # (请在提交 PR 之前删除上面的注释) | [//]: # (请在提交 PR 之前删除上面的注释) | ||||||
| @@ -7,3 +7,4 @@ | |||||||
| close #issue_number | close #issue_number | ||||||
|  |  | ||||||
| 我已确认该 PR 已自测通过,相关截图如下: | 我已确认该 PR 已自测通过,相关截图如下: | ||||||
|  | (此处放上测试通过的截图,如果不涉及前端改动或从 UI 上无法看出,请放终端启动成功的截图) | ||||||
|   | |||||||
| @@ -53,7 +53,7 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon | |||||||
| 		FinishReason: "stop", | 		FinishReason: "stop", | ||||||
| 	} | 	} | ||||||
| 	fullTextResponse := openai.TextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
| 		Id:      helper.GetUUID(), | 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| 		Created: helper.GetTimestamp(), | 		Created: helper.GetTimestamp(), | ||||||
| 		Choices: []openai.TextResponseChoice{choice}, | 		Choices: []openai.TextResponseChoice{choice}, | ||||||
| @@ -66,7 +66,7 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion | |||||||
| 	choice.Delta.Content = aiProxyDocuments2Markdown(documents) | 	choice.Delta.Content = aiProxyDocuments2Markdown(documents) | ||||||
| 	choice.FinishReason = &constant.StopFinishReason | 	choice.FinishReason = &constant.StopFinishReason | ||||||
| 	return &openai.ChatCompletionsStreamResponse{ | 	return &openai.ChatCompletionsStreamResponse{ | ||||||
| 		Id:      helper.GetUUID(), | 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: helper.GetTimestamp(), | 		Created: helper.GetTimestamp(), | ||||||
| 		Model:   "", | 		Model:   "", | ||||||
| @@ -78,7 +78,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena | |||||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	choice.Delta.Content = response.Content | 	choice.Delta.Content = response.Content | ||||||
| 	return &openai.ChatCompletionsStreamResponse{ | 	return &openai.ChatCompletionsStreamResponse{ | ||||||
| 		Id:      helper.GetUUID(), | 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: helper.GetTimestamp(), | 		Created: helper.GetTimestamp(), | ||||||
| 		Model:   response.Model, | 		Model:   response.Model, | ||||||
|   | |||||||
| @@ -32,6 +32,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | |||||||
|  |  | ||||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | ||||||
| 	channel.SetupCommonRequestHeader(c, req, meta) | 	channel.SetupCommonRequestHeader(c, req, meta) | ||||||
|  | 	if meta.IsStream { | ||||||
|  | 		req.Header.Set("Accept", "text/event-stream") | ||||||
|  | 	} | ||||||
| 	req.Header.Set("Authorization", "Bearer "+meta.APIKey) | 	req.Header.Set("Authorization", "Bearer "+meta.APIKey) | ||||||
| 	if meta.IsStream { | 	if meta.IsStream { | ||||||
| 		req.Header.Set("X-DashScope-SSE", "enable") | 		req.Header.Set("X-DashScope-SSE", "enable") | ||||||
|   | |||||||
| @@ -33,6 +33,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | |||||||
| 		enableSearch = true | 		enableSearch = true | ||||||
| 		aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) | 		aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) | ||||||
| 	} | 	} | ||||||
|  | 	if request.TopP >= 1 { | ||||||
|  | 		request.TopP = 0.9999 | ||||||
|  | 	} | ||||||
| 	return &ChatRequest{ | 	return &ChatRequest{ | ||||||
| 		Model: aliModel, | 		Model: aliModel, | ||||||
| 		Input: Input{ | 		Input: Input{ | ||||||
| @@ -42,6 +45,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | |||||||
| 			EnableSearch:      enableSearch, | 			EnableSearch:      enableSearch, | ||||||
| 			IncrementalOutput: request.Stream, | 			IncrementalOutput: request.Stream, | ||||||
| 			Seed:              uint64(request.Seed), | 			Seed:              uint64(request.Seed), | ||||||
|  | 			MaxTokens:         request.MaxTokens, | ||||||
|  | 			Temperature:       request.Temperature, | ||||||
|  | 			TopP:              request.TopP, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -16,6 +16,8 @@ type Parameters struct { | |||||||
| 	Seed              uint64  `json:"seed,omitempty"` | 	Seed              uint64  `json:"seed,omitempty"` | ||||||
| 	EnableSearch      bool    `json:"enable_search,omitempty"` | 	EnableSearch      bool    `json:"enable_search,omitempty"` | ||||||
| 	IncrementalOutput bool    `json:"incremental_output,omitempty"` | 	IncrementalOutput bool    `json:"incremental_output,omitempty"` | ||||||
|  | 	MaxTokens         int     `json:"max_tokens,omitempty"` | ||||||
|  | 	Temperature       float64 `json:"temperature,omitempty"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type ChatRequest struct { | type ChatRequest struct { | ||||||
|   | |||||||
| @@ -5,7 +5,6 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel" | 	"github.com/songquanpeng/one-api/relay/channel" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
| 	"io" | 	"io" | ||||||
| @@ -20,7 +19,7 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||||
| 	return fmt.Sprintf("%s/v1/complete", meta.BaseURL), nil | 	return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | ||||||
| @@ -31,6 +30,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *ut | |||||||
| 		anthropicVersion = "2023-06-01" | 		anthropicVersion = "2023-06-01" | ||||||
| 	} | 	} | ||||||
| 	req.Header.Set("anthropic-version", anthropicVersion) | 	req.Header.Set("anthropic-version", anthropicVersion) | ||||||
|  | 	req.Header.Set("anthropic-beta", "messages-2023-12-15") | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -47,9 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io | |||||||
|  |  | ||||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||||
| 	if meta.IsStream { | 	if meta.IsStream { | ||||||
| 		var responseText string | 		err, usage = StreamHandler(c, resp) | ||||||
| 		err, responseText = StreamHandler(c, resp) |  | ||||||
| 		usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) |  | ||||||
| 	} else { | 	} else { | ||||||
| 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||||
| 	} | 	} | ||||||
| @@ -61,5 +59,5 @@ func (a *Adaptor) GetModelList() []string { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) GetChannelName() string { | func (a *Adaptor) GetChannelName() string { | ||||||
| 	return "authropic" | 	return "anthropic" | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,5 +1,8 @@ | |||||||
| package anthropic | package anthropic | ||||||
|  |  | ||||||
| var ModelList = []string{ | var ModelList = []string{ | ||||||
| 	"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", | 	"claude-instant-1.2", "claude-2.0", "claude-2.1", | ||||||
|  | 	"claude-3-haiku-20240307", | ||||||
|  | 	"claude-3-sonnet-20240229", | ||||||
|  | 	"claude-3-opus-20240229", | ||||||
| } | } | ||||||
|   | |||||||
| @@ -7,6 +7,7 @@ import ( | |||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/image" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| @@ -15,73 +16,135 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func stopReasonClaude2OpenAI(reason string) string { | func stopReasonClaude2OpenAI(reason *string) string { | ||||||
| 	switch reason { | 	if reason == nil { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	switch *reason { | ||||||
|  | 	case "end_turn": | ||||||
|  | 		return "stop" | ||||||
| 	case "stop_sequence": | 	case "stop_sequence": | ||||||
| 		return "stop" | 		return "stop" | ||||||
| 	case "max_tokens": | 	case "max_tokens": | ||||||
| 		return "length" | 		return "length" | ||||||
| 	default: | 	default: | ||||||
| 		return reason | 		return *reason | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||||
| 	claudeRequest := Request{ | 	claudeRequest := Request{ | ||||||
| 		Model:             textRequest.Model, | 		Model:       textRequest.Model, | ||||||
| 		Prompt:            "", | 		MaxTokens:   textRequest.MaxTokens, | ||||||
| 		MaxTokensToSample: textRequest.MaxTokens, | 		Temperature: textRequest.Temperature, | ||||||
| 		StopSequences:     nil, | 		TopP:        textRequest.TopP, | ||||||
| 		Temperature:       textRequest.Temperature, | 		Stream:      textRequest.Stream, | ||||||
| 		TopP:              textRequest.TopP, |  | ||||||
| 		Stream:            textRequest.Stream, |  | ||||||
| 	} | 	} | ||||||
| 	if claudeRequest.MaxTokensToSample == 0 { | 	if claudeRequest.MaxTokens == 0 { | ||||||
| 		claudeRequest.MaxTokensToSample = 1000000 | 		claudeRequest.MaxTokens = 4096 | ||||||
|  | 	} | ||||||
|  | 	// legacy model name mapping | ||||||
|  | 	if claudeRequest.Model == "claude-instant-1" { | ||||||
|  | 		claudeRequest.Model = "claude-instant-1.1" | ||||||
|  | 	} else if claudeRequest.Model == "claude-2" { | ||||||
|  | 		claudeRequest.Model = "claude-2.1" | ||||||
| 	} | 	} | ||||||
| 	prompt := "" |  | ||||||
| 	for _, message := range textRequest.Messages { | 	for _, message := range textRequest.Messages { | ||||||
| 		if message.Role == "user" { | 		if message.Role == "system" && claudeRequest.System == "" { | ||||||
| 			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) | 			claudeRequest.System = message.StringContent() | ||||||
| 		} else if message.Role == "assistant" { | 			continue | ||||||
| 			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) |  | ||||||
| 		} else if message.Role == "system" { |  | ||||||
| 			if prompt == "" { |  | ||||||
| 				prompt = message.StringContent() |  | ||||||
| 			} |  | ||||||
| 		} | 		} | ||||||
|  | 		claudeMessage := Message{ | ||||||
|  | 			Role: message.Role, | ||||||
|  | 		} | ||||||
|  | 		var content Content | ||||||
|  | 		if message.IsStringContent() { | ||||||
|  | 			content.Type = "text" | ||||||
|  | 			content.Text = message.StringContent() | ||||||
|  | 			claudeMessage.Content = append(claudeMessage.Content, content) | ||||||
|  | 			claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		var contents []Content | ||||||
|  | 		openaiContent := message.ParseContent() | ||||||
|  | 		for _, part := range openaiContent { | ||||||
|  | 			var content Content | ||||||
|  | 			if part.Type == model.ContentTypeText { | ||||||
|  | 				content.Type = "text" | ||||||
|  | 				content.Text = part.Text | ||||||
|  | 			} else if part.Type == model.ContentTypeImageURL { | ||||||
|  | 				content.Type = "image" | ||||||
|  | 				content.Source = &ImageSource{ | ||||||
|  | 					Type: "base64", | ||||||
|  | 				} | ||||||
|  | 				mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) | ||||||
|  | 				content.Source.MediaType = mimeType | ||||||
|  | 				content.Source.Data = data | ||||||
|  | 			} | ||||||
|  | 			contents = append(contents, content) | ||||||
|  | 		} | ||||||
|  | 		claudeMessage.Content = contents | ||||||
|  | 		claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) | ||||||
| 	} | 	} | ||||||
| 	prompt += "\n\nAssistant:" |  | ||||||
| 	claudeRequest.Prompt = prompt |  | ||||||
| 	return &claudeRequest | 	return &claudeRequest | ||||||
| } | } | ||||||
|  |  | ||||||
| func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse { | // https://docs.anthropic.com/claude/reference/messages-streaming | ||||||
|  | func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { | ||||||
|  | 	var response *Response | ||||||
|  | 	var responseText string | ||||||
|  | 	var stopReason string | ||||||
|  | 	switch claudeResponse.Type { | ||||||
|  | 	case "message_start": | ||||||
|  | 		return nil, claudeResponse.Message | ||||||
|  | 	case "content_block_start": | ||||||
|  | 		if claudeResponse.ContentBlock != nil { | ||||||
|  | 			responseText = claudeResponse.ContentBlock.Text | ||||||
|  | 		} | ||||||
|  | 	case "content_block_delta": | ||||||
|  | 		if claudeResponse.Delta != nil { | ||||||
|  | 			responseText = claudeResponse.Delta.Text | ||||||
|  | 		} | ||||||
|  | 	case "message_delta": | ||||||
|  | 		if claudeResponse.Usage != nil { | ||||||
|  | 			response = &Response{ | ||||||
|  | 				Usage: *claudeResponse.Usage, | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil { | ||||||
|  | 			stopReason = *claudeResponse.Delta.StopReason | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
| 	choice.Delta.Content = claudeResponse.Completion | 	choice.Delta.Content = responseText | ||||||
| 	finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) | 	choice.Delta.Role = "assistant" | ||||||
|  | 	finishReason := stopReasonClaude2OpenAI(&stopReason) | ||||||
| 	if finishReason != "null" { | 	if finishReason != "null" { | ||||||
| 		choice.FinishReason = &finishReason | 		choice.FinishReason = &finishReason | ||||||
| 	} | 	} | ||||||
| 	var response openai.ChatCompletionsStreamResponse | 	var openaiResponse openai.ChatCompletionsStreamResponse | ||||||
| 	response.Object = "chat.completion.chunk" | 	openaiResponse.Object = "chat.completion.chunk" | ||||||
| 	response.Model = claudeResponse.Model | 	openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||||
| 	response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | 	return &openaiResponse, response | ||||||
| 	return &response |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { | func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { | ||||||
|  | 	var responseText string | ||||||
|  | 	if len(claudeResponse.Content) > 0 { | ||||||
|  | 		responseText = claudeResponse.Content[0].Text | ||||||
|  | 	} | ||||||
| 	choice := openai.TextResponseChoice{ | 	choice := openai.TextResponseChoice{ | ||||||
| 		Index: 0, | 		Index: 0, | ||||||
| 		Message: model.Message{ | 		Message: model.Message{ | ||||||
| 			Role:    "assistant", | 			Role:    "assistant", | ||||||
| 			Content: strings.TrimPrefix(claudeResponse.Completion, " "), | 			Content: responseText, | ||||||
| 			Name:    nil, | 			Name:    nil, | ||||||
| 		}, | 		}, | ||||||
| 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | ||||||
| 	} | 	} | ||||||
| 	fullTextResponse := openai.TextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | 		Id:      fmt.Sprintf("chatcmpl-%s", claudeResponse.Id), | ||||||
|  | 		Model:   claudeResponse.Model, | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| 		Created: helper.GetTimestamp(), | 		Created: helper.GetTimestamp(), | ||||||
| 		Choices: []openai.TextResponseChoice{choice}, | 		Choices: []openai.TextResponseChoice{choice}, | ||||||
| @@ -89,17 +152,15 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { | |||||||
| 	return &fullTextResponse | 	return &fullTextResponse | ||||||
| } | } | ||||||
|  |  | ||||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { | func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||||
| 	responseText := "" |  | ||||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) |  | ||||||
| 	createdTime := helper.GetTimestamp() | 	createdTime := helper.GetTimestamp() | ||||||
| 	scanner := bufio.NewScanner(resp.Body) | 	scanner := bufio.NewScanner(resp.Body) | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
| 		if atEOF && len(data) == 0 { | 		if atEOF && len(data) == 0 { | ||||||
| 			return 0, nil, nil | 			return 0, nil, nil | ||||||
| 		} | 		} | ||||||
| 		if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { | 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||||
| 			return i + 4, data[0:i], nil | 			return i + 1, data[0:i], nil | ||||||
| 		} | 		} | ||||||
| 		if atEOF { | 		if atEOF { | ||||||
| 			return len(data), data, nil | 			return len(data), data, nil | ||||||
| @@ -111,29 +172,45 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | |||||||
| 	go func() { | 	go func() { | ||||||
| 		for scanner.Scan() { | 		for scanner.Scan() { | ||||||
| 			data := scanner.Text() | 			data := scanner.Text() | ||||||
| 			if !strings.HasPrefix(data, "event: completion") { | 			if len(data) < 6 { | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| 			data = strings.TrimPrefix(data, "event: completion\r\ndata: ") | 			if !strings.HasPrefix(data, "data: ") { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			data = strings.TrimPrefix(data, "data: ") | ||||||
| 			dataChan <- data | 			dataChan <- data | ||||||
| 		} | 		} | ||||||
| 		stopChan <- true | 		stopChan <- true | ||||||
| 	}() | 	}() | ||||||
| 	common.SetEventStreamHeaders(c) | 	common.SetEventStreamHeaders(c) | ||||||
|  | 	var usage model.Usage | ||||||
|  | 	var modelName string | ||||||
|  | 	var id string | ||||||
| 	c.Stream(func(w io.Writer) bool { | 	c.Stream(func(w io.Writer) bool { | ||||||
| 		select { | 		select { | ||||||
| 		case data := <-dataChan: | 		case data := <-dataChan: | ||||||
| 			// some implementations may add \r at the end of data | 			// some implementations may add \r at the end of data | ||||||
| 			data = strings.TrimSuffix(data, "\r") | 			data = strings.TrimSuffix(data, "\r") | ||||||
| 			var claudeResponse Response | 			var claudeResponse StreamResponse | ||||||
| 			err := json.Unmarshal([]byte(data), &claudeResponse) | 			err := json.Unmarshal([]byte(data), &claudeResponse) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
| 				return true | 				return true | ||||||
| 			} | 			} | ||||||
| 			responseText += claudeResponse.Completion | 			response, meta := streamResponseClaude2OpenAI(&claudeResponse) | ||||||
| 			response := streamResponseClaude2OpenAI(&claudeResponse) | 			if meta != nil { | ||||||
| 			response.Id = responseId | 				usage.PromptTokens += meta.Usage.InputTokens | ||||||
|  | 				usage.CompletionTokens += meta.Usage.OutputTokens | ||||||
|  | 				modelName = meta.Model | ||||||
|  | 				id = fmt.Sprintf("chatcmpl-%s", meta.Id) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			if response == nil { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			response.Id = id | ||||||
|  | 			response.Model = modelName | ||||||
| 			response.Created = createdTime | 			response.Created = createdTime | ||||||
| 			jsonStr, err := json.Marshal(response) | 			jsonStr, err := json.Marshal(response) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -147,11 +224,8 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | |||||||
| 			return false | 			return false | ||||||
| 		} | 		} | ||||||
| 	}) | 	}) | ||||||
| 	err := resp.Body.Close() | 	_ = resp.Body.Close() | ||||||
| 	if err != nil { | 	return nil, &usage | ||||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" |  | ||||||
| 	} |  | ||||||
| 	return nil, responseText |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||||
| @@ -181,11 +255,10 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st | |||||||
| 	} | 	} | ||||||
| 	fullTextResponse := responseClaude2OpenAI(&claudeResponse) | 	fullTextResponse := responseClaude2OpenAI(&claudeResponse) | ||||||
| 	fullTextResponse.Model = modelName | 	fullTextResponse.Model = modelName | ||||||
| 	completionTokens := openai.CountTokenText(claudeResponse.Completion, modelName) |  | ||||||
| 	usage := model.Usage{ | 	usage := model.Usage{ | ||||||
| 		PromptTokens:     promptTokens, | 		PromptTokens:     claudeResponse.Usage.InputTokens, | ||||||
| 		CompletionTokens: completionTokens, | 		CompletionTokens: claudeResponse.Usage.OutputTokens, | ||||||
| 		TotalTokens:      promptTokens + completionTokens, | 		TotalTokens:      claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, | ||||||
| 	} | 	} | ||||||
| 	fullTextResponse.Usage = usage | 	fullTextResponse.Usage = usage | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|   | |||||||
| @@ -1,19 +1,44 @@ | |||||||
| package anthropic | package anthropic | ||||||
|  |  | ||||||
|  | // https://docs.anthropic.com/claude/reference/messages_post | ||||||
|  |  | ||||||
| type Metadata struct { | type Metadata struct { | ||||||
| 	UserId string `json:"user_id"` | 	UserId string `json:"user_id"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type ImageSource struct { | ||||||
|  | 	Type      string `json:"type"` | ||||||
|  | 	MediaType string `json:"media_type"` | ||||||
|  | 	Data      string `json:"data"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Content struct { | ||||||
|  | 	Type   string       `json:"type"` | ||||||
|  | 	Text   string       `json:"text,omitempty"` | ||||||
|  | 	Source *ImageSource `json:"source,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Message struct { | ||||||
|  | 	Role    string    `json:"role"` | ||||||
|  | 	Content []Content `json:"content"` | ||||||
|  | } | ||||||
|  |  | ||||||
| type Request struct { | type Request struct { | ||||||
| 	Model             string   `json:"model"` | 	Model         string    `json:"model"` | ||||||
| 	Prompt            string   `json:"prompt"` | 	Messages      []Message `json:"messages"` | ||||||
| 	MaxTokensToSample int      `json:"max_tokens_to_sample"` | 	System        string    `json:"system,omitempty"` | ||||||
| 	StopSequences     []string `json:"stop_sequences,omitempty"` | 	MaxTokens     int       `json:"max_tokens,omitempty"` | ||||||
| 	Temperature       float64  `json:"temperature,omitempty"` | 	StopSequences []string  `json:"stop_sequences,omitempty"` | ||||||
| 	TopP              float64  `json:"top_p,omitempty"` | 	Stream        bool      `json:"stream,omitempty"` | ||||||
| 	TopK              int      `json:"top_k,omitempty"` | 	Temperature   float64   `json:"temperature,omitempty"` | ||||||
|  | 	TopP          float64   `json:"top_p,omitempty"` | ||||||
|  | 	TopK          int       `json:"top_k,omitempty"` | ||||||
| 	//Metadata    `json:"metadata,omitempty"` | 	//Metadata    `json:"metadata,omitempty"` | ||||||
| 	Stream bool `json:"stream,omitempty"` | } | ||||||
|  |  | ||||||
|  | type Usage struct { | ||||||
|  | 	InputTokens  int `json:"input_tokens"` | ||||||
|  | 	OutputTokens int `json:"output_tokens"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type Error struct { | type Error struct { | ||||||
| @@ -22,8 +47,29 @@ type Error struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| type Response struct { | type Response struct { | ||||||
| 	Completion string `json:"completion"` | 	Id           string    `json:"id"` | ||||||
| 	StopReason string `json:"stop_reason"` | 	Type         string    `json:"type"` | ||||||
| 	Model      string `json:"model"` | 	Role         string    `json:"role"` | ||||||
| 	Error      Error  `json:"error"` | 	Content      []Content `json:"content"` | ||||||
|  | 	Model        string    `json:"model"` | ||||||
|  | 	StopReason   *string   `json:"stop_reason"` | ||||||
|  | 	StopSequence *string   `json:"stop_sequence"` | ||||||
|  | 	Usage        Usage     `json:"usage"` | ||||||
|  | 	Error        Error     `json:"error"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Delta struct { | ||||||
|  | 	Type         string  `json:"type"` | ||||||
|  | 	Text         string  `json:"text"` | ||||||
|  | 	StopReason   *string `json:"stop_reason"` | ||||||
|  | 	StopSequence *string `json:"stop_sequence"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type StreamResponse struct { | ||||||
|  | 	Type         string    `json:"type"` | ||||||
|  | 	Message      *Response `json:"message"` | ||||||
|  | 	Index        int       `json:"index"` | ||||||
|  | 	ContentBlock *Content  `json:"content_block"` | ||||||
|  | 	Delta        *Delta    `json:"delta"` | ||||||
|  | 	Usage        *Usage    `json:"usage"` | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										7
									
								
								relay/channel/baichuan/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								relay/channel/baichuan/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | |||||||
|  | package baichuan | ||||||
|  |  | ||||||
|  | var ModelList = []string{ | ||||||
|  | 	"Baichuan2-Turbo", | ||||||
|  | 	"Baichuan2-Turbo-192k", | ||||||
|  | 	"Baichuan-Text-Embedding", | ||||||
|  | } | ||||||
| @@ -2,13 +2,16 @@ package baidu | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strings" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel" | 	"github.com/songquanpeng/one-api/relay/channel" | ||||||
| 	"github.com/songquanpeng/one-api/relay/constant" | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Adaptor struct { | type Adaptor struct { | ||||||
| @@ -20,23 +23,45 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { | |||||||
|  |  | ||||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||||
| 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t | 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t | ||||||
| 	var fullRequestURL string | 	suffix := "chat/" | ||||||
| 	switch meta.ActualModelName { | 	if strings.HasPrefix(meta.ActualModelName, "Embedding") { | ||||||
| 	case "ERNIE-Bot-4": | 		suffix = "embeddings/" | ||||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" |  | ||||||
| 	case "ERNIE-Bot-8K": |  | ||||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k" |  | ||||||
| 	case "ERNIE-Bot": |  | ||||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" |  | ||||||
| 	case "ERNIE-Speed": |  | ||||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed" |  | ||||||
| 	case "ERNIE-Bot-turbo": |  | ||||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" |  | ||||||
| 	case "BLOOMZ-7B": |  | ||||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" |  | ||||||
| 	case "Embedding-V1": |  | ||||||
| 		fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" |  | ||||||
| 	} | 	} | ||||||
|  | 	if strings.HasPrefix(meta.ActualModelName, "bge-large") { | ||||||
|  | 		suffix = "embeddings/" | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(meta.ActualModelName, "tao-8k") { | ||||||
|  | 		suffix = "embeddings/" | ||||||
|  | 	} | ||||||
|  | 	switch meta.ActualModelName { | ||||||
|  | 	case "ERNIE-4.0": | ||||||
|  | 		suffix += "completions_pro" | ||||||
|  | 	case "ERNIE-Bot-4": | ||||||
|  | 		suffix += "completions_pro" | ||||||
|  | 	case "ERNIE-3.5-8K": | ||||||
|  | 		suffix += "completions" | ||||||
|  | 	case "ERNIE-Bot-8K": | ||||||
|  | 		suffix += "ernie_bot_8k" | ||||||
|  | 	case "ERNIE-Bot": | ||||||
|  | 		suffix += "completions" | ||||||
|  | 	case "ERNIE-Speed": | ||||||
|  | 		suffix += "ernie_speed" | ||||||
|  | 	case "ERNIE-Bot-turbo": | ||||||
|  | 		suffix += "eb-instant" | ||||||
|  | 	case "BLOOMZ-7B": | ||||||
|  | 		suffix += "bloomz_7b1" | ||||||
|  | 	case "Embedding-V1": | ||||||
|  | 		suffix += "embedding-v1" | ||||||
|  | 	case "bge-large-zh": | ||||||
|  | 		suffix += "bge_large_zh" | ||||||
|  | 	case "bge-large-en": | ||||||
|  | 		suffix += "bge_large_en" | ||||||
|  | 	case "tao-8k": | ||||||
|  | 		suffix += "tao_8k" | ||||||
|  | 	default: | ||||||
|  | 		suffix += meta.ActualModelName | ||||||
|  | 	} | ||||||
|  | 	fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix) | ||||||
| 	var accessToken string | 	var accessToken string | ||||||
| 	var err error | 	var err error | ||||||
| 	if accessToken, err = GetAccessToken(meta.APIKey); err != nil { | 	if accessToken, err = GetAccessToken(meta.APIKey); err != nil { | ||||||
|   | |||||||
| @@ -7,4 +7,7 @@ var ModelList = []string{ | |||||||
| 	"ERNIE-Speed", | 	"ERNIE-Speed", | ||||||
| 	"ERNIE-Bot-turbo", | 	"ERNIE-Bot-turbo", | ||||||
| 	"Embedding-V1", | 	"Embedding-V1", | ||||||
|  | 	"bge-large-zh", | ||||||
|  | 	"bge-large-en", | ||||||
|  | 	"tao-8k", | ||||||
| } | } | ||||||
|   | |||||||
| @@ -32,9 +32,16 @@ type Message struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| type ChatRequest struct { | type ChatRequest struct { | ||||||
| 	Messages []Message `json:"messages"` | 	Messages        []Message `json:"messages"` | ||||||
| 	Stream   bool      `json:"stream"` | 	Temperature     float64   `json:"temperature,omitempty"` | ||||||
| 	UserId   string    `json:"user_id,omitempty"` | 	TopP            float64   `json:"top_p,omitempty"` | ||||||
|  | 	PenaltyScore    float64   `json:"penalty_score,omitempty"` | ||||||
|  | 	Stream          bool      `json:"stream,omitempty"` | ||||||
|  | 	System          string    `json:"system,omitempty"` | ||||||
|  | 	DisableSearch   bool      `json:"disable_search,omitempty"` | ||||||
|  | 	EnableCitation  bool      `json:"enable_citation,omitempty"` | ||||||
|  | 	MaxOutputTokens int       `json:"max_output_tokens,omitempty"` | ||||||
|  | 	UserId          string    `json:"user_id,omitempty"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type Error struct { | type Error struct { | ||||||
| @@ -45,28 +52,28 @@ type Error struct { | |||||||
| var baiduTokenStore sync.Map | var baiduTokenStore sync.Map | ||||||
|  |  | ||||||
| func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||||
| 	messages := make([]Message, 0, len(request.Messages)) | 	baiduRequest := ChatRequest{ | ||||||
|  | 		Messages:        make([]Message, 0, len(request.Messages)), | ||||||
|  | 		Temperature:     request.Temperature, | ||||||
|  | 		TopP:            request.TopP, | ||||||
|  | 		PenaltyScore:    request.FrequencyPenalty, | ||||||
|  | 		Stream:          request.Stream, | ||||||
|  | 		DisableSearch:   false, | ||||||
|  | 		EnableCitation:  false, | ||||||
|  | 		MaxOutputTokens: request.MaxTokens, | ||||||
|  | 		UserId:          request.User, | ||||||
|  | 	} | ||||||
| 	for _, message := range request.Messages { | 	for _, message := range request.Messages { | ||||||
| 		if message.Role == "system" { | 		if message.Role == "system" { | ||||||
| 			messages = append(messages, Message{ | 			baiduRequest.System = message.StringContent() | ||||||
| 				Role:    "user", |  | ||||||
| 				Content: message.StringContent(), |  | ||||||
| 			}) |  | ||||||
| 			messages = append(messages, Message{ |  | ||||||
| 				Role:    "assistant", |  | ||||||
| 				Content: "Okay", |  | ||||||
| 			}) |  | ||||||
| 		} else { | 		} else { | ||||||
| 			messages = append(messages, Message{ | 			baiduRequest.Messages = append(baiduRequest.Messages, Message{ | ||||||
| 				Role:    message.Role, | 				Role:    message.Role, | ||||||
| 				Content: message.StringContent(), | 				Content: message.StringContent(), | ||||||
| 			}) | 			}) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	return &ChatRequest{ | 	return &baiduRequest | ||||||
| 		Messages: messages, |  | ||||||
| 		Stream:   request.Stream, |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { | func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||||
|   | |||||||
| @@ -1,6 +1,8 @@ | |||||||
| package gemini | package gemini | ||||||
|  |  | ||||||
|  | // https://ai.google.dev/models/gemini | ||||||
|  |  | ||||||
| var ModelList = []string{ | var ModelList = []string{ | ||||||
| 	"gemini-pro", | 	"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro", | ||||||
| 	"gemini-pro-vision", | 	"gemini-pro-vision", "gemini-1.0-pro-vision-001", | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										10
									
								
								relay/channel/groq/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								relay/channel/groq/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | |||||||
|  | package groq | ||||||
|  |  | ||||||
|  | // https://console.groq.com/docs/models | ||||||
|  |  | ||||||
|  | var ModelList = []string{ | ||||||
|  | 	"gemma-7b-it", | ||||||
|  | 	"llama2-7b-2048", | ||||||
|  | 	"llama2-70b-4096", | ||||||
|  | 	"mixtral-8x7b-32768", | ||||||
|  | } | ||||||
							
								
								
									
										9
									
								
								relay/channel/lingyiwanwu/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								relay/channel/lingyiwanwu/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | |||||||
|  | package lingyiwanwu | ||||||
|  |  | ||||||
|  | // https://platform.lingyiwanwu.com/docs | ||||||
|  |  | ||||||
|  | var ModelList = []string{ | ||||||
|  | 	"yi-34b-chat-0205", | ||||||
|  | 	"yi-34b-chat-200k", | ||||||
|  | 	"yi-vl-plus", | ||||||
|  | } | ||||||
							
								
								
									
										7
									
								
								relay/channel/minimax/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								relay/channel/minimax/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | |||||||
|  | package minimax | ||||||
|  |  | ||||||
|  | var ModelList = []string{ | ||||||
|  | 	"abab5.5s-chat", | ||||||
|  | 	"abab5.5-chat", | ||||||
|  | 	"abab6-chat", | ||||||
|  | } | ||||||
							
								
								
									
										14
									
								
								relay/channel/minimax/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								relay/channel/minimax/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | |||||||
|  | package minimax | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||||
|  | 	if meta.Mode == constant.RelayModeChatCompletions { | ||||||
|  | 		return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil | ||||||
|  | 	} | ||||||
|  | 	return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode) | ||||||
|  | } | ||||||
							
								
								
									
										10
									
								
								relay/channel/mistral/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								relay/channel/mistral/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | |||||||
|  | package mistral | ||||||
|  |  | ||||||
|  | var ModelList = []string{ | ||||||
|  | 	"open-mistral-7b", | ||||||
|  | 	"open-mixtral-8x7b", | ||||||
|  | 	"mistral-small-latest", | ||||||
|  | 	"mistral-medium-latest", | ||||||
|  | 	"mistral-large-latest", | ||||||
|  | 	"mistral-embed", | ||||||
|  | } | ||||||
							
								
								
									
										75
									
								
								relay/channel/ollama/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								relay/channel/ollama/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,75 @@ | |||||||
|  | package ollama | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Adaptor struct { | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) Init(meta *util.RelayMeta) { | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||||
|  | 	// https://github.com/ollama/ollama/blob/main/docs/api.md | ||||||
|  | 	fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) | ||||||
|  | 	if meta.Mode == constant.RelayModeEmbeddings { | ||||||
|  | 		fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL) | ||||||
|  | 	} | ||||||
|  | 	return fullRequestURL, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | ||||||
|  | 	channel.SetupCommonRequestHeader(c, req, meta) | ||||||
|  | 	req.Header.Set("Authorization", "Bearer "+meta.APIKey) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||||
|  | 	if request == nil { | ||||||
|  | 		return nil, errors.New("request is nil") | ||||||
|  | 	} | ||||||
|  | 	switch relayMode { | ||||||
|  | 	case constant.RelayModeEmbeddings: | ||||||
|  | 		ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request) | ||||||
|  | 		return ollamaEmbeddingRequest, nil | ||||||
|  | 	default: | ||||||
|  | 		return ConvertRequest(*request), nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { | ||||||
|  | 	return channel.DoRequestHelper(a, c, meta, requestBody) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||||
|  | 	if meta.IsStream { | ||||||
|  | 		err, usage = StreamHandler(c, resp) | ||||||
|  | 	} else { | ||||||
|  | 		switch meta.Mode { | ||||||
|  | 		case constant.RelayModeEmbeddings: | ||||||
|  | 			err, usage = EmbeddingHandler(c, resp) | ||||||
|  | 		default: | ||||||
|  | 			err, usage = Handler(c, resp) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) GetModelList() []string { | ||||||
|  | 	return ModelList | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) GetChannelName() string { | ||||||
|  | 	return "ollama" | ||||||
|  | } | ||||||
							
								
								
									
										5
									
								
								relay/channel/ollama/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								relay/channel/ollama/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | package ollama | ||||||
|  |  | ||||||
|  | var ModelList = []string{ | ||||||
|  | 	"qwen:0.5b-chat", | ||||||
|  | } | ||||||
							
								
								
									
										237
									
								
								relay/channel/ollama/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										237
									
								
								relay/channel/ollama/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,237 @@ | |||||||
|  | package ollama | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"context" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strings" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||||
|  | 	ollamaRequest := ChatRequest{ | ||||||
|  | 		Model: request.Model, | ||||||
|  | 		Options: &Options{ | ||||||
|  | 			Seed:             int(request.Seed), | ||||||
|  | 			Temperature:      request.Temperature, | ||||||
|  | 			TopP:             request.TopP, | ||||||
|  | 			FrequencyPenalty: request.FrequencyPenalty, | ||||||
|  | 			PresencePenalty:  request.PresencePenalty, | ||||||
|  | 		}, | ||||||
|  | 		Stream: request.Stream, | ||||||
|  | 	} | ||||||
|  | 	for _, message := range request.Messages { | ||||||
|  | 		ollamaRequest.Messages = append(ollamaRequest.Messages, Message{ | ||||||
|  | 			Role:    message.Role, | ||||||
|  | 			Content: message.StringContent(), | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	return &ollamaRequest | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||||
|  | 	choice := openai.TextResponseChoice{ | ||||||
|  | 		Index: 0, | ||||||
|  | 		Message: model.Message{ | ||||||
|  | 			Role:    response.Message.Role, | ||||||
|  | 			Content: response.Message.Content, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	if response.Done { | ||||||
|  | 		choice.FinishReason = "stop" | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := openai.TextResponse{ | ||||||
|  | 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||||
|  | 		Object:  "chat.completion", | ||||||
|  | 		Created: helper.GetTimestamp(), | ||||||
|  | 		Choices: []openai.TextResponseChoice{choice}, | ||||||
|  | 		Usage: model.Usage{ | ||||||
|  | 			PromptTokens:     response.PromptEvalCount, | ||||||
|  | 			CompletionTokens: response.EvalCount, | ||||||
|  | 			TotalTokens:      response.PromptEvalCount + response.EvalCount, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	return &fullTextResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | ||||||
|  | 	var choice openai.ChatCompletionsStreamResponseChoice | ||||||
|  | 	choice.Delta.Role = ollamaResponse.Message.Role | ||||||
|  | 	choice.Delta.Content = ollamaResponse.Message.Content | ||||||
|  | 	if ollamaResponse.Done { | ||||||
|  | 		choice.FinishReason = &constant.StopFinishReason | ||||||
|  | 	} | ||||||
|  | 	response := openai.ChatCompletionsStreamResponse{ | ||||||
|  | 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: helper.GetTimestamp(), | ||||||
|  | 		Model:   ollamaResponse.Model, | ||||||
|  | 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | 	return &response | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||||
|  | 	var usage model.Usage | ||||||
|  | 	scanner := bufio.NewScanner(resp.Body) | ||||||
|  | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
|  | 		if atEOF && len(data) == 0 { | ||||||
|  | 			return 0, nil, nil | ||||||
|  | 		} | ||||||
|  | 		if i := strings.Index(string(data), "}\n"); i >= 0 { | ||||||
|  | 			return i + 2, data[0:i], nil | ||||||
|  | 		} | ||||||
|  | 		if atEOF { | ||||||
|  | 			return len(data), data, nil | ||||||
|  | 		} | ||||||
|  | 		return 0, nil, nil | ||||||
|  | 	}) | ||||||
|  | 	dataChan := make(chan string) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		for scanner.Scan() { | ||||||
|  | 			data := strings.TrimPrefix(scanner.Text(), "}") | ||||||
|  | 			dataChan <- data + "}" | ||||||
|  | 		} | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  | 	common.SetEventStreamHeaders(c) | ||||||
|  | 	c.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case data := <-dataChan: | ||||||
|  | 			var ollamaResponse ChatResponse | ||||||
|  | 			err := json.Unmarshal([]byte(data), &ollamaResponse) | ||||||
|  | 			if err != nil { | ||||||
|  | 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			if ollamaResponse.EvalCount != 0 { | ||||||
|  | 				usage.PromptTokens = ollamaResponse.PromptEvalCount | ||||||
|  | 				usage.CompletionTokens = ollamaResponse.EvalCount | ||||||
|  | 				usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount | ||||||
|  | 			} | ||||||
|  | 			response := streamResponseOllama2OpenAI(&ollamaResponse) | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	err := resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	return nil, &usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { | ||||||
|  | 	return &EmbeddingRequest{ | ||||||
|  | 		Model:  request.Model, | ||||||
|  | 		Prompt: strings.Join(request.ParseInput(), " "), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||||
|  | 	var ollamaResponse EmbeddingResponse | ||||||
|  | 	err := json.NewDecoder(resp.Body).Decode(&ollamaResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if ollamaResponse.Error != "" { | ||||||
|  | 		return &model.ErrorWithStatusCode{ | ||||||
|  | 			Error: model.Error{ | ||||||
|  | 				Message: ollamaResponse.Error, | ||||||
|  | 				Type:    "ollama_error", | ||||||
|  | 				Param:   "", | ||||||
|  | 				Code:    "ollama_error", | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fullTextResponse := embeddingResponseOllama2OpenAI(&ollamaResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { | ||||||
|  | 	openAIEmbeddingResponse := openai.EmbeddingResponse{ | ||||||
|  | 		Object: "list", | ||||||
|  | 		Data:   make([]openai.EmbeddingResponseItem, 0, 1), | ||||||
|  | 		Model:  "text-embedding-v1", | ||||||
|  | 		Usage:  model.Usage{TotalTokens: 0}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ | ||||||
|  | 		Object:    `embedding`, | ||||||
|  | 		Index:     0, | ||||||
|  | 		Embedding: response.Embedding, | ||||||
|  | 	}) | ||||||
|  | 	return &openAIEmbeddingResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||||
|  | 	ctx := context.TODO() | ||||||
|  | 	var ollamaResponse ChatResponse | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	logger.Debugf(ctx, "ollama response: %s", string(responseBody)) | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &ollamaResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	if ollamaResponse.Error != "" { | ||||||
|  | 		return &model.ErrorWithStatusCode{ | ||||||
|  | 			Error: model.Error{ | ||||||
|  | 				Message: ollamaResponse.Error, | ||||||
|  | 				Type:    "ollama_error", | ||||||
|  | 				Param:   "", | ||||||
|  | 				Code:    "ollama_error", | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := responseOllama2OpenAI(&ollamaResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
							
								
								
									
										47
									
								
								relay/channel/ollama/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								relay/channel/ollama/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | |||||||
|  | package ollama | ||||||
|  |  | ||||||
|  | type Options struct { | ||||||
|  | 	Seed             int     `json:"seed,omitempty"` | ||||||
|  | 	Temperature      float64 `json:"temperature,omitempty"` | ||||||
|  | 	TopK             int     `json:"top_k,omitempty"` | ||||||
|  | 	TopP             float64 `json:"top_p,omitempty"` | ||||||
|  | 	FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` | ||||||
|  | 	PresencePenalty  float64 `json:"presence_penalty,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Message struct { | ||||||
|  | 	Role    string   `json:"role,omitempty"` | ||||||
|  | 	Content string   `json:"content,omitempty"` | ||||||
|  | 	Images  []string `json:"images,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ChatRequest struct { | ||||||
|  | 	Model    string    `json:"model,omitempty"` | ||||||
|  | 	Messages []Message `json:"messages,omitempty"` | ||||||
|  | 	Stream   bool      `json:"stream"` | ||||||
|  | 	Options  *Options  `json:"options,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ChatResponse struct { | ||||||
|  | 	Model           string  `json:"model,omitempty"` | ||||||
|  | 	CreatedAt       string  `json:"created_at,omitempty"` | ||||||
|  | 	Message         Message `json:"message,omitempty"` | ||||||
|  | 	Response        string  `json:"response,omitempty"` // for stream response | ||||||
|  | 	Done            bool    `json:"done,omitempty"` | ||||||
|  | 	TotalDuration   int     `json:"total_duration,omitempty"` | ||||||
|  | 	LoadDuration    int     `json:"load_duration,omitempty"` | ||||||
|  | 	PromptEvalCount int     `json:"prompt_eval_count,omitempty"` | ||||||
|  | 	EvalCount       int     `json:"eval_count,omitempty"` | ||||||
|  | 	EvalDuration    int     `json:"eval_duration,omitempty"` | ||||||
|  | 	Error           string  `json:"error,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type EmbeddingRequest struct { | ||||||
|  | 	Model  string `json:"model"` | ||||||
|  | 	Prompt string `json:"prompt"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type EmbeddingResponse struct { | ||||||
|  | 	Error     string    `json:"error,omitempty"` | ||||||
|  | 	Embedding []float64 `json:"embedding,omitempty"` | ||||||
|  | } | ||||||
| @@ -6,8 +6,7 @@ import ( | |||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel" | 	"github.com/songquanpeng/one-api/relay/channel" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/ai360" | 	"github.com/songquanpeng/one-api/relay/channel/minimax" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/moonshot" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
| 	"io" | 	"io" | ||||||
| @@ -24,22 +23,23 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||||
| 	if meta.ChannelType == common.ChannelTypeAzure { | 	switch meta.ChannelType { | ||||||
|  | 	case common.ChannelTypeAzure: | ||||||
| 		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | 		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | ||||||
| 		requestURL := strings.Split(meta.RequestURLPath, "?")[0] | 		requestURL := strings.Split(meta.RequestURLPath, "?")[0] | ||||||
| 		requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) | 		requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) | ||||||
| 		task := strings.TrimPrefix(requestURL, "/v1/") | 		task := strings.TrimPrefix(requestURL, "/v1/") | ||||||
| 		model_ := meta.ActualModelName | 		model_ := meta.ActualModelName | ||||||
| 		model_ = strings.Replace(model_, ".", "", -1) | 		model_ = strings.Replace(model_, ".", "", -1) | ||||||
| 		// https://github.com/songquanpeng/one-api/issues/67 | 		//https://github.com/songquanpeng/one-api/issues/1191 | ||||||
| 		model_ = strings.TrimSuffix(model_, "-0301") | 		// {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version} | ||||||
| 		model_ = strings.TrimSuffix(model_, "-0314") |  | ||||||
| 		model_ = strings.TrimSuffix(model_, "-0613") |  | ||||||
|  |  | ||||||
| 		requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) | 		requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) | ||||||
| 		return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil | 		return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil | ||||||
|  | 	case common.ChannelTypeMinimax: | ||||||
|  | 		return minimax.GetRequestURL(meta) | ||||||
|  | 	default: | ||||||
|  | 		return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil | ||||||
| 	} | 	} | ||||||
| 	return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | ||||||
| @@ -70,7 +70,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io | |||||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||||
| 	if meta.IsStream { | 	if meta.IsStream { | ||||||
| 		var responseText string | 		var responseText string | ||||||
| 		err, responseText = StreamHandler(c, resp, meta.Mode) | 		err, responseText, _ = StreamHandler(c, resp, meta.Mode) | ||||||
| 		usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) | 		usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) | ||||||
| 	} else { | 	} else { | ||||||
| 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||||
| @@ -79,25 +79,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel | |||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) GetModelList() []string { | func (a *Adaptor) GetModelList() []string { | ||||||
| 	switch a.ChannelType { | 	_, modelList := GetCompatibleChannelMeta(a.ChannelType) | ||||||
| 	case common.ChannelType360: | 	return modelList | ||||||
| 		return ai360.ModelList |  | ||||||
| 	case common.ChannelTypeMoonshot: |  | ||||||
| 		return moonshot.ModelList |  | ||||||
| 	default: |  | ||||||
| 		return ModelList |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) GetChannelName() string { | func (a *Adaptor) GetChannelName() string { | ||||||
| 	switch a.ChannelType { | 	channelName, _ := GetCompatibleChannelMeta(a.ChannelType) | ||||||
| 	case common.ChannelTypeAzure: | 	return channelName | ||||||
| 		return "azure" |  | ||||||
| 	case common.ChannelType360: |  | ||||||
| 		return "360" |  | ||||||
| 	case common.ChannelTypeMoonshot: |  | ||||||
| 		return "moonshot" |  | ||||||
| 	default: |  | ||||||
| 		return "openai" |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										46
									
								
								relay/channel/openai/compatible.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								relay/channel/openai/compatible.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | |||||||
|  | package openai | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/ai360" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/baichuan" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/groq" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/lingyiwanwu" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/minimax" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/mistral" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/moonshot" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var CompatibleChannels = []int{ | ||||||
|  | 	common.ChannelTypeAzure, | ||||||
|  | 	common.ChannelType360, | ||||||
|  | 	common.ChannelTypeMoonshot, | ||||||
|  | 	common.ChannelTypeBaichuan, | ||||||
|  | 	common.ChannelTypeMinimax, | ||||||
|  | 	common.ChannelTypeMistral, | ||||||
|  | 	common.ChannelTypeGroq, | ||||||
|  | 	common.ChannelTypeLingYiWanWu, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||||
|  | 	switch channelType { | ||||||
|  | 	case common.ChannelTypeAzure: | ||||||
|  | 		return "azure", ModelList | ||||||
|  | 	case common.ChannelType360: | ||||||
|  | 		return "360", ai360.ModelList | ||||||
|  | 	case common.ChannelTypeMoonshot: | ||||||
|  | 		return "moonshot", moonshot.ModelList | ||||||
|  | 	case common.ChannelTypeBaichuan: | ||||||
|  | 		return "baichuan", baichuan.ModelList | ||||||
|  | 	case common.ChannelTypeMinimax: | ||||||
|  | 		return "minimax", minimax.ModelList | ||||||
|  | 	case common.ChannelTypeMistral: | ||||||
|  | 		return "mistralai", mistral.ModelList | ||||||
|  | 	case common.ChannelTypeGroq: | ||||||
|  | 		return "groq", groq.ModelList | ||||||
|  | 	case common.ChannelTypeLingYiWanWu: | ||||||
|  | 		return "lingyiwanwu", lingyiwanwu.ModelList | ||||||
|  | 	default: | ||||||
|  | 		return "openai", ModelList | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -14,7 +14,7 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string) { | func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { | ||||||
| 	responseText := "" | 	responseText := "" | ||||||
| 	scanner := bufio.NewScanner(resp.Body) | 	scanner := bufio.NewScanner(resp.Body) | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
| @@ -31,6 +31,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | |||||||
| 	}) | 	}) | ||||||
| 	dataChan := make(chan string) | 	dataChan := make(chan string) | ||||||
| 	stopChan := make(chan bool) | 	stopChan := make(chan bool) | ||||||
|  | 	var usage *model.Usage | ||||||
| 	go func() { | 	go func() { | ||||||
| 		for scanner.Scan() { | 		for scanner.Scan() { | ||||||
| 			data := scanner.Text() | 			data := scanner.Text() | ||||||
| @@ -54,6 +55,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | |||||||
| 					for _, choice := range streamResponse.Choices { | 					for _, choice := range streamResponse.Choices { | ||||||
| 						responseText += choice.Delta.Content | 						responseText += choice.Delta.Content | ||||||
| 					} | 					} | ||||||
|  | 					if streamResponse.Usage != nil { | ||||||
|  | 						usage = streamResponse.Usage | ||||||
|  | 					} | ||||||
| 				case constant.RelayModeCompletions: | 				case constant.RelayModeCompletions: | ||||||
| 					var streamResponse CompletionsStreamResponse | 					var streamResponse CompletionsStreamResponse | ||||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||||
| @@ -86,9 +90,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | |||||||
| 	}) | 	}) | ||||||
| 	err := resp.Body.Close() | 	err := resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | 		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil | ||||||
| 	} | 	} | ||||||
| 	return nil, responseText | 	return nil, responseText, usage | ||||||
| } | } | ||||||
|  |  | ||||||
| func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||||
|   | |||||||
| @@ -118,8 +118,10 @@ type ImageResponse struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| type ChatCompletionsStreamResponseChoice struct { | type ChatCompletionsStreamResponseChoice struct { | ||||||
|  | 	Index int `json:"index"` | ||||||
| 	Delta struct { | 	Delta struct { | ||||||
| 		Content string `json:"content"` | 		Content string `json:"content"` | ||||||
|  | 		Role    string `json:"role,omitempty"` | ||||||
| 	} `json:"delta"` | 	} `json:"delta"` | ||||||
| 	FinishReason *string `json:"finish_reason,omitempty"` | 	FinishReason *string `json:"finish_reason,omitempty"` | ||||||
| } | } | ||||||
| @@ -130,6 +132,7 @@ type ChatCompletionsStreamResponse struct { | |||||||
| 	Created int64                                 `json:"created"` | 	Created int64                                 `json:"created"` | ||||||
| 	Model   string                                `json:"model"` | 	Model   string                                `json:"model"` | ||||||
| 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"` | 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"` | ||||||
|  | 	Usage   *model.Usage                          `json:"usage"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type CompletionsStreamResponse struct { | type CompletionsStreamResponse struct { | ||||||
|   | |||||||
| @@ -28,17 +28,6 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | |||||||
| 	messages := make([]Message, 0, len(request.Messages)) | 	messages := make([]Message, 0, len(request.Messages)) | ||||||
| 	for i := 0; i < len(request.Messages); i++ { | 	for i := 0; i < len(request.Messages); i++ { | ||||||
| 		message := request.Messages[i] | 		message := request.Messages[i] | ||||||
| 		if message.Role == "system" { |  | ||||||
| 			messages = append(messages, Message{ |  | ||||||
| 				Role:    "user", |  | ||||||
| 				Content: message.StringContent(), |  | ||||||
| 			}) |  | ||||||
| 			messages = append(messages, Message{ |  | ||||||
| 				Role:    "assistant", |  | ||||||
| 				Content: "Okay", |  | ||||||
| 			}) |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		messages = append(messages, Message{ | 		messages = append(messages, Message{ | ||||||
| 			Content: message.StringContent(), | 			Content: message.StringContent(), | ||||||
| 			Role:    message.Role, | 			Role:    message.Role, | ||||||
| @@ -81,6 +70,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { | |||||||
|  |  | ||||||
| func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | ||||||
| 	response := openai.ChatCompletionsStreamResponse{ | 	response := openai.ChatCompletionsStreamResponse{ | ||||||
|  | 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: helper.GetTimestamp(), | 		Created: helper.GetTimestamp(), | ||||||
| 		Model:   "tencent-hunyuan", | 		Model:   "tencent-hunyuan", | ||||||
|   | |||||||
| @@ -27,21 +27,10 @@ import ( | |||||||
| func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { | func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { | ||||||
| 	messages := make([]Message, 0, len(request.Messages)) | 	messages := make([]Message, 0, len(request.Messages)) | ||||||
| 	for _, message := range request.Messages { | 	for _, message := range request.Messages { | ||||||
| 		if message.Role == "system" { | 		messages = append(messages, Message{ | ||||||
| 			messages = append(messages, Message{ | 			Role:    message.Role, | ||||||
| 				Role:    "user", | 			Content: message.StringContent(), | ||||||
| 				Content: message.StringContent(), | 		}) | ||||||
| 			}) |  | ||||||
| 			messages = append(messages, Message{ |  | ||||||
| 				Role:    "assistant", |  | ||||||
| 				Content: "Okay", |  | ||||||
| 			}) |  | ||||||
| 		} else { |  | ||||||
| 			messages = append(messages, Message{ |  | ||||||
| 				Role:    message.Role, |  | ||||||
| 				Content: message.StringContent(), |  | ||||||
| 			}) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 	xunfeiRequest := ChatRequest{} | 	xunfeiRequest := ChatRequest{} | ||||||
| 	xunfeiRequest.Header.AppId = xunfeiAppId | 	xunfeiRequest.Header.AppId = xunfeiAppId | ||||||
| @@ -70,6 +59,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { | |||||||
| 		FinishReason: constant.StopFinishReason, | 		FinishReason: constant.StopFinishReason, | ||||||
| 	} | 	} | ||||||
| 	fullTextResponse := openai.TextResponse{ | 	fullTextResponse := openai.TextResponse{ | ||||||
|  | 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| 		Created: helper.GetTimestamp(), | 		Created: helper.GetTimestamp(), | ||||||
| 		Choices: []openai.TextResponseChoice{choice}, | 		Choices: []openai.TextResponseChoice{choice}, | ||||||
| @@ -92,6 +82,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl | |||||||
| 		choice.FinishReason = &constant.StopFinishReason | 		choice.FinishReason = &constant.StopFinishReason | ||||||
| 	} | 	} | ||||||
| 	response := openai.ChatCompletionsStreamResponse{ | 	response := openai.ChatCompletionsStreamResponse{ | ||||||
|  | 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||||
| 		Object:  "chat.completion.chunk", | 		Object:  "chat.completion.chunk", | ||||||
| 		Created: helper.GetTimestamp(), | 		Created: helper.GetTimestamp(), | ||||||
| 		Model:   "SparkDesk", | 		Model:   "SparkDesk", | ||||||
| @@ -130,7 +121,7 @@ func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId | |||||||
| 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) | 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) | ||||||
| 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	common.SetEventStreamHeaders(c) | 	common.SetEventStreamHeaders(c) | ||||||
| 	var usage model.Usage | 	var usage model.Usage | ||||||
| @@ -160,7 +151,7 @@ func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId strin | |||||||
| 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) | 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) | ||||||
| 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | 		return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	var usage model.Usage | 	var usage model.Usage | ||||||
| 	var content string | 	var content string | ||||||
| @@ -180,11 +171,7 @@ func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId strin | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { | 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { | ||||||
| 		xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{ | 		return openai.ErrorWrapper(err, "xunfei_empty_response_detected", http.StatusInternalServerError), nil | ||||||
| 			{ |  | ||||||
| 				Content: "", |  | ||||||
| 			}, |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 	xunfeiResponse.Payload.Choices.Text[0].Content = content | 	xunfeiResponse.Payload.Choices.Text[0].Content = content | ||||||
|  |  | ||||||
| @@ -211,15 +198,21 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, err | 		return nil, nil, err | ||||||
| 	} | 	} | ||||||
|  | 	_, msg, err := conn.ReadMessage() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	dataChan := make(chan ChatResponse) | 	dataChan := make(chan ChatResponse) | ||||||
| 	stopChan := make(chan bool) | 	stopChan := make(chan bool) | ||||||
| 	go func() { | 	go func() { | ||||||
| 		for { | 		for { | ||||||
| 			_, msg, err := conn.ReadMessage() | 			if msg == nil { | ||||||
| 			if err != nil { | 				_, msg, err = conn.ReadMessage() | ||||||
| 				logger.SysError("error reading stream response: " + err.Error()) | 				if err != nil { | ||||||
| 				break | 					logger.SysError("error reading stream response: " + err.Error()) | ||||||
|  | 					break | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 			var response ChatResponse | 			var response ChatResponse | ||||||
| 			err = json.Unmarshal(msg, &response) | 			err = json.Unmarshal(msg, &response) | ||||||
| @@ -227,6 +220,7 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, | |||||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
|  | 			msg = nil | ||||||
| 			dataChan <- response | 			dataChan <- response | ||||||
| 			if response.Payload.Choices.Status == 2 { | 			if response.Payload.Choices.Status == 2 { | ||||||
| 				err := conn.Close() | 				err := conn.Close() | ||||||
|   | |||||||
| @@ -5,20 +5,36 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel" | 	"github.com/songquanpeng/one-api/relay/channel" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
| 	"io" | 	"io" | ||||||
|  | 	"math" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Adaptor struct { | type Adaptor struct { | ||||||
|  | 	APIVersion string | ||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Adaptor) Init(meta *util.RelayMeta) { | func (a *Adaptor) Init(meta *util.RelayMeta) { | ||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) SetVersionByModeName(modelName string) { | ||||||
|  | 	if strings.HasPrefix(modelName, "glm-") { | ||||||
|  | 		a.APIVersion = "v4" | ||||||
|  | 	} else { | ||||||
|  | 		a.APIVersion = "v3" | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||||
|  | 	a.SetVersionByModeName(meta.ActualModelName) | ||||||
|  | 	if a.APIVersion == "v4" { | ||||||
|  | 		return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil | ||||||
|  | 	} | ||||||
| 	method := "invoke" | 	method := "invoke" | ||||||
| 	if meta.IsStream { | 	if meta.IsStream { | ||||||
| 		method = "sse-invoke" | 		method = "sse-invoke" | ||||||
| @@ -37,6 +53,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | |||||||
| 	if request == nil { | 	if request == nil { | ||||||
| 		return nil, errors.New("request is nil") | 		return nil, errors.New("request is nil") | ||||||
| 	} | 	} | ||||||
|  | 	// TopP (0.0, 1.0) | ||||||
|  | 	request.TopP = math.Min(0.99, request.TopP) | ||||||
|  | 	request.TopP = math.Max(0.01, request.TopP) | ||||||
|  |  | ||||||
|  | 	// Temperature (0.0, 1.0) | ||||||
|  | 	request.Temperature = math.Min(0.99, request.Temperature) | ||||||
|  | 	request.Temperature = math.Max(0.01, request.Temperature) | ||||||
|  | 	a.SetVersionByModeName(request.Model) | ||||||
|  | 	if a.APIVersion == "v4" { | ||||||
|  | 		return request, nil | ||||||
|  | 	} | ||||||
| 	return ConvertRequest(*request), nil | 	return ConvertRequest(*request), nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -44,7 +71,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io | |||||||
| 	return channel.DoRequestHelper(a, c, meta, requestBody) | 	return channel.DoRequestHelper(a, c, meta, requestBody) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||||
|  | 	if meta.IsStream { | ||||||
|  | 		err, _, usage = openai.StreamHandler(c, resp, meta.Mode) | ||||||
|  | 	} else { | ||||||
|  | 		err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||||
|  | 	if a.APIVersion == "v4" { | ||||||
|  | 		return a.DoResponseV4(c, resp, meta) | ||||||
|  | 	} | ||||||
| 	if meta.IsStream { | 	if meta.IsStream { | ||||||
| 		err, usage = StreamHandler(c, resp) | 		err, usage = StreamHandler(c, resp) | ||||||
| 	} else { | 	} else { | ||||||
|   | |||||||
| @@ -2,4 +2,5 @@ package zhipu | |||||||
|  |  | ||||||
| var ModelList = []string{ | var ModelList = []string{ | ||||||
| 	"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", | 	"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", | ||||||
|  | 	"glm-4", "glm-4v", "glm-3-turbo", | ||||||
| } | } | ||||||
|   | |||||||
| @@ -76,21 +76,10 @@ func GetToken(apikey string) string { | |||||||
| func ConvertRequest(request model.GeneralOpenAIRequest) *Request { | func ConvertRequest(request model.GeneralOpenAIRequest) *Request { | ||||||
| 	messages := make([]Message, 0, len(request.Messages)) | 	messages := make([]Message, 0, len(request.Messages)) | ||||||
| 	for _, message := range request.Messages { | 	for _, message := range request.Messages { | ||||||
| 		if message.Role == "system" { | 		messages = append(messages, Message{ | ||||||
| 			messages = append(messages, Message{ | 			Role:    message.Role, | ||||||
| 				Role:    "system", | 			Content: message.StringContent(), | ||||||
| 				Content: message.StringContent(), | 		}) | ||||||
| 			}) |  | ||||||
| 			messages = append(messages, Message{ |  | ||||||
| 				Role:    "user", |  | ||||||
| 				Content: "Okay", |  | ||||||
| 			}) |  | ||||||
| 		} else { |  | ||||||
| 			messages = append(messages, Message{ |  | ||||||
| 				Role:    message.Role, |  | ||||||
| 				Content: message.StringContent(), |  | ||||||
| 			}) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 	return &Request{ | 	return &Request{ | ||||||
| 		Prompt:      messages, | 		Prompt:      messages, | ||||||
|   | |||||||
| @@ -15,6 +15,7 @@ const ( | |||||||
| 	APITypeAIProxyLibrary | 	APITypeAIProxyLibrary | ||||||
| 	APITypeTencent | 	APITypeTencent | ||||||
| 	APITypeGemini | 	APITypeGemini | ||||||
|  | 	APITypeOllama | ||||||
|  |  | ||||||
| 	APITypeDummy // this one is only for count, do not add any channel after this | 	APITypeDummy // this one is only for count, do not add any channel after this | ||||||
| ) | ) | ||||||
| @@ -40,6 +41,8 @@ func ChannelType2APIType(channelType int) int { | |||||||
| 		apiType = APITypeTencent | 		apiType = APITypeTencent | ||||||
| 	case common.ChannelTypeGemini: | 	case common.ChannelTypeGemini: | ||||||
| 		apiType = APITypeGemini | 		apiType = APITypeGemini | ||||||
|  | 	case common.ChannelTypeOllama: | ||||||
|  | 		apiType = APITypeOllama | ||||||
| 	} | 	} | ||||||
| 	return apiType | 	return apiType | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										24
									
								
								relay/constant/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								relay/constant/image.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | |||||||
|  | package constant | ||||||
|  |  | ||||||
|  | var DalleSizeRatios = map[string]map[string]float64{ | ||||||
|  | 	"dall-e-2": { | ||||||
|  | 		"256x256":   1, | ||||||
|  | 		"512x512":   1.125, | ||||||
|  | 		"1024x1024": 1.25, | ||||||
|  | 	}, | ||||||
|  | 	"dall-e-3": { | ||||||
|  | 		"1024x1024": 1, | ||||||
|  | 		"1024x1792": 2, | ||||||
|  | 		"1792x1024": 2, | ||||||
|  | 	}, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var DalleGenerationImageAmounts = map[string][2]int{ | ||||||
|  | 	"dall-e-2": {1, 10}, | ||||||
|  | 	"dall-e-3": {1, 1}, // OpenAI allows n=1 currently. | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var DalleImagePromptLengthLimitations = map[string]int{ | ||||||
|  | 	"dall-e-2": 1000, | ||||||
|  | 	"dall-e-3": 4000, | ||||||
|  | } | ||||||
| @@ -22,6 +22,7 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { | func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { | ||||||
|  | 	ctx := c.Request.Context() | ||||||
| 	audioModel := "whisper-1" | 	audioModel := "whisper-1" | ||||||
|  |  | ||||||
| 	tokenId := c.GetInt("token_id") | 	tokenId := c.GetInt("token_id") | ||||||
| @@ -49,16 +50,16 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | |||||||
| 	modelRatio := common.GetModelRatio(audioModel) | 	modelRatio := common.GetModelRatio(audioModel) | ||||||
| 	groupRatio := common.GetGroupRatio(group) | 	groupRatio := common.GetGroupRatio(group) | ||||||
| 	ratio := modelRatio * groupRatio | 	ratio := modelRatio * groupRatio | ||||||
| 	var quota int | 	var quota int64 | ||||||
| 	var preConsumedQuota int | 	var preConsumedQuota int64 | ||||||
| 	switch relayMode { | 	switch relayMode { | ||||||
| 	case constant.RelayModeAudioSpeech: | 	case constant.RelayModeAudioSpeech: | ||||||
| 		preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) | 		preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) | ||||||
| 		quota = preConsumedQuota | 		quota = preConsumedQuota | ||||||
| 	default: | 	default: | ||||||
| 		preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio) | 		preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio) | ||||||
| 	} | 	} | ||||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | 	userQuota, err := model.CacheGetUserQuota(ctx, userId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| @@ -82,6 +83,24 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | |||||||
| 			return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | 			return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 	succeed := false | ||||||
|  | 	defer func() { | ||||||
|  | 		if succeed { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		if preConsumedQuota > 0 { | ||||||
|  | 			// we need to roll back the pre-consumed quota | ||||||
|  | 			defer func(ctx context.Context) { | ||||||
|  | 				go func() { | ||||||
|  | 					// negative means add quota back for token & user | ||||||
|  | 					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) | ||||||
|  | 					if err != nil { | ||||||
|  | 						logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) | ||||||
|  | 					} | ||||||
|  | 				}() | ||||||
|  | 			}(c.Request.Context()) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  |  | ||||||
| 	// map model name | 	// map model name | ||||||
| 	modelMapping := c.GetString("model_mapping") | 	modelMapping := c.GetString("model_mapping") | ||||||
| @@ -103,10 +122,15 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) | 	fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) | ||||||
| 	if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { | 	if channelType == common.ChannelTypeAzure { | ||||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api |  | ||||||
| 		apiVersion := util.GetAzureAPIVersion(c) | 		apiVersion := util.GetAzureAPIVersion(c) | ||||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) | 		if relayMode == constant.RelayModeAudioTranscription { | ||||||
|  | 			// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api | ||||||
|  | 			fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) | ||||||
|  | 		} else if relayMode == constant.RelayModeAudioSpeech { | ||||||
|  | 			// https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api | ||||||
|  | 			fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	requestBody := &bytes.Buffer{} | 	requestBody := &bytes.Buffer{} | ||||||
| @@ -122,7 +146,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | |||||||
| 		return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { | 	if (relayMode == constant.RelayModeAudioTranscription || relayMode == constant.RelayModeAudioSpeech) && channelType == common.ChannelTypeAzure { | ||||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api | 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api | ||||||
| 		apiKey := c.Request.Header.Get("Authorization") | 		apiKey := c.Request.Header.Get("Authorization") | ||||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||||
| @@ -183,24 +207,13 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		quota = openai.CountTokenText(text, audioModel) | 		quota = int64(openai.CountTokenText(text, audioModel)) | ||||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||||
| 	} | 	} | ||||||
| 	if resp.StatusCode != http.StatusOK { | 	if resp.StatusCode != http.StatusOK { | ||||||
| 		if preConsumedQuota > 0 { |  | ||||||
| 			// we need to roll back the pre-consumed quota |  | ||||||
| 			defer func(ctx context.Context) { |  | ||||||
| 				go func() { |  | ||||||
| 					// negative means add quota back for token & user |  | ||||||
| 					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) |  | ||||||
| 					if err != nil { |  | ||||||
| 						logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) |  | ||||||
| 					} |  | ||||||
| 				}() |  | ||||||
| 			}(c.Request.Context()) |  | ||||||
| 		} |  | ||||||
| 		return util.RelayErrorHandler(resp) | 		return util.RelayErrorHandler(resp) | ||||||
| 	} | 	} | ||||||
|  | 	succeed = true | ||||||
| 	quotaDelta := quota - preConsumedQuota | 	quotaDelta := quota - preConsumedQuota | ||||||
| 	defer func(ctx context.Context) { | 	defer func(ctx context.Context) { | ||||||
| 		go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) | 		go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) | ||||||
|   | |||||||
| @@ -36,6 +36,65 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener | |||||||
| 	return textRequest, nil | 	return textRequest, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) { | ||||||
|  | 	imageRequest := &openai.ImageRequest{} | ||||||
|  | 	err := common.UnmarshalBodyReusable(c, imageRequest) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if imageRequest.N == 0 { | ||||||
|  | 		imageRequest.N = 1 | ||||||
|  | 	} | ||||||
|  | 	if imageRequest.Size == "" { | ||||||
|  | 		imageRequest.Size = "1024x1024" | ||||||
|  | 	} | ||||||
|  | 	if imageRequest.Model == "" { | ||||||
|  | 		imageRequest.Model = "dall-e-2" | ||||||
|  | 	} | ||||||
|  | 	return imageRequest, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode { | ||||||
|  | 	// model validation | ||||||
|  | 	_, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] | ||||||
|  | 	if !hasValidSize { | ||||||
|  | 		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | ||||||
|  | 	} | ||||||
|  | 	// check prompt length | ||||||
|  | 	if imageRequest.Prompt == "" { | ||||||
|  | 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||||
|  | 	} | ||||||
|  | 	if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] { | ||||||
|  | 		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | ||||||
|  | 	} | ||||||
|  | 	// Number of generated images validation | ||||||
|  | 	if !isWithinRange(imageRequest.Model, imageRequest.N) { | ||||||
|  | 		// channel not azure | ||||||
|  | 		if meta.ChannelType != common.ChannelTypeAzure { | ||||||
|  | 			return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) { | ||||||
|  | 	if imageRequest == nil { | ||||||
|  | 		return 0, errors.New("imageRequest is nil") | ||||||
|  | 	} | ||||||
|  | 	imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] | ||||||
|  | 	if !hasValidSize { | ||||||
|  | 		return 0, fmt.Errorf("size not supported for this image model: %s", imageRequest.Size) | ||||||
|  | 	} | ||||||
|  | 	if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { | ||||||
|  | 		if imageRequest.Size == "1024x1024" { | ||||||
|  | 			imageCostRatio *= 2 | ||||||
|  | 		} else { | ||||||
|  | 			imageCostRatio *= 1.5 | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return imageCostRatio, nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { | func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { | ||||||
| 	switch relayMode { | 	switch relayMode { | ||||||
| 	case constant.RelayModeChatCompletions: | 	case constant.RelayModeChatCompletions: | ||||||
| @@ -48,18 +107,18 @@ func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int | |||||||
| 	return 0 | 	return 0 | ||||||
| } | } | ||||||
|  |  | ||||||
| func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int { | func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int64 { | ||||||
| 	preConsumedTokens := config.PreConsumedQuota | 	preConsumedTokens := config.PreConsumedQuota | ||||||
| 	if textRequest.MaxTokens != 0 { | 	if textRequest.MaxTokens != 0 { | ||||||
| 		preConsumedTokens = promptTokens + textRequest.MaxTokens | 		preConsumedTokens = int64(promptTokens) + int64(textRequest.MaxTokens) | ||||||
| 	} | 	} | ||||||
| 	return int(float64(preConsumedTokens) * ratio) | 	return int64(float64(preConsumedTokens) * ratio) | ||||||
| } | } | ||||||
|  |  | ||||||
| func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *relaymodel.ErrorWithStatusCode) { | func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int64, *relaymodel.ErrorWithStatusCode) { | ||||||
| 	preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) | 	preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) | ||||||
|  |  | ||||||
| 	userQuota, err := model.CacheGetUserQuota(meta.UserId) | 	userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | 		return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| @@ -85,16 +144,16 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR | |||||||
| 	return preConsumedQuota, nil | 	return preConsumedQuota, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) { | func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { | ||||||
| 	if usage == nil { | 	if usage == nil { | ||||||
| 		logger.Error(ctx, "usage is nil, which is unexpected") | 		logger.Error(ctx, "usage is nil, which is unexpected") | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	quota := 0 | 	var quota int64 | ||||||
| 	completionRatio := common.GetCompletionRatio(textRequest.Model) | 	completionRatio := common.GetCompletionRatio(textRequest.Model) | ||||||
| 	promptTokens := usage.PromptTokens | 	promptTokens := usage.PromptTokens | ||||||
| 	completionTokens := usage.CompletionTokens | 	completionTokens := usage.CompletionTokens | ||||||
| 	quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) | 	quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) | ||||||
| 	if ratio != 0 && quota <= 0 { | 	if ratio != 0 && quota <= 0 { | ||||||
| 		quota = 1 | 		quota = 1 | ||||||
| 	} | 	} | ||||||
| @@ -109,14 +168,12 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Error(ctx, "error consuming token remain quota: "+err.Error()) | 		logger.Error(ctx, "error consuming token remain quota: "+err.Error()) | ||||||
| 	} | 	} | ||||||
| 	err = model.CacheUpdateUserQuota(meta.UserId) | 	err = model.CacheUpdateUserQuota(ctx, meta.UserId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Error(ctx, "error update user quota cache: "+err.Error()) | 		logger.Error(ctx, "error update user quota cache: "+err.Error()) | ||||||
| 	} | 	} | ||||||
| 	if quota != 0 { | 	logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) | ||||||
| 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) | 	model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) | ||||||
| 		model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) | 	model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | ||||||
| 		model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | 	model.UpdateChannelUsedQuota(meta.ChannelId, quota) | ||||||
| 		model.UpdateChannelUsedQuota(meta.ChannelId, quota) |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -10,6 +10,7 @@ import ( | |||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"github.com/songquanpeng/one-api/model" | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/constant" | ||||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" | 	"github.com/songquanpeng/one-api/relay/util" | ||||||
| 	"io" | 	"io" | ||||||
| @@ -20,122 +21,67 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| func isWithinRange(element string, value int) bool { | func isWithinRange(element string, value int) bool { | ||||||
| 	if _, ok := common.DalleGenerationImageAmounts[element]; !ok { | 	if _, ok := constant.DalleGenerationImageAmounts[element]; !ok { | ||||||
| 		return false | 		return false | ||||||
| 	} | 	} | ||||||
| 	min := common.DalleGenerationImageAmounts[element][0] | 	min := constant.DalleGenerationImageAmounts[element][0] | ||||||
| 	max := common.DalleGenerationImageAmounts[element][1] | 	max := constant.DalleGenerationImageAmounts[element][1] | ||||||
|  |  | ||||||
| 	return value >= min && value <= max | 	return value >= min && value <= max | ||||||
| } | } | ||||||
|  |  | ||||||
| func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { | func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { | ||||||
| 	imageModel := "dall-e-2" | 	ctx := c.Request.Context() | ||||||
| 	imageSize := "1024x1024" | 	meta := util.GetRelayMeta(c) | ||||||
|  | 	imageRequest, err := getImageRequest(c, meta.Mode) | ||||||
| 	tokenId := c.GetInt("token_id") |  | ||||||
| 	channelType := c.GetInt("channel") |  | ||||||
| 	channelId := c.GetInt("channel_id") |  | ||||||
| 	userId := c.GetInt("id") |  | ||||||
| 	group := c.GetString("group") |  | ||||||
|  |  | ||||||
| 	var imageRequest openai.ImageRequest |  | ||||||
| 	err := common.UnmarshalBodyReusable(c, &imageRequest) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | 		logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) | ||||||
| 	} | 		return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) | ||||||
|  |  | ||||||
| 	if imageRequest.N == 0 { |  | ||||||
| 		imageRequest.N = 1 |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Size validation |  | ||||||
| 	if imageRequest.Size != "" { |  | ||||||
| 		imageSize = imageRequest.Size |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Model validation |  | ||||||
| 	if imageRequest.Model != "" { |  | ||||||
| 		imageModel = imageRequest.Model |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] |  | ||||||
|  |  | ||||||
| 	// Check if model is supported |  | ||||||
| 	if hasValidSize { |  | ||||||
| 		if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { |  | ||||||
| 			if imageSize == "1024x1024" { |  | ||||||
| 				imageCostRatio *= 2 |  | ||||||
| 			} else { |  | ||||||
| 				imageCostRatio *= 1.5 |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} else { |  | ||||||
| 		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Prompt validation |  | ||||||
| 	if imageRequest.Prompt == "" { |  | ||||||
| 		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Check prompt length |  | ||||||
| 	if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { |  | ||||||
| 		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Number of generated images validation |  | ||||||
| 	if !isWithinRange(imageModel, imageRequest.N) { |  | ||||||
| 		// channel not azure |  | ||||||
| 		if channelType != common.ChannelTypeAzure { |  | ||||||
| 			return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// map model name | 	// map model name | ||||||
| 	modelMapping := c.GetString("model_mapping") | 	var isModelMapped bool | ||||||
| 	isModelMapped := false | 	meta.OriginModelName = imageRequest.Model | ||||||
| 	if modelMapping != "" { | 	imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping) | ||||||
| 		modelMap := make(map[string]string) | 	meta.ActualModelName = imageRequest.Model | ||||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) |  | ||||||
| 		if err != nil { | 	// model validation | ||||||
| 			return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | 	bizErr := validateImageRequest(imageRequest, meta) | ||||||
| 		} | 	if bizErr != nil { | ||||||
| 		if modelMap[imageModel] != "" { | 		return bizErr | ||||||
| 			imageModel = modelMap[imageModel] |  | ||||||
| 			isModelMapped = true |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 	baseURL := common.ChannelBaseURLs[channelType] |  | ||||||
|  | 	imageCostRatio, err := getImageCostRatio(imageRequest) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	requestURL := c.Request.URL.String() | 	requestURL := c.Request.URL.String() | ||||||
| 	if c.GetString("base_url") != "" { | 	fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) | ||||||
| 		baseURL = c.GetString("base_url") | 	if meta.ChannelType == common.ChannelTypeAzure { | ||||||
| 	} |  | ||||||
| 	fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) |  | ||||||
| 	if channelType == common.ChannelTypeAzure { |  | ||||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api | 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api | ||||||
| 		apiVersion := util.GetAzureAPIVersion(c) | 		apiVersion := util.GetAzureAPIVersion(c) | ||||||
| 		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview | 		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview | ||||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) | 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var requestBody io.Reader | 	var requestBody io.Reader | ||||||
| 	if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body | 	if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body | ||||||
| 		jsonStr, err := json.Marshal(imageRequest) | 		jsonStr, err := json.Marshal(imageRequest) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
| 	} else { | 	} else { | ||||||
| 		requestBody = c.Request.Body | 		requestBody = c.Request.Body | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	modelRatio := common.GetModelRatio(imageModel) | 	modelRatio := common.GetModelRatio(imageRequest.Model) | ||||||
| 	groupRatio := common.GetGroupRatio(group) | 	groupRatio := common.GetGroupRatio(meta.Group) | ||||||
| 	ratio := modelRatio * groupRatio | 	ratio := modelRatio * groupRatio | ||||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | 	userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) | ||||||
|  |  | ||||||
| 	quota := int(ratio*imageCostRatio*1000) * imageRequest.N | 	quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) | ||||||
|  |  | ||||||
| 	if userQuota-quota < 0 { | 	if userQuota-quota < 0 { | ||||||
| 		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | 		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||||
| @@ -146,7 +92,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | |||||||
| 		return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	token := c.Request.Header.Get("Authorization") | 	token := c.Request.Header.Get("Authorization") | ||||||
| 	if channelType == common.ChannelTypeAzure { // Azure authentication | 	if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication | ||||||
| 		token = strings.TrimPrefix(token, "Bearer ") | 		token = strings.TrimPrefix(token, "Bearer ") | ||||||
| 		req.Header.Set("api-key", token) | 		req.Header.Set("api-key", token) | ||||||
| 	} else { | 	} else { | ||||||
| @@ -169,25 +115,25 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	var textResponse openai.ImageResponse | 	var imageResponse openai.ImageResponse | ||||||
|  |  | ||||||
| 	defer func(ctx context.Context) { | 	defer func(ctx context.Context) { | ||||||
| 		if resp.StatusCode != http.StatusOK { | 		if resp.StatusCode != http.StatusOK { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		err := model.PostConsumeTokenQuota(tokenId, quota) | 		err := model.PostConsumeTokenQuota(meta.TokenId, quota) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.SysError("error consuming token remain quota: " + err.Error()) | 			logger.SysError("error consuming token remain quota: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 		err = model.CacheUpdateUserQuota(userId) | 		err = model.CacheUpdateUserQuota(ctx, meta.UserId) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.SysError("error update user quota cache: " + err.Error()) | 			logger.SysError("error update user quota cache: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 		if quota != 0 { | 		if quota != 0 { | ||||||
| 			tokenName := c.GetString("token_name") | 			tokenName := c.GetString("token_name") | ||||||
| 			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | 			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||||
| 			model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) | 			model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) | ||||||
| 			model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | 			model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | ||||||
| 			channelId := c.GetInt("channel_id") | 			channelId := c.GetInt("channel_id") | ||||||
| 			model.UpdateChannelUsedQuota(channelId, quota) | 			model.UpdateChannelUsedQuota(channelId, quota) | ||||||
| 		} | 		} | ||||||
| @@ -202,7 +148,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(responseBody, &textResponse) | 	err = json.Unmarshal(responseBody, &imageResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -55,7 +55,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | |||||||
| 	var requestBody io.Reader | 	var requestBody io.Reader | ||||||
| 	if meta.APIType == constant.APITypeOpenAI { | 	if meta.APIType == constant.APITypeOpenAI { | ||||||
| 		// no need to convert request for openai | 		// no need to convert request for openai | ||||||
| 		if isModelMapped { | 		shouldResetRequestBody := isModelMapped || meta.ChannelType == common.ChannelTypeBaichuan // frequency_penalty 0 is not acceptable for baichuan | ||||||
|  | 		if shouldResetRequestBody { | ||||||
| 			jsonStr, err := json.Marshal(textRequest) | 			jsonStr, err := json.Marshal(textRequest) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) | 				return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) | ||||||
| @@ -73,6 +74,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) | 			return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
|  | 		logger.Debugf(ctx, "converted request: \n%s", string(jsonData)) | ||||||
| 		requestBody = bytes.NewBuffer(jsonData) | 		requestBody = bytes.NewBuffer(jsonData) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -82,11 +84,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | |||||||
| 		logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) | 		logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) | ||||||
| 		return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) | 		return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| 	meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | 	errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json") | ||||||
| 	if resp.StatusCode != http.StatusOK { | 	if errorHappened { | ||||||
| 		util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) | 		util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) | ||||||
| 		return util.RelayErrorHandler(resp) | 		return util.RelayErrorHandler(resp) | ||||||
| 	} | 	} | ||||||
|  | 	meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||||
|  |  | ||||||
| 	// do response | 	// do response | ||||||
| 	usage, respErr := adaptor.DoResponse(c, resp, meta) | 	usage, respErr := adaptor.DoResponse(c, resp, meta) | ||||||
|   | |||||||
| @@ -7,6 +7,7 @@ import ( | |||||||
| 	"github.com/songquanpeng/one-api/relay/channel/anthropic" | 	"github.com/songquanpeng/one-api/relay/channel/anthropic" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/baidu" | 	"github.com/songquanpeng/one-api/relay/channel/baidu" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/gemini" | 	"github.com/songquanpeng/one-api/relay/channel/gemini" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channel/ollama" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/palm" | 	"github.com/songquanpeng/one-api/relay/channel/palm" | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/tencent" | 	"github.com/songquanpeng/one-api/relay/channel/tencent" | ||||||
| @@ -37,6 +38,8 @@ func GetAdaptor(apiType int) channel.Adaptor { | |||||||
| 		return &xunfei.Adaptor{} | 		return &xunfei.Adaptor{} | ||||||
| 	case constant.APITypeZhipu: | 	case constant.APITypeZhipu: | ||||||
| 		return &zhipu.Adaptor{} | 		return &zhipu.Adaptor{} | ||||||
|  | 	case constant.APITypeOllama: | ||||||
|  | 		return &ollama.Adaptor{} | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -6,7 +6,7 @@ import ( | |||||||
| 	"github.com/songquanpeng/one-api/model" | 	"github.com/songquanpeng/one-api/model" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) { | func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) { | ||||||
| 	if preConsumedQuota != 0 { | 	if preConsumedQuota != 0 { | ||||||
| 		go func(ctx context.Context) { | 		go func(ctx context.Context) { | ||||||
| 			// return pre-consumed quota | 			// return pre-consumed quota | ||||||
|   | |||||||
| @@ -27,7 +27,23 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { | |||||||
| 	if statusCode == http.StatusUnauthorized { | 	if statusCode == http.StatusUnauthorized { | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
| 	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | 	switch err.Type { | ||||||
|  | 	case "insufficient_quota": | ||||||
|  | 		return true | ||||||
|  | 	// https://docs.anthropic.com/claude/reference/errors | ||||||
|  | 	case "authentication_error": | ||||||
|  | 		return true | ||||||
|  | 	case "permission_error": | ||||||
|  | 		return true | ||||||
|  | 	case "forbidden": | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic | ||||||
|  | 		return true | ||||||
|  | 	} else if strings.HasPrefix(err.Message, "This organization has been disabled.") { | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
| 	return false | 	return false | ||||||
| @@ -101,6 +117,9 @@ func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.Err | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  | 	if config.DebugEnabled { | ||||||
|  | 		logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody))) | ||||||
|  | 	} | ||||||
| 	err = resp.Body.Close() | 	err = resp.Body.Close() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return | ||||||
| @@ -136,20 +155,20 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin | |||||||
| 	return fullRequestURL | 	return fullRequestURL | ||||||
| } | } | ||||||
|  |  | ||||||
| func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { | func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { | ||||||
| 	// quotaDelta is remaining quota to be consumed | 	// quotaDelta is remaining quota to be consumed | ||||||
| 	err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | 	err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("error consuming token remain quota: " + err.Error()) | 		logger.SysError("error consuming token remain quota: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	err = model.CacheUpdateUserQuota(userId) | 	err = model.CacheUpdateUserQuota(ctx, userId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("error update user quota cache: " + err.Error()) | 		logger.SysError("error update user quota cache: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	// totalQuota is total quota consumed | 	// totalQuota is total quota consumed | ||||||
| 	if totalQuota != 0 { | 	if totalQuota != 0 { | ||||||
| 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||||
| 		model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) | 		model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent) | ||||||
| 		model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) | 		model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) | ||||||
| 		model.UpdateChannelUsedQuota(channelId, totalQuota) | 		model.UpdateChannelUsedQuota(channelId, totalQuota) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -14,6 +14,7 @@ func SetApiRouter(router *gin.Engine) { | |||||||
| 	apiRouter.Use(middleware.GlobalAPIRateLimit()) | 	apiRouter.Use(middleware.GlobalAPIRateLimit()) | ||||||
| 	{ | 	{ | ||||||
| 		apiRouter.GET("/status", controller.GetStatus) | 		apiRouter.GET("/status", controller.GetStatus) | ||||||
|  | 		apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels) | ||||||
| 		apiRouter.GET("/notice", controller.GetNotice) | 		apiRouter.GET("/notice", controller.GetNotice) | ||||||
| 		apiRouter.GET("/about", controller.GetAbout) | 		apiRouter.GET("/about", controller.GetAbout) | ||||||
| 		apiRouter.GET("/home_page_content", controller.GetHomePageContent) | 		apiRouter.GET("/home_page_content", controller.GetHomePageContent) | ||||||
| @@ -69,7 +70,7 @@ func SetApiRouter(router *gin.Engine) { | |||||||
| 			channelRoute.GET("/search", controller.SearchChannels) | 			channelRoute.GET("/search", controller.SearchChannels) | ||||||
| 			channelRoute.GET("/models", controller.ListModels) | 			channelRoute.GET("/models", controller.ListModels) | ||||||
| 			channelRoute.GET("/:id", controller.GetChannel) | 			channelRoute.GET("/:id", controller.GetChannel) | ||||||
| 			channelRoute.GET("/test", controller.TestAllChannels) | 			channelRoute.GET("/test", controller.TestChannels) | ||||||
| 			channelRoute.GET("/test/:id", controller.TestChannel) | 			channelRoute.GET("/test/:id", controller.TestChannel) | ||||||
| 			channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance) | 			channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance) | ||||||
| 			channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) | 			channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) | ||||||
|   | |||||||
| @@ -9,6 +9,7 @@ import ( | |||||||
|  |  | ||||||
| func SetDashboardRouter(router *gin.Engine) { | func SetDashboardRouter(router *gin.Engine) { | ||||||
| 	apiRouter := router.Group("/") | 	apiRouter := router.Group("/") | ||||||
|  | 	apiRouter.Use(middleware.CORS()) | ||||||
| 	apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) | 	apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) | ||||||
| 	apiRouter.Use(middleware.GlobalAPIRateLimit()) | 	apiRouter.Use(middleware.GlobalAPIRateLimit()) | ||||||
| 	apiRouter.Use(middleware.TokenAuth()) | 	apiRouter.Use(middleware.TokenAuth()) | ||||||
|   | |||||||
| @@ -9,7 +9,7 @@ | |||||||
| 1. 在 `web` 文件夹下新建一个文件夹,文件夹名为主题名。 | 1. 在 `web` 文件夹下新建一个文件夹,文件夹名为主题名。 | ||||||
| 2. 把你的主题文件放到这个文件夹下。 | 2. 把你的主题文件放到这个文件夹下。 | ||||||
| 3. 修改你的 `package.json` 文件,把 `build` 命令改为:`"build": "react-scripts build && mv -f build ../build/default"`,其中 `default` 为你的主题名。 | 3. 修改你的 `package.json` 文件,把 `build` 命令改为:`"build": "react-scripts build && mv -f build ../build/default"`,其中 `default` 为你的主题名。 | ||||||
| 4. 修改 `common/constants.go` 中的 `ValidThemes`,把你的主题名称注册进去。 | 4. 修改 `common/config/config.go` 中的 `ValidThemes`,把你的主题名称注册进去。 | ||||||
| 5. 修改 `web/THEMES` 文件,这里也需要同步修改。 | 5. 修改 `web/THEMES` 文件,这里也需要同步修改。 | ||||||
|  |  | ||||||
| ## 主题列表 | ## 主题列表 | ||||||
| @@ -33,6 +33,12 @@ | |||||||
| ||| | ||| | ||||||
| ||| | ||| | ||||||
|  |  | ||||||
|  | ### 主题:air | ||||||
|  | 由 [Calon](https://github.com/Calcium-Ion) 开发。 | ||||||
|  | ||| | ||||||
|  | |:---:|:---:| | ||||||
|  |  | ||||||
|  |  | ||||||
| #### 开发说明 | #### 开发说明 | ||||||
|  |  | ||||||
| 请查看 [web/berry/README.md](https://github.com/songquanpeng/one-api/tree/main/web/berry/README.md) | 请查看 [web/berry/README.md](https://github.com/songquanpeng/one-api/tree/main/web/berry/README.md) | ||||||
|   | |||||||
| @@ -1,2 +1,3 @@ | |||||||
| default | default | ||||||
| berry | berry | ||||||
|  | air | ||||||
|   | |||||||
							
								
								
									
										26
									
								
								web/air/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								web/air/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | |||||||
|  | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. | ||||||
|  |  | ||||||
|  | # dependencies | ||||||
|  | /node_modules | ||||||
|  | /.pnp | ||||||
|  | .pnp.js | ||||||
|  |  | ||||||
|  | # testing | ||||||
|  | /coverage | ||||||
|  |  | ||||||
|  | # production | ||||||
|  | /build | ||||||
|  |  | ||||||
|  | # misc | ||||||
|  | .DS_Store | ||||||
|  | .env.local | ||||||
|  | .env.development.local | ||||||
|  | .env.test.local | ||||||
|  | .env.production.local | ||||||
|  |  | ||||||
|  | npm-debug.log* | ||||||
|  | yarn-debug.log* | ||||||
|  | yarn-error.log* | ||||||
|  | .idea | ||||||
|  | package-lock.json | ||||||
|  | yarn.lock | ||||||
							
								
								
									
										21
									
								
								web/air/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								web/air/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | |||||||
|  | # React Template | ||||||
|  |  | ||||||
|  | ## Basic Usages | ||||||
|  |  | ||||||
|  | ```shell | ||||||
|  | # Runs the app in the development mode | ||||||
|  | npm start | ||||||
|  |  | ||||||
|  | # Builds the app for production to the `build` folder | ||||||
|  | npm run build | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | If you want to change the default server, please set `REACT_APP_SERVER` environment variables before build, | ||||||
|  | for example: `REACT_APP_SERVER=http://your.domain.com`. | ||||||
|  |  | ||||||
|  | Before you start editing, make sure your `Actions on Save` options have `Optimize imports` & `Run Prettier` enabled. | ||||||
|  |  | ||||||
|  | ## Reference | ||||||
|  |  | ||||||
|  | 1. https://github.com/OIerDb-ng/OIerDb | ||||||
|  | 2. https://github.com/cornflourblue/react-hooks-redux-registration-login-example | ||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user