mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-26 11:23:43 +08:00 
			
		
		
		
	Compare commits
	
		
			141 Commits
		
	
	
		
			v0.4.7-alp
			...
			v0.5.3-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 5a62357c93 | ||
|  | b464e2907a | ||
|  | d96cf2e84d | ||
|  | 446337c329 | ||
|  | 1dfa190e79 | ||
|  | 2d49ca6a07 | ||
|  | 89bcaaf989 | ||
|  | afcd1bd27b | ||
|  | c2c455c980 | ||
|  | 30a7f1a1c7 | ||
|  | c9d2e42a9e | ||
|  | 3fca6ff534 | ||
|  | 8cbbeb784f | ||
|  | ec88c0c240 | ||
|  | 065147b440 | ||
|  | fe8f216dd9 | ||
|  | b7d0616ae0 | ||
|  | ce9c8024a6 | ||
|  | 8a866078b2 | ||
|  | 3e81d8af45 | ||
|  | b8cb86c2c1 | ||
|  | f45d586400 | ||
|  | 50dec03ff3 | ||
|  | f31d400b6f | ||
|  | 130e6bfd83 | ||
|  | d1335ebc01 | ||
|  | e92da7928b | ||
|  | d1b6f492b6 | ||
|  | b9f6461dd4 | ||
|  | 0a39521a3d | ||
|  | c134604cee | ||
|  | 929e43ef81 | ||
|  | dce8bbe1ca | ||
|  | bc2f48b1f2 | ||
|  | 889af8b2db | ||
|  | 4eea096654 | ||
|  | 4ab3211c0e | ||
|  | 3da119efba | ||
|  | dccd66b852 | ||
|  | 2fcd6852e0 | ||
|  | 9b4d1964d4 | ||
|  | 806bf8241c | ||
|  | ce93c9b6b2 | ||
|  | 4ec4289565 | ||
|  | 3dc5a0f91d | ||
|  | 80a846673a | ||
|  | 26c6719ea3 | ||
|  | c87e05bfc2 | ||
|  | e6938bd236 | ||
|  | 8f721d67a5 | ||
|  | fcc1e2d568 | ||
|  | 9a1db61675 | ||
|  | 3c940113ab | ||
|  | 0495b9a0d7 | ||
|  | 12a0e7105e | ||
|  | e628b643cd | ||
|  | 675847bf98 | ||
|  | 2ff15baf66 | ||
|  | 4139a7036f | ||
|  | 02da0b51f8 | ||
|  | 35cfebee12 | ||
|  | 0e088f7c3e | ||
|  | f61d326721 | ||
|  | 74b06b643a | ||
|  | ccf7709e23 | ||
|  | d592e2c8b8 | ||
|  | b520b54625 | ||
|  | 81c5901123 | ||
|  | abc53cb208 | ||
|  | 2b17bb8dd7 | ||
|  | ea73201b6f | ||
|  | 6215d2e71c | ||
|  | d17bdc40a7 | ||
|  | 280df27705 | ||
|  | 991f5bf4ee | ||
|  | 701aaba191 | ||
|  | 3bab5b48bf | ||
|  | f3bccee3b5 | ||
|  | d84b0b0f5d | ||
|  | d383302e8a | ||
|  | 04f40def2f | ||
|  | c48b7bc0f5 | ||
|  | b09daf5ec1 | ||
|  | c90c0ecef4 | ||
|  | 1ab5fb7d2d | ||
|  | f769711c19 | ||
|  | edc5156693 | ||
|  | 9ec6506c32 | ||
|  | f387cc5ead | ||
|  | 569b68c43b | ||
|  | f0c40a6cd0 | ||
|  | 0cea9e6a6f | ||
|  | b1b3651e84 | ||
|  | 8f6bd51f58 | ||
|  | bddbf57104 | ||
|  | 9a16b0f9e5 | ||
|  | 3530309a31 | ||
|  | 733ebc067b | ||
|  | 6a8567ac14 | ||
|  | aabc546691 | ||
|  | 1c82b06f35 | ||
|  | 9e4109672a | ||
|  | 64c35334e6 | ||
|  | 0ce572b405 | ||
|  | a326ac4b28 | ||
|  | 05b0e77839 | ||
|  | 51f19470bc | ||
|  | 737672fb0b | ||
|  | 0941e294bf | ||
|  | 431d505f79 | ||
|  | f0dc7f3f06 | ||
|  | 99fed1f850 | ||
|  | 4dc5388a80 | ||
|  | f81f4c60b2 | ||
|  | c613d8b6b2 | ||
|  | 7adac1c09c | ||
|  | 6f05128368 | ||
|  | 9b178a28a3 | ||
|  | 4a6a7f4635 | ||
|  | 6b1a24d650 | ||
|  | 94ba3dd024 | ||
|  | f6eb4e5628 | ||
|  | 57bd907f83 | ||
|  | dd8e8d5ee8 | ||
|  | 1ca1aa0cdc | ||
|  | f2ba0c0300 | ||
|  | f5c1fcd3c3 | ||
|  | 5fdf670a19 | ||
|  | 3ce982d8ee | ||
|  | a515f9284e | ||
|  | cccf5e4a07 | ||
|  | b0bfb9c9a1 | ||
|  | 3aff61a973 | ||
|  | 0fd1ff4d9e | ||
|  | e2777bf73e | ||
|  | 77a16e6415 | ||
|  | 827942c8a9 | ||
|  | 604ff20541 | ||
|  | 25017219f5 | ||
|  | 2dd4ad0e06 | ||
|  | 61dc117da7 | 
							
								
								
									
										4
									
								
								.github/ISSUE_TEMPLATE/bug_report.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/ISSUE_TEMPLATE/bug_report.md
									
									
									
									
										vendored
									
									
								
							| @@ -8,11 +8,13 @@ assignees: '' | ||||
| --- | ||||
|  | ||||
| **例行检查** | ||||
|  | ||||
| [//]: # (方框内删除已有的空格,填 x 号) | ||||
| + [ ] 我已确认目前没有类似 issue | ||||
| + [ ] 我已确认我已升级到最新版本 | ||||
| + [ ] 我已完整查看过项目 README,尤其是常见问题部分 | ||||
| + [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈  | ||||
| + [ ] 我理解并认可上述内容,并理解项目维护者精力有限,不遵循规则的 issue 可能会被无视或直接关闭 | ||||
| + [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭** | ||||
|  | ||||
| **问题描述** | ||||
|  | ||||
|   | ||||
							
								
								
									
										3
									
								
								.github/ISSUE_TEMPLATE/config.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/ISSUE_TEMPLATE/config.yml
									
									
									
									
										vendored
									
									
								
							| @@ -6,6 +6,3 @@ contact_links: | ||||
|   - name: 赞赏支持 | ||||
|     url: https://iamazing.cn/page/reward | ||||
|     about: 请作者喝杯咖啡,以激励作者持续开发 | ||||
|   - name: 付费部署或定制功能 | ||||
|     url: https://openai.justsong.cn/ | ||||
|     about: 加群后联系群主 | ||||
|   | ||||
							
								
								
									
										5
									
								
								.github/ISSUE_TEMPLATE/feature_request.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/ISSUE_TEMPLATE/feature_request.md
									
									
									
									
										vendored
									
									
								
							| @@ -8,10 +8,13 @@ assignees: '' | ||||
| --- | ||||
|  | ||||
| **例行检查** | ||||
|  | ||||
| [//]: # (方框内删除已有的空格,填 x 号) | ||||
| + [ ] 我已确认目前没有类似 issue | ||||
| + [ ] 我已确认我已升级到最新版本 | ||||
| + [ ] 我已完整查看过项目 README,已确定现有版本无法满足需求 | ||||
| + [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈 | ||||
| + [ ] 我理解并认可上述内容,并理解项目维护者精力有限,不遵循规则的 issue 可能会被无视或直接关闭 | ||||
| + [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭** | ||||
|  | ||||
| **功能描述** | ||||
|  | ||||
|   | ||||
							
								
								
									
										10
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							| @@ -1,4 +1,4 @@ | ||||
| name: Publish Docker image (amd64) | ||||
| name: Publish Docker image (amd64, English) | ||||
|  | ||||
| on: | ||||
|   push: | ||||
| @@ -33,20 +33,12 @@ jobs: | ||||
|           username: ${{ secrets.DOCKERHUB_USERNAME }} | ||||
|           password: ${{ secrets.DOCKERHUB_TOKEN }} | ||||
|  | ||||
|       - name: Log in to the Container registry | ||||
|         uses: docker/login-action@v2 | ||||
|         with: | ||||
|           registry: ghcr.io | ||||
|           username: ${{ github.actor }} | ||||
|           password: ${{ secrets.GITHUB_TOKEN }} | ||||
|  | ||||
|       - name: Extract metadata (tags, labels) for Docker | ||||
|         id: meta | ||||
|         uses: docker/metadata-action@v4 | ||||
|         with: | ||||
|           images: | | ||||
|             justsong/one-api-en | ||||
|             ghcr.io/one-api-en | ||||
|  | ||||
|       - name: Build and push Docker images | ||||
|         uses: docker/build-push-action@v3 | ||||
|   | ||||
							
								
								
									
										297
									
								
								README.en.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										297
									
								
								README.en.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,297 @@ | ||||
| <p align="right"> | ||||
|     <a href="./README.md">中文</a> | <strong>English</strong> | ||||
| </p> | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a> | ||||
| </p> | ||||
|  | ||||
| <div align="center"> | ||||
|  | ||||
| # One API | ||||
|  | ||||
| _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use ✨_ | ||||
|  | ||||
| </div> | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="https://raw.githubusercontent.com/songquanpeng/one-api/main/LICENSE"> | ||||
|     <img src="https://img.shields.io/github/license/songquanpeng/one-api?color=brightgreen" alt="license"> | ||||
|   </a> | ||||
|   <a href="https://github.com/songquanpeng/one-api/releases/latest"> | ||||
|     <img src="https://img.shields.io/github/v/release/songquanpeng/one-api?color=brightgreen&include_prereleases" alt="release"> | ||||
|   </a> | ||||
|   <a href="https://hub.docker.com/repository/docker/justsong/one-api"> | ||||
|     <img src="https://img.shields.io/docker/pulls/justsong/one-api?color=brightgreen" alt="docker pull"> | ||||
|   </a> | ||||
|   <a href="https://github.com/songquanpeng/one-api/releases/latest"> | ||||
|     <img src="https://img.shields.io/github/downloads/songquanpeng/one-api/total?color=brightgreen&include_prereleases" alt="release"> | ||||
|   </a> | ||||
|   <a href="https://goreportcard.com/report/github.com/songquanpeng/one-api"> | ||||
|     <img src="https://goreportcard.com/badge/github.com/songquanpeng/one-api" alt="GoReportCard"> | ||||
|   </a> | ||||
| </p> | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="#deployment">Deployment Tutorial</a> | ||||
|   · | ||||
|   <a href="#usage">Usage</a> | ||||
|   · | ||||
|   <a href="https://github.com/songquanpeng/one-api/issues">Feedback</a> | ||||
|   · | ||||
|   <a href="#screenshots">Screenshots</a> | ||||
|   · | ||||
|   <a href="https://openai.justsong.cn/">Live Demo</a> | ||||
|   · | ||||
|   <a href="#faq">FAQ</a> | ||||
|   · | ||||
|   <a href="#related-projects">Related Projects</a> | ||||
|   · | ||||
|   <a href="https://iamazing.cn/page/reward">Donate</a> | ||||
| </p> | ||||
|  | ||||
| > **Warning**: This README is translated by ChatGPT. Please feel free to submit a PR if you find any translation errors. | ||||
|  | ||||
| > **Warning**: The Docker image for English version is `justsong/one-api-en`. | ||||
|  | ||||
| > **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability. | ||||
|  | ||||
| ## Features | ||||
| 1. Support for multiple large models: | ||||
|    + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||
|    + [x] [Anthropic Claude Series Models](https://anthropic.com) | ||||
|    + [x] [Google PaLM2 Series Models](https://developers.generativeai.google) | ||||
|    + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||
|    + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) | ||||
|    + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) | ||||
| 2. Supports access to multiple channels through **load balancing**. | ||||
| 3. Supports **stream mode** that enables typewriter-like effect through stream transmission. | ||||
| 4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details. | ||||
| 5. Supports **token management** that allows setting token expiration time and usage count. | ||||
| 6. Supports **voucher management** that enables batch generation and export of vouchers. Vouchers can be used for account balance replenishment. | ||||
| 7. Supports **channel management** that allows bulk creation of channels. | ||||
| 8. Supports **user grouping** and **channel grouping** for setting different rates for different groups. | ||||
| 9. Supports channel **model list configuration**. | ||||
| 10. Supports **quota details checking**. | ||||
| 11. Supports **user invite rewards**. | ||||
| 12. Allows display of balance in USD. | ||||
| 13. Supports announcement publishing, recharge link setting, and initial balance setting for new users. | ||||
| 14. Offers rich **customization** options: | ||||
|     1. Supports customization of system name, logo, and footer. | ||||
|     2. Supports customization of homepage and about page using HTML & Markdown code, or embedding a standalone webpage through iframe. | ||||
| 15. Supports management API access through system access tokens. | ||||
| 16. Supports Cloudflare Turnstile user verification. | ||||
| 17. Supports user management and multiple user login/registration methods: | ||||
|     + Email login/registration and password reset via email. | ||||
|     + [GitHub OAuth](https://github.com/settings/applications/new). | ||||
|     + WeChat Official Account authorization (requires additional deployment of [WeChat Server](https://github.com/songquanpeng/wechat-server)). | ||||
| 18. Immediate support and encapsulation of other major model APIs as they become available. | ||||
|  | ||||
| ## Deployment | ||||
| ### Docker Deployment | ||||
| Deployment command: `docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api-en` | ||||
|  | ||||
| Update command: `docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR` | ||||
|  | ||||
| The first `3000` in `-p 3000:3000` is the port of the host, which can be modified as needed. | ||||
|  | ||||
| Data will be saved in the `/home/ubuntu/data/one-api` directory on the host. Ensure that the directory exists and has write permissions, or change it to a suitable directory. | ||||
|  | ||||
| Nginx reference configuration: | ||||
| ``` | ||||
| server{ | ||||
|    server_name openai.justsong.cn;  # Modify your domain name accordingly | ||||
|     | ||||
|    location / { | ||||
|           client_max_body_size  64m; | ||||
|           proxy_http_version 1.1; | ||||
|           proxy_pass http://localhost:3000;  # Modify your port accordingly | ||||
|           proxy_set_header Host $host; | ||||
|           proxy_set_header X-Forwarded-For $remote_addr; | ||||
|           proxy_cache_bypass $http_upgrade; | ||||
|           proxy_set_header Accept-Encoding gzip; | ||||
|    } | ||||
| } | ||||
| ``` | ||||
|  | ||||
| Next, configure HTTPS with Let's Encrypt certbot: | ||||
| ```bash | ||||
| # Install certbot on Ubuntu: | ||||
| sudo snap install --classic certbot | ||||
| sudo ln -s /snap/bin/certbot /usr/bin/certbot | ||||
| # Generate certificates & modify Nginx configuration | ||||
| sudo certbot --nginx | ||||
| # Follow the prompts | ||||
| # Restart Nginx | ||||
| sudo service nginx restart | ||||
| ``` | ||||
|  | ||||
| The initial account username is `root` and password is `123456`. | ||||
|  | ||||
| ### Manual Deployment | ||||
| 1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source: | ||||
|    ```shell | ||||
|    git clone https://github.com/songquanpeng/one-api.git | ||||
|     | ||||
|    # Build the frontend | ||||
|    cd one-api/web | ||||
|    npm install | ||||
|    npm run build | ||||
|     | ||||
|    # Build the backend | ||||
|    cd .. | ||||
|    go mod download | ||||
|    go build -ldflags "-s -w" -o one-api | ||||
|    ``` | ||||
| 2. Run: | ||||
|    ```shell | ||||
|    chmod u+x one-api | ||||
|    ./one-api --port 3000 --log-dir ./logs | ||||
|    ``` | ||||
| 3. Access [http://localhost:3000/](http://localhost:3000/) and log in. The initial account username is `root` and password is `123456`. | ||||
|  | ||||
| For more detailed deployment tutorials, please refer to [this page](https://iamazing.cn/page/how-to-deploy-a-website). | ||||
|  | ||||
| ### Multi-machine Deployment | ||||
| 1. Set the same `SESSION_SECRET` for all servers. | ||||
| 2. Set `SQL_DSN` and use MySQL instead of SQLite. All servers should connect to the same database. | ||||
| 3. Set the `NODE_TYPE` for all non-master nodes to `slave`. | ||||
| 4. Set `SYNC_FREQUENCY` for servers to periodically sync configurations from the database. | ||||
| 5. Non-master nodes can optionally set `FRONTEND_BASE_URL` to redirect page requests to the master server. | ||||
| 6. Install Redis separately on non-master nodes, and configure `REDIS_CONN_STRING` so that the database can be accessed with zero latency when the cache has not expired. | ||||
| 7. If the main server also has high latency accessing the database, Redis must be enabled and `SYNC_FREQUENCY` must be set to periodically sync configurations from the database. | ||||
|  | ||||
| Please refer to the [environment variables](#environment-variables) section for details on using environment variables. | ||||
|  | ||||
| ### Deployment on Control Panels (e.g., Baota) | ||||
| Refer to [#175](https://github.com/songquanpeng/one-api/issues/175) for detailed instructions. | ||||
|  | ||||
| If you encounter a blank page after deployment, refer to [#97](https://github.com/songquanpeng/one-api/issues/97) for possible solutions. | ||||
|  | ||||
| ### Deployment on Third-Party Platforms | ||||
| <details> | ||||
| <summary><strong>Deploy on Sealos</strong></summary> | ||||
| <div> | ||||
|  | ||||
| > Sealos supports high concurrency, dynamic scaling, and stable operations for millions of users. | ||||
|  | ||||
| > Click the button below to deploy with one click.👇 | ||||
|  | ||||
| [](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) | ||||
|  | ||||
|  | ||||
| </div> | ||||
| </details> | ||||
|  | ||||
| <details> | ||||
| <summary><strong>Deployment on Zeabur</strong></summary> | ||||
| <div> | ||||
|  | ||||
| > Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage. | ||||
|  | ||||
| 1. First, fork the code. | ||||
| 2. Go to [Zeabur](https://zeabur.com/), log in, and enter the console. | ||||
| 3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port). | ||||
| 4. Copy the connection parameters and run ```create database `one-api` ``` to create the database. | ||||
| 5. Then, in Service -> Add Service, select Git (authorization is required for the first use) and choose your forked repository. | ||||
| 6. Automatic deployment will start, but please cancel it for now. Go to the Variable tab, add a `PORT` with a value of `3000`, and then add a `SQL_DSN` with a value of `<username>:<password>@tcp(<addr>:<port>)/one-api`. Save the changes. Please note that if `SQL_DSN` is not set, data will not be persisted, and the data will be lost after redeployment. | ||||
| 7. Select Redeploy. | ||||
| 8. In the Domains tab, select a suitable domain name prefix, such as "my-one-api". The final domain name will be "my-one-api.zeabur.app". You can also CNAME your own domain name. | ||||
| 9. Wait for the deployment to complete, and click on the generated domain name to access One API. | ||||
|  | ||||
| </div> | ||||
| </details> | ||||
|  | ||||
| ## Configuration | ||||
| The system is ready to use out of the box. | ||||
|  | ||||
| You can configure it by setting environment variables or command line parameters. | ||||
|  | ||||
| After the system starts, log in as the `root` user to further configure the system. | ||||
|  | ||||
| ## Usage | ||||
| Add your API Key on the `Channels` page, and then add an access token on the `Tokens` page. | ||||
|  | ||||
| You can then use your access token to access One API. The usage is consistent with the [OpenAI API](https://platform.openai.com/docs/api-reference/introduction). | ||||
|  | ||||
| In places where the OpenAI API is used, remember to set the API Base to your One API deployment address, for example: `https://openai.justsong.cn`. The API Key should be the token generated in One API. | ||||
|  | ||||
| Note that the specific API Base format depends on the client you are using. | ||||
|  | ||||
| ```mermaid | ||||
| graph LR | ||||
|     A(User) | ||||
|     A --->|Request| B(One API) | ||||
|     B -->|Relay Request| C(OpenAI) | ||||
|     B -->|Relay Request| D(Azure) | ||||
|     B -->|Relay Request| E(Other downstream channels) | ||||
| ``` | ||||
|  | ||||
| To specify which channel to use for the current request, you can add the channel ID after the token, for example: `Authorization: Bearer ONE_API_KEY-CHANNEL_ID`. | ||||
| Note that the token needs to be created by an administrator to specify the channel ID. | ||||
|  | ||||
| If the channel ID is not provided, load balancing will be used to distribute the requests to multiple channels. | ||||
|  | ||||
| ### Environment Variables | ||||
| 1. `REDIS_CONN_STRING`: When set, Redis will be used as the storage for request rate limiting instead of memory. | ||||
|     + Example: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||
| 2. `SESSION_SECRET`: When set, a fixed session key will be used to ensure that cookies of logged-in users are still valid after the system restarts. | ||||
|     + Example: `SESSION_SECRET=random_string` | ||||
| 3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0. | ||||
|     + Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||
| 4. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. | ||||
|     + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||
| 5. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. | ||||
|     + Example: `SYNC_FREQUENCY=60` | ||||
| 6. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. | ||||
|     + Example: `NODE_TYPE=slave` | ||||
| 7. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. | ||||
|     + Example: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||
| 8. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. | ||||
|     + Example: `CHANNEL_TEST_FREQUENCY=1440` | ||||
| 9. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. | ||||
|     + Example: `POLLING_INTERVAL=5` | ||||
|  | ||||
| ### Command Line Parameters | ||||
| 1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`. | ||||
|     + Example: `--port 3000` | ||||
| 2. `--log-dir <log_dir>`: Specifies the log directory. If not set, the logs will not be saved. | ||||
|     + Example: `--log-dir ./logs` | ||||
| 3. `--version`: Prints the system version number and exits. | ||||
| 4. `--help`: Displays the command usage help and parameter descriptions. | ||||
|  | ||||
| ## Screenshots | ||||
|  | ||||
|  | ||||
|  | ||||
| ## FAQ | ||||
| 1. What is quota? How is it calculated? Does One API have quota calculation issues? | ||||
|     + Quota = Group multiplier * Model multiplier * (number of prompt tokens + number of completion tokens * completion multiplier) | ||||
|     + The completion multiplier is fixed at 1.33 for GPT3.5 and 2 for GPT4, consistent with the official definition. | ||||
|     + If it is not a stream mode, the official API will return the total number of tokens consumed. However, please note that the consumption multipliers for prompts and completions are different. | ||||
| 2. Why does it prompt "insufficient quota" even though my account balance is sufficient? | ||||
|     + Please check if your token quota is sufficient. It is separate from the account balance. | ||||
|     + The token quota is used to set the maximum usage and can be freely set by the user. | ||||
| 3. It says "No available channels" when trying to use a channel. What should I do? | ||||
|     + Please check the user and channel group settings. | ||||
|     + Also check the channel model settings. | ||||
| 4. Channel testing reports an error: "invalid character '<' looking for beginning of value" | ||||
|     + This error occurs when the returned value is not valid JSON but an HTML page. | ||||
|     + Most likely, the IP of your deployment site or the node of the proxy has been blocked by CloudFlare. | ||||
| 5. ChatGPT Next Web reports an error: "Failed to fetch" | ||||
|     + Do not set `BASE_URL` during deployment. | ||||
|     + Double-check that your interface address and API Key are correct. | ||||
|  | ||||
| ## Related Projects | ||||
| [FastGPT](https://github.com/c121914yu/FastGPT): Build an AI knowledge base in three minutes | ||||
|  | ||||
| ## Note | ||||
| This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. | ||||
|  | ||||
| This project is released under the MIT license. Based on this, attribution and a link to this project must be included at the bottom of the page. | ||||
|  | ||||
| The same applies to derivative projects based on this project. | ||||
|  | ||||
| If you do not wish to include attribution, prior authorization must be obtained. | ||||
|  | ||||
| According to the MIT license, users should bear the risk and responsibility of using this project, and the developer of this open-source project is not responsible for this. | ||||
							
								
								
									
										105
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										105
									
								
								README.md
									
									
									
									
									
								
							| @@ -1,3 +1,8 @@ | ||||
| <p align="right"> | ||||
|    <strong>中文</strong> | <a href="./README.en.md">English</a> | ||||
| </p> | ||||
|  | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a> | ||||
| </p> | ||||
| @@ -6,7 +11,7 @@ | ||||
|  | ||||
| # One API | ||||
|  | ||||
| _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用✨_ | ||||
| _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 ✨_ | ||||
|  | ||||
| </div> | ||||
|  | ||||
| @@ -46,50 +51,61 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用 | ||||
|   <a href="https://iamazing.cn/page/reward">赞赏支持</a> | ||||
| </p> | ||||
|  | ||||
| > **Note**:本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 | ||||
|  | ||||
| > **Note**:使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 | ||||
|  | ||||
| > **Warning**:从 `v0.3` 版本升级到 `v0.4` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.3-v0.4.sql)。 | ||||
|  | ||||
| ## 功能 | ||||
| 1. 支持多种 API 访问渠道,欢迎 PR 或提 issue 添加更多渠道: | ||||
|    + [x] OpenAI 官方通道(支持配置代理) | ||||
|    + [x] **Azure OpenAI API** | ||||
| 1. 支持多种大模型: | ||||
|    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||
|    + [x] [Anthropic Claude 系列模型](https://anthropic.com) | ||||
|    + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) | ||||
|    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||
|    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) | ||||
|    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||
|    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) | ||||
| 2. 支持配置镜像以及众多第三方代理服务: | ||||
|    + [x] [OpenAI-SB](https://openai-sb.com) | ||||
|    + [x] [API2D](https://api2d.com/r/197971) | ||||
|    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) | ||||
|    + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) | ||||
|    + [x] [API2GPT](http://console.api2gpt.com/m/00002S) | ||||
|    + [x] [CloseAI](https://console.closeai-asia.com/r/2412) | ||||
|    + [x] [AI.LS](https://ai.ls) | ||||
|    + [x] [OpenAI Max](https://openaimax.com) | ||||
|    + [x] 自定义渠道:例如各种未收录的第三方代理服务 | ||||
| 2. 支持通过**负载均衡**的方式访问多个渠道。 | ||||
| 3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||
| 4. 支持**多机部署**,[详见此处](#多机部署)。 | ||||
| 5. 支持**令牌管理**,设置令牌的过期时间和使用次数。 | ||||
| 6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 | ||||
| 7. 支持**通道管理**,批量创建通道。 | ||||
| 8. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 | ||||
| 9. 支持渠道**设置模型列表**。 | ||||
| 10. 支持**查看额度明细**。 | ||||
| 11. 支持**用户邀请奖励**。 | ||||
| 12. 支持以美元为单位显示额度。 | ||||
| 13. 支持发布公告,设置充值链接,设置新用户初始额度。 | ||||
| 14. 支持丰富的**自定义**设置, | ||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||
| 5. 支持**多机部署**,[详见此处](#多机部署)。 | ||||
| 6. 支持**令牌管理**,设置令牌的过期时间和额度。 | ||||
| 7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 | ||||
| 8. 支持**通道管理**,批量创建通道。 | ||||
| 9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 | ||||
| 10. 支持渠道**设置模型列表**。 | ||||
| 11. 支持**查看额度明细**。 | ||||
| 12. 支持**用户邀请奖励**。 | ||||
| 13. 支持以美元为单位显示额度。 | ||||
| 14. 支持发布公告,设置充值链接,设置新用户初始额度。 | ||||
| 15. 支持模型映射,重定向用户的请求模型。 | ||||
| 16. 支持失败自动重试。 | ||||
| 17. 支持绘图接口。 | ||||
| 18. 支持丰富的**自定义**设置, | ||||
|     1. 支持自定义系统名称,logo 以及页脚。 | ||||
|     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 | ||||
| 15. 支持通过系统访问令牌访问管理 API。 | ||||
| 16. 支持 Cloudflare Turnstile 用户校验。 | ||||
| 17. 支持用户管理,支持**多种用户登录注册方式**: | ||||
|     + 邮箱登录注册以及通过邮箱进行密码重置。 | ||||
| 19. 支持通过系统访问令牌访问管理 API。 | ||||
| 20. 支持 Cloudflare Turnstile 用户校验。 | ||||
| 21. 支持用户管理,支持**多种用户登录注册方式**: | ||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||
| 18. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。 | ||||
|  | ||||
| ## 部署 | ||||
| ### 基于 Docker 进行部署 | ||||
| 部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api` | ||||
|  | ||||
| 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 | ||||
|  | ||||
| 如果你的并发量较大,推荐设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。 | ||||
|  | ||||
| 更新命令:`docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR` | ||||
|  | ||||
| `-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。 | ||||
| @@ -109,6 +125,7 @@ server{ | ||||
|           proxy_set_header X-Forwarded-For $remote_addr; | ||||
|           proxy_cache_bypass $http_upgrade; | ||||
|           proxy_set_header Accept-Encoding gzip; | ||||
|           proxy_read_timeout 300s;  # GPT-4 需要较长的超时时间,请自行调整 | ||||
|    } | ||||
| } | ||||
| ``` | ||||
| @@ -136,7 +153,7 @@ sudo service nginx restart | ||||
|    cd one-api/web | ||||
|    npm install | ||||
|    npm run build | ||||
|  | ||||
|     | ||||
|    # 构建后端 | ||||
|    cd .. | ||||
|    go mod download | ||||
| @@ -154,8 +171,8 @@ sudo service nginx restart | ||||
| ### 多机部署 | ||||
| 1. 所有服务器 `SESSION_SECRET` 设置一样的值。 | ||||
| 2. 必须设置 `SQL_DSN`,使用 MySQL 数据库而非 SQLite,所有服务器连接同一个数据库。 | ||||
| 3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`。 | ||||
| 4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置。 | ||||
| 3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。 | ||||
| 4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis,无论主从。 | ||||
| 5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。 | ||||
| 6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟。 | ||||
| 7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis,并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。 | ||||
| @@ -178,7 +195,7 @@ sudo service nginx restart | ||||
| docker run --name chat-next-web -d -p 3001:3000 yidadaa/chatgpt-next-web | ||||
| ``` | ||||
|  | ||||
| 注意修改端口号和 `BASE_URL`。 | ||||
| 注意修改端口号,之后在页面上设置接口地址(例如:https://openai.justsong.cn/ )和 API Key 即可。 | ||||
|  | ||||
| #### ChatGPT Web | ||||
| 项目主页:https://github.com/Chanzhaoyu/chatgpt-web | ||||
| @@ -190,6 +207,19 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | ||||
| 注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。 | ||||
|  | ||||
| ### 部署到第三方平台 | ||||
| <details> | ||||
| <summary><strong>部署到 Sealos </strong></summary> | ||||
| <div> | ||||
|  | ||||
| > Sealos 的服务器在国外,不需要额外处理网络问题,支持高并发 & 动态伸缩。 | ||||
|  | ||||
| 点击以下按钮一键部署: | ||||
|  | ||||
| [](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) | ||||
|  | ||||
| </div> | ||||
| </details> | ||||
|  | ||||
| <details> | ||||
| <summary><strong>部署到 Zeabur</strong></summary> | ||||
| <div> | ||||
| @@ -216,6 +246,8 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | ||||
|  | ||||
| 等到系统启动后,使用 `root` 用户登录系统并做进一步的配置。 | ||||
|  | ||||
| **Note**:如果你不知道某个配置项的含义,可以临时删掉值以看到进一步的提示文字。 | ||||
|  | ||||
| ## 使用方法 | ||||
| 在`渠道`页面中添加你的 API Key,之后在`令牌`页面中新增访问令牌。 | ||||
|  | ||||
| @@ -246,7 +278,10 @@ graph LR | ||||
|    + 例子:`SESSION_SECRET=random_string` | ||||
| 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 8.0 版本。 | ||||
|    + 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||
| 4. `FRONTEND_BASE_URL`:设置之后将使用指定的前端地址,而非后端地址。 | ||||
|    + 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。 | ||||
|    + 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。 | ||||
|    + 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。 | ||||
| 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | ||||
|    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||
| 5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。 | ||||
|    + 例子:`SYNC_FREQUENCY=60` | ||||
| @@ -256,7 +291,7 @@ graph LR | ||||
|    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` | ||||
| 8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | ||||
|    + 例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||
| 9. `REQUEST_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||
| 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||
|    + 例子:`POLLING_INTERVAL=5` | ||||
|  | ||||
| ### 命令行参数 | ||||
| @@ -281,6 +316,7 @@ https://openai.justsong.cn | ||||
|    + 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率) | ||||
|    + 其中补全倍率对于 GPT3.5 固定为 1.33,GPT4 为 2,与官方保持一致。 | ||||
|    + 如果是非流模式,官方接口会返回消耗的总 token,但是你要注意提示和补全的消耗倍率不一样。 | ||||
|    + 注意,One API 的默认倍率就是官方倍率,是已经调整过的。 | ||||
| 2. 账户额度足够为什么提示额度不足? | ||||
|    + 请检查你的令牌额度是否足够,这个和账户额度是分开的。 | ||||
|    + 令牌额度仅供用户设置最大使用量,用户可自由设置。 | ||||
| @@ -293,13 +329,16 @@ https://openai.justsong.cn | ||||
| 5. ChatGPT Next Web 报错:`Failed to fetch` | ||||
|    + 部署的时候不要设置 `BASE_URL`。 | ||||
|    + 检查你的接口地址和 API Key 有没有填对。 | ||||
| 6. 报错:`当前分组负载已饱和,请稍后再试` | ||||
|    + 上游通道 429 了。 | ||||
|  | ||||
| ## 相关项目 | ||||
| [FastGPT](https://github.com/c121914yu/FastGPT): 三分钟搭建 AI 知识库 | ||||
|  | ||||
| ## 注意 | ||||
| 本项目为开源项目,请在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 | ||||
|  | ||||
| 本项目使用 MIT 协议进行开源,请以某种方式保留 One API 的版权信息。 | ||||
| 本项目使用 MIT 协议进行开源,**在此基础上**,必须在页面底部保留署名以及指向本项目的链接。如果不想保留署名,必须首先获得授权。 | ||||
|  | ||||
| 同样适用于基于本项目的二开项目。 | ||||
|  | ||||
| 依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。 | ||||
| @@ -1,25 +1,29 @@ | ||||
| #!/bin/bash | ||||
|  | ||||
| if [ $# -ne 3 ]; then | ||||
|   echo "Usage: time_test.sh <domain> <key> <count>" | ||||
| if [ $# -lt 3 ]; then | ||||
|   echo "Usage: time_test.sh <domain> <key> <count> [<model>]" | ||||
|   exit 1 | ||||
| fi | ||||
|  | ||||
| domain=$1 | ||||
| key=$2 | ||||
| count=$3 | ||||
| model=${4:-"gpt-3.5-turbo"} # 设置默认模型为 gpt-3.5-turbo | ||||
|  | ||||
| total_time=0 | ||||
| times=() | ||||
|  | ||||
| for ((i=1; i<=count; i++)); do | ||||
|   result=$(curl -o /dev/null -s -w %{time_total}\\n \ | ||||
|   result=$(curl -o /dev/null -s -w "%{http_code} %{time_total}\\n" \ | ||||
|            https://"$domain"/v1/chat/completions \ | ||||
|            -H "Content-Type: application/json" \ | ||||
|            -H "Authorization: Bearer $key" \ | ||||
|            -d '{"messages": [{"content": "echo hi", "role": "user"}], "model": "gpt-3.5-turbo", "stream": false, "max_tokens": 1}') | ||||
|   echo "$result" | ||||
|   total_time=$(bc <<< "$total_time + $result") | ||||
|   times+=("$result") | ||||
|            -d '{"messages": [{"content": "echo hi", "role": "user"}], "model": "'"$model"'", "stream": false, "max_tokens": 1}') | ||||
|   http_code=$(echo "$result" | awk '{print $1}') | ||||
|   time=$(echo "$result" | awk '{print $2}') | ||||
|   echo "HTTP status code: $http_code, Time taken: $time" | ||||
|   total_time=$(bc <<< "$total_time + $time") | ||||
|   times+=("$time") | ||||
| done | ||||
|  | ||||
| average_time=$(echo "scale=4; $total_time / $count" | bc) | ||||
|   | ||||
| @@ -42,6 +42,19 @@ var WeChatAuthEnabled = false | ||||
| var TurnstileCheckEnabled = false | ||||
| var RegisterEnabled = true | ||||
|  | ||||
| var EmailDomainRestrictionEnabled = false | ||||
| var EmailDomainWhitelist = []string{ | ||||
| 	"gmail.com", | ||||
| 	"163.com", | ||||
| 	"126.com", | ||||
| 	"qq.com", | ||||
| 	"outlook.com", | ||||
| 	"hotmail.com", | ||||
| 	"icloud.com", | ||||
| 	"yahoo.com", | ||||
| 	"foxmail.com", | ||||
| } | ||||
|  | ||||
| var LogConsumeEnabled = true | ||||
|  | ||||
| var SMTPServer = "" | ||||
| @@ -67,14 +80,18 @@ var ChannelDisableThreshold = 5.0 | ||||
| var AutomaticDisableChannelEnabled = false | ||||
| var QuotaRemindThreshold = 1000 | ||||
| var PreConsumedQuota = 500 | ||||
| var ApproximateTokenEnabled = false | ||||
| var RetryTimes = 0 | ||||
|  | ||||
| var RootUserEmail = "" | ||||
|  | ||||
| var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" | ||||
|  | ||||
| var requestInterval, _ = strconv.Atoi(os.Getenv("REQUEST_INTERVAL")) | ||||
| var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | ||||
| var RequestInterval = time.Duration(requestInterval) * time.Second | ||||
|  | ||||
| var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY | ||||
|  | ||||
| const ( | ||||
| 	RoleGuestUser  = 0 | ||||
| 	RoleCommonUser = 1 | ||||
| @@ -148,20 +165,32 @@ const ( | ||||
| 	ChannelTypeAIProxy   = 10 | ||||
| 	ChannelTypePaLM      = 11 | ||||
| 	ChannelTypeAPI2GPT   = 12 | ||||
| 	ChannelTypeAIGC2D    = 13 | ||||
| 	ChannelTypeAnthropic = 14 | ||||
| 	ChannelTypeBaidu     = 15 | ||||
| 	ChannelTypeZhipu     = 16 | ||||
| 	ChannelTypeAli       = 17 | ||||
| 	ChannelTypeXunfei    = 18 | ||||
| ) | ||||
|  | ||||
| var ChannelBaseURLs = []string{ | ||||
| 	"",                             // 0 | ||||
| 	"https://api.openai.com",       // 1 | ||||
| 	"https://oa.api2d.net",         // 2 | ||||
| 	"",                             // 3 | ||||
| 	"https://api.openai-proxy.org", // 4 | ||||
| 	"https://api.openai-sb.com",    // 5 | ||||
| 	"https://api.openaimax.com",    // 6 | ||||
| 	"https://api.ohmygpt.com",      // 7 | ||||
| 	"",                             // 8 | ||||
| 	"https://api.caipacity.com",    // 9 | ||||
| 	"https://api.aiproxy.io",       // 10 | ||||
| 	"",                             // 11 | ||||
| 	"https://api.api2gpt.com",      // 12 | ||||
| 	"",                               // 0 | ||||
| 	"https://api.openai.com",         // 1 | ||||
| 	"https://oa.api2d.net",           // 2 | ||||
| 	"",                               // 3 | ||||
| 	"https://api.closeai-proxy.xyz",  // 4 | ||||
| 	"https://api.openai-sb.com",      // 5 | ||||
| 	"https://api.openaimax.com",      // 6 | ||||
| 	"https://api.ohmygpt.com",        // 7 | ||||
| 	"",                               // 8 | ||||
| 	"https://api.caipacity.com",      // 9 | ||||
| 	"https://api.aiproxy.io",         // 10 | ||||
| 	"",                               // 11 | ||||
| 	"https://api.api2gpt.com",        // 12 | ||||
| 	"https://api.aigc2d.com",         // 13 | ||||
| 	"https://api.anthropic.com",      // 14 | ||||
| 	"https://aip.baidubce.com",       // 15 | ||||
| 	"https://open.bigmodel.cn",       // 16 | ||||
| 	"https://dashscope.aliyuncs.com", // 17 | ||||
| 	"",                               // 18 | ||||
| } | ||||
|   | ||||
| @@ -4,9 +4,11 @@ import "encoding/json" | ||||
|  | ||||
| // ModelRatio | ||||
| // https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf | ||||
| // https://openai.com/pricing | ||||
| // TODO: when a new api is enabled, check the pricing here | ||||
| // 1 === $0.002 / 1K tokens | ||||
| // 1 === ¥0.014 / 1k tokens | ||||
| var ModelRatio = map[string]float64{ | ||||
| 	"gpt-4":                   15, | ||||
| 	"gpt-4-0314":              15, | ||||
| @@ -31,10 +33,23 @@ var ModelRatio = map[string]float64{ | ||||
| 	"curie":                   10, | ||||
| 	"babbage":                 10, | ||||
| 	"ada":                     10, | ||||
| 	"text-embedding-ada-002":  0.2, | ||||
| 	"text-embedding-ada-002":  0.05, | ||||
| 	"text-search-ada-doc-001": 10, | ||||
| 	"text-moderation-stable":  0.1, | ||||
| 	"text-moderation-latest":  0.1, | ||||
| 	"dall-e":                  8, | ||||
| 	"claude-instant-1":        0.75, | ||||
| 	"claude-2":                30, | ||||
| 	"ERNIE-Bot":               0.8572, // ¥0.012 / 1k tokens | ||||
| 	"ERNIE-Bot-turbo":         0.5715, // ¥0.008 / 1k tokens | ||||
| 	"Embedding-V1":            0.1429, // ¥0.002 / 1k tokens | ||||
| 	"PaLM-2":                  1, | ||||
| 	"chatglm_pro":             0.7143, // ¥0.01 / 1k tokens | ||||
| 	"chatglm_std":             0.3572, // ¥0.005 / 1k tokens | ||||
| 	"chatglm_lite":            0.1429, // ¥0.002 / 1k tokens | ||||
| 	"qwen-v1":                 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag | ||||
| 	"qwen-plus-v1":            0.5715, // Same as above | ||||
| 	"SparkDesk":               0.8572, // TBD | ||||
| } | ||||
|  | ||||
| func ModelRatio2JSONString() string { | ||||
|   | ||||
| @@ -7,16 +7,24 @@ import ( | ||||
| ) | ||||
|  | ||||
| func GetSubscription(c *gin.Context) { | ||||
| 	var quota int | ||||
| 	var remainQuota int | ||||
| 	var usedQuota int | ||||
| 	var err error | ||||
| 	var token *model.Token | ||||
| 	var expiredTime int64 | ||||
| 	if common.DisplayTokenStatEnabled { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		token, err = model.GetTokenById(tokenId) | ||||
| 		quota = token.RemainQuota | ||||
| 		expiredTime = token.ExpiredTime | ||||
| 		remainQuota = token.RemainQuota | ||||
| 		usedQuota = token.UsedQuota | ||||
| 	} else { | ||||
| 		userId := c.GetInt("id") | ||||
| 		quota, err = model.GetUserQuota(userId) | ||||
| 		remainQuota, err = model.GetUserQuota(userId) | ||||
| 		usedQuota, err = model.GetUserUsedQuota(userId) | ||||
| 	} | ||||
| 	if expiredTime <= 0 { | ||||
| 		expiredTime = 0 | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		openAIError := OpenAIError{ | ||||
| @@ -28,16 +36,21 @@ func GetSubscription(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	quota := remainQuota + usedQuota | ||||
| 	amount := float64(quota) | ||||
| 	if common.DisplayInCurrencyEnabled { | ||||
| 		amount /= common.QuotaPerUnit | ||||
| 	} | ||||
| 	if token != nil && token.UnlimitedQuota { | ||||
| 		amount = 100000000 | ||||
| 	} | ||||
| 	subscription := OpenAISubscriptionResponse{ | ||||
| 		Object:             "billing_subscription", | ||||
| 		HasPaymentMethod:   true, | ||||
| 		SoftLimitUSD:       amount, | ||||
| 		HardLimitUSD:       amount, | ||||
| 		SystemHardLimitUSD: amount, | ||||
| 		AccessUntil:        expiredTime, | ||||
| 	} | ||||
| 	c.JSON(200, subscription) | ||||
| 	return | ||||
| @@ -71,7 +84,7 @@ func GetUsage(c *gin.Context) { | ||||
| 	} | ||||
| 	usage := OpenAIUsageResponse{ | ||||
| 		Object:     "list", | ||||
| 		TotalUsage: amount, | ||||
| 		TotalUsage: amount * 100, | ||||
| 	} | ||||
| 	c.JSON(200, usage) | ||||
| 	return | ||||
|   | ||||
| @@ -22,6 +22,7 @@ type OpenAISubscriptionResponse struct { | ||||
| 	SoftLimitUSD       float64 `json:"soft_limit_usd"` | ||||
| 	HardLimitUSD       float64 `json:"hard_limit_usd"` | ||||
| 	SystemHardLimitUSD float64 `json:"system_hard_limit_usd"` | ||||
| 	AccessUntil        int64   `json:"access_until"` | ||||
| } | ||||
|  | ||||
| type OpenAIUsageDailyCost struct { | ||||
| @@ -32,6 +33,13 @@ type OpenAIUsageDailyCost struct { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type OpenAICreditGrants struct { | ||||
| 	Object         string  `json:"object"` | ||||
| 	TotalGranted   float64 `json:"total_granted"` | ||||
| 	TotalUsed      float64 `json:"total_used"` | ||||
| 	TotalAvailable float64 `json:"total_available"` | ||||
| } | ||||
|  | ||||
| type OpenAIUsageResponse struct { | ||||
| 	Object string `json:"object"` | ||||
| 	//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"` | ||||
| @@ -61,6 +69,14 @@ type API2GPTUsageResponse struct { | ||||
| 	TotalRemaining float64 `json:"total_remaining"` | ||||
| } | ||||
|  | ||||
| type APGC2DGPTUsageResponse struct { | ||||
| 	//Grants         interface{} `json:"grants"` | ||||
| 	Object         string  `json:"object"` | ||||
| 	TotalAvailable float64 `json:"total_available"` | ||||
| 	TotalGranted   float64 `json:"total_granted"` | ||||
| 	TotalUsed      float64 `json:"total_used"` | ||||
| } | ||||
|  | ||||
| // GetAuthHeader get auth header | ||||
| func GetAuthHeader(token string) http.Header { | ||||
| 	h := http.Header{} | ||||
| @@ -69,7 +85,6 @@ func GetAuthHeader(token string) http.Header { | ||||
| } | ||||
|  | ||||
| func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { | ||||
| 	client := &http.Client{} | ||||
| 	req, err := http.NewRequest(method, url, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @@ -77,10 +92,13 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He | ||||
| 	for k := range headers { | ||||
| 		req.Header.Add(k, headers.Get(k)) | ||||
| 	} | ||||
| 	res, err := client.Do(req) | ||||
| 	res, err := httpClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if res.StatusCode != http.StatusOK { | ||||
| 		return nil, fmt.Errorf("status code: %d", res.StatusCode) | ||||
| 	} | ||||
| 	body, err := io.ReadAll(res.Body) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @@ -92,6 +110,22 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He | ||||
| 	return body, nil | ||||
| } | ||||
|  | ||||
| func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL) | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	response := OpenAICreditGrants{} | ||||
| 	err = json.Unmarshal(body, &response) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	channel.UpdateBalance(response.TotalAvailable) | ||||
| 	return response.TotalAvailable, nil | ||||
| } | ||||
|  | ||||
| func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key) | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
| @@ -150,8 +184,26 @@ func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) { | ||||
| 	return response.TotalRemaining, nil | ||||
| } | ||||
|  | ||||
| func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := "https://api.aigc2d.com/dashboard/billing/credit_grants" | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	response := APGC2DGPTUsageResponse{} | ||||
| 	err = json.Unmarshal(body, &response) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	channel.UpdateBalance(response.TotalAvailable) | ||||
| 	return response.TotalAvailable, nil | ||||
| } | ||||
|  | ||||
| func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||
| 	baseURL := common.ChannelBaseURLs[channel.Type] | ||||
| 	if channel.BaseURL == "" { | ||||
| 		channel.BaseURL = baseURL | ||||
| 	} | ||||
| 	switch channel.Type { | ||||
| 	case common.ChannelTypeOpenAI: | ||||
| 		if channel.BaseURL != "" { | ||||
| @@ -161,12 +213,16 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||
| 		return 0, errors.New("尚未实现") | ||||
| 	case common.ChannelTypeCustom: | ||||
| 		baseURL = channel.BaseURL | ||||
| 	case common.ChannelTypeCloseAI: | ||||
| 		return updateChannelCloseAIBalance(channel) | ||||
| 	case common.ChannelTypeOpenAISB: | ||||
| 		return updateChannelOpenAISBBalance(channel) | ||||
| 	case common.ChannelTypeAIProxy: | ||||
| 		return updateChannelAIProxyBalance(channel) | ||||
| 	case common.ChannelTypeAPI2GPT: | ||||
| 		return updateChannelAPI2GPTBalance(channel) | ||||
| 	case common.ChannelTypeAIGC2D: | ||||
| 		return updateChannelAIGC2DBalance(channel) | ||||
| 	default: | ||||
| 		return 0, errors.New("尚未实现") | ||||
| 	} | ||||
|   | ||||
| @@ -14,8 +14,18 @@ import ( | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func testChannel(channel *model.Channel, request ChatRequest) error { | ||||
| func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { | ||||
| 	switch channel.Type { | ||||
| 	case common.ChannelTypePaLM: | ||||
| 		fallthrough | ||||
| 	case common.ChannelTypeAnthropic: | ||||
| 		fallthrough | ||||
| 	case common.ChannelTypeBaidu: | ||||
| 		fallthrough | ||||
| 	case common.ChannelTypeZhipu: | ||||
| 		fallthrough | ||||
| 	case common.ChannelTypeXunfei: | ||||
| 		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil | ||||
| 	case common.ChannelTypeAzure: | ||||
| 		request.Model = "gpt-35-turbo" | ||||
| 	default: | ||||
| @@ -33,11 +43,11 @@ func testChannel(channel *model.Channel, request ChatRequest) error { | ||||
|  | ||||
| 	jsonData, err := json.Marshal(request) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 		return err, nil | ||||
| 	} | ||||
| 	req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 		return err, nil | ||||
| 	} | ||||
| 	if channel.Type == common.ChannelTypeAzure { | ||||
| 		req.Header.Set("api-key", channel.Key) | ||||
| @@ -45,21 +55,20 @@ func testChannel(channel *model.Channel, request ChatRequest) error { | ||||
| 		req.Header.Set("Authorization", "Bearer "+channel.Key) | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/json") | ||||
| 	client := &http.Client{} | ||||
| 	resp, err := client.Do(req) | ||||
| 	resp, err := httpClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 		return err, nil | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 	var response TextResponse | ||||
| 	err = json.NewDecoder(resp.Body).Decode(&response) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 		return err, nil | ||||
| 	} | ||||
| 	if response.Usage.CompletionTokens == 0 { | ||||
| 		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) | ||||
| 		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error | ||||
| 	} | ||||
| 	return nil | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func buildTestRequest() *ChatRequest { | ||||
| @@ -94,7 +103,7 @@ func TestChannel(c *gin.Context) { | ||||
| 	} | ||||
| 	testRequest := buildTestRequest() | ||||
| 	tik := time.Now() | ||||
| 	err = testChannel(channel, *testRequest) | ||||
| 	err, _ = testChannel(channel, *testRequest) | ||||
| 	tok := time.Now() | ||||
| 	milliseconds := tok.Sub(tik).Milliseconds() | ||||
| 	go channel.UpdateResponseTime(milliseconds) | ||||
| @@ -158,13 +167,14 @@ func testAllChannels(notify bool) error { | ||||
| 				continue | ||||
| 			} | ||||
| 			tik := time.Now() | ||||
| 			err := testChannel(channel, *testRequest) | ||||
| 			err, openaiErr := testChannel(channel, *testRequest) | ||||
| 			tok := time.Now() | ||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | ||||
| 			if err != nil || milliseconds > disableThreshold { | ||||
| 				if milliseconds > disableThreshold { | ||||
| 					err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||
| 				} | ||||
| 			if milliseconds > disableThreshold { | ||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 			} | ||||
| 			if shouldDisableChannel(openaiErr) { | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 			} | ||||
| 			channel.UpdateResponseTime(milliseconds) | ||||
|   | ||||
| @@ -13,7 +13,12 @@ func GetAllLogs(c *gin.Context) { | ||||
| 		p = 0 | ||||
| 	} | ||||
| 	logType, _ := strconv.Atoi(c.Query("type")) | ||||
| 	logs, err := model.GetAllLogs(logType, p*common.ItemsPerPage, common.ItemsPerPage) | ||||
| 	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) | ||||
| 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | ||||
| 	username := c.Query("username") | ||||
| 	tokenName := c.Query("token_name") | ||||
| 	modelName := c.Query("model_name") | ||||
| 	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) | ||||
| 	if err != nil { | ||||
| 		c.JSON(200, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -35,7 +40,11 @@ func GetUserLogs(c *gin.Context) { | ||||
| 	} | ||||
| 	userId := c.GetInt("id") | ||||
| 	logType, _ := strconv.Atoi(c.Query("type")) | ||||
| 	logs, err := model.GetUserLogs(userId, logType, p*common.ItemsPerPage, common.ItemsPerPage) | ||||
| 	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) | ||||
| 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | ||||
| 	tokenName := c.Query("token_name") | ||||
| 	modelName := c.Query("model_name") | ||||
| 	logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) | ||||
| 	if err != nil { | ||||
| 		c.JSON(200, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -84,3 +93,41 @@ func SearchUserLogs(c *gin.Context) { | ||||
| 		"data":    logs, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func GetLogsStat(c *gin.Context) { | ||||
| 	logType, _ := strconv.Atoi(c.Query("type")) | ||||
| 	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) | ||||
| 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | ||||
| 	tokenName := c.Query("token_name") | ||||
| 	username := c.Query("username") | ||||
| 	modelName := c.Query("model_name") | ||||
| 	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) | ||||
| 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") | ||||
| 	c.JSON(200, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data": gin.H{ | ||||
| 			"quota": quotaNum, | ||||
| 			//"token": tokenNum, | ||||
| 		}, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func GetLogsSelfStat(c *gin.Context) { | ||||
| 	username := c.GetString("username") | ||||
| 	logType, _ := strconv.Atoi(c.Query("type")) | ||||
| 	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) | ||||
| 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | ||||
| 	tokenName := c.Query("token_name") | ||||
| 	modelName := c.Query("model_name") | ||||
| 	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) | ||||
| 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) | ||||
| 	c.JSON(200, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data": gin.H{ | ||||
| 			"quota": quotaNum, | ||||
| 			//"token": tokenNum, | ||||
| 		}, | ||||
| 	}) | ||||
| } | ||||
|   | ||||
| @@ -3,10 +3,12 @@ package controller | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func GetStatus(c *gin.Context) { | ||||
| @@ -78,6 +80,22 @@ func SendEmailVerification(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	if common.EmailDomainRestrictionEnabled { | ||||
| 		allowed := false | ||||
| 		for _, domain := range common.EmailDomainWhitelist { | ||||
| 			if strings.HasSuffix(email, "@"+domain) { | ||||
| 				allowed = true | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 		if !allowed { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	if model.IsEmailAlreadyTaken(email) { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -127,8 +145,9 @@ func SendPasswordResetEmail(c *gin.Context) { | ||||
| 	link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) | ||||
| 	subject := fmt.Sprintf("%s密码重置", common.SystemName) | ||||
| 	content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ | ||||
| 		"<p>点击<a href='%s'>此处</a>进行密码重置。</p>"+ | ||||
| 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, common.VerificationValidMinutes) | ||||
| 		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ | ||||
| 		"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ | ||||
| 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes) | ||||
| 	err := common.SendEmail(subject, email, content) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
|   | ||||
| @@ -2,6 +2,7 @@ package controller | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| @@ -53,6 +54,15 @@ func init() { | ||||
| 	}) | ||||
| 	// https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||
| 	openAIModels = []OpenAIModels{ | ||||
| 		{ | ||||
| 			Id:         "dall-e", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "dall-e", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-3.5-turbo", | ||||
| 			Object:     "model", | ||||
| @@ -224,6 +234,132 @@ func init() { | ||||
| 			Root:       "text-moderation-stable", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "text-davinci-edit-001", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "text-davinci-edit-001", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "code-davinci-edit-001", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "code-davinci-edit-001", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "claude-instant-1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anturopic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-instant-1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "claude-2", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anturopic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-2", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "ERNIE-Bot", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "baidu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "ERNIE-Bot", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "ERNIE-Bot-turbo", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "baidu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "ERNIE-Bot-turbo", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "Embedding-V1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "baidu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "Embedding-V1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "PaLM-2", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "google", | ||||
| 			Permission: permission, | ||||
| 			Root:       "PaLM-2", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "chatglm_pro", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "zhipu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "chatglm_pro", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "chatglm_std", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "zhipu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "chatglm_std", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "chatglm_lite", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "zhipu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "chatglm_lite", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "qwen-v1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "ali", | ||||
| 			Permission: permission, | ||||
| 			Root:       "qwen-v1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "qwen-plus-v1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "ali", | ||||
| 			Permission: permission, | ||||
| 			Root:       "qwen-plus-v1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "SparkDesk", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "xunfei", | ||||
| 			Permission: permission, | ||||
| 			Root:       "SparkDesk", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 	} | ||||
| 	openAIModelsMap = make(map[string]OpenAIModels) | ||||
| 	for _, model := range openAIModels { | ||||
|   | ||||
| @@ -2,11 +2,12 @@ package controller | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func GetOptions(c *gin.Context) { | ||||
| @@ -49,6 +50,14 @@ func UpdateOption(c *gin.Context) { | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	case "EmailDomainRestrictionEnabled": | ||||
| 		if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	case "WeChatAuthEnabled": | ||||
| 		if option.Value == "true" && common.WeChatServerAddress == "" { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
|   | ||||
							
								
								
									
										240
									
								
								controller/relay-ali.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										240
									
								
								controller/relay-ali.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,240 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r | ||||
|  | ||||
| type AliMessage struct { | ||||
| 	User string `json:"user"` | ||||
| 	Bot  string `json:"bot"` | ||||
| } | ||||
|  | ||||
| type AliInput struct { | ||||
| 	Prompt  string       `json:"prompt"` | ||||
| 	History []AliMessage `json:"history"` | ||||
| } | ||||
|  | ||||
| type AliParameters struct { | ||||
| 	TopP         float64 `json:"top_p,omitempty"` | ||||
| 	TopK         int     `json:"top_k,omitempty"` | ||||
| 	Seed         uint64  `json:"seed,omitempty"` | ||||
| 	EnableSearch bool    `json:"enable_search,omitempty"` | ||||
| } | ||||
|  | ||||
| type AliChatRequest struct { | ||||
| 	Model      string        `json:"model"` | ||||
| 	Input      AliInput      `json:"input"` | ||||
| 	Parameters AliParameters `json:"parameters,omitempty"` | ||||
| } | ||||
|  | ||||
| type AliError struct { | ||||
| 	Code      string `json:"code"` | ||||
| 	Message   string `json:"message"` | ||||
| 	RequestId string `json:"request_id"` | ||||
| } | ||||
|  | ||||
| type AliUsage struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| } | ||||
|  | ||||
| type AliOutput struct { | ||||
| 	Text         string `json:"text"` | ||||
| 	FinishReason string `json:"finish_reason"` | ||||
| } | ||||
|  | ||||
| type AliChatResponse struct { | ||||
| 	Output AliOutput `json:"output"` | ||||
| 	Usage  AliUsage  `json:"usage"` | ||||
| 	AliError | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { | ||||
| 	messages := make([]AliMessage, 0, len(request.Messages)) | ||||
| 	prompt := "" | ||||
| 	for i := 0; i < len(request.Messages); i++ { | ||||
| 		message := request.Messages[i] | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, AliMessage{ | ||||
| 				User: message.Content, | ||||
| 				Bot:  "Okay", | ||||
| 			}) | ||||
| 			continue | ||||
| 		} else { | ||||
| 			if i == len(request.Messages)-1 { | ||||
| 				prompt = message.Content | ||||
| 				break | ||||
| 			} | ||||
| 			messages = append(messages, AliMessage{ | ||||
| 				User: message.Content, | ||||
| 				Bot:  request.Messages[i+1].Content, | ||||
| 			}) | ||||
| 			i++ | ||||
| 		} | ||||
| 	} | ||||
| 	return &AliChatRequest{ | ||||
| 		Model: request.Model, | ||||
| 		Input: AliInput{ | ||||
| 			Prompt:  prompt, | ||||
| 			History: messages, | ||||
| 		}, | ||||
| 		//Parameters: AliParameters{  // ChatGPT's parameters are not compatible with Ali's | ||||
| 		//	TopP: request.TopP, | ||||
| 		//	TopK: 50, | ||||
| 		//	//Seed:         0, | ||||
| 		//	//EnableSearch: false, | ||||
| 		//}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Output.Text, | ||||
| 		}, | ||||
| 		FinishReason: response.Output.FinishReason, | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      response.RequestId, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Usage: Usage{ | ||||
| 			PromptTokens:     response.Usage.InputTokens, | ||||
| 			CompletionTokens: response.Usage.OutputTokens, | ||||
| 			TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens, | ||||
| 		}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = aliResponse.Output.Text | ||||
| 	choice.FinishReason = aliResponse.Output.FinishReason | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Id:      aliResponse.RequestId, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "ernie-bot", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] != "data:" { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[5:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	lastResponseText := "" | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var aliResponse AliChatResponse | ||||
| 			err := json.Unmarshal([]byte(data), &aliResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			usage.PromptTokens += aliResponse.Usage.InputTokens | ||||
| 			usage.CompletionTokens += aliResponse.Usage.OutputTokens | ||||
| 			usage.TotalTokens += aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | ||||
| 			response := streamResponseAli2OpenAI(&aliResponse) | ||||
| 			response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) | ||||
| 			lastResponseText = aliResponse.Output.Text | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var aliResponse AliChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &aliResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if aliResponse.Code != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: aliResponse.Message, | ||||
| 				Type:    aliResponse.Code, | ||||
| 				Param:   aliResponse.RequestId, | ||||
| 				Code:    aliResponse.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseAli2OpenAI(&aliResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
							
								
								
									
										299
									
								
								controller/relay-baidu.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										299
									
								
								controller/relay-baidu.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,299 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 | ||||
|  | ||||
| type BaiduTokenResponse struct { | ||||
| 	RefreshToken  string `json:"refresh_token"` | ||||
| 	ExpiresIn     int    `json:"expires_in"` | ||||
| 	SessionKey    string `json:"session_key"` | ||||
| 	AccessToken   string `json:"access_token"` | ||||
| 	Scope         string `json:"scope"` | ||||
| 	SessionSecret string `json:"session_secret"` | ||||
| } | ||||
|  | ||||
| type BaiduMessage struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type BaiduChatRequest struct { | ||||
| 	Messages []BaiduMessage `json:"messages"` | ||||
| 	Stream   bool           `json:"stream"` | ||||
| 	UserId   string         `json:"user_id,omitempty"` | ||||
| } | ||||
|  | ||||
| type BaiduError struct { | ||||
| 	ErrorCode int    `json:"error_code"` | ||||
| 	ErrorMsg  string `json:"error_msg"` | ||||
| } | ||||
|  | ||||
| type BaiduChatResponse struct { | ||||
| 	Id               string `json:"id"` | ||||
| 	Object           string `json:"object"` | ||||
| 	Created          int64  `json:"created"` | ||||
| 	Result           string `json:"result"` | ||||
| 	IsTruncated      bool   `json:"is_truncated"` | ||||
| 	NeedClearHistory bool   `json:"need_clear_history"` | ||||
| 	Usage            Usage  `json:"usage"` | ||||
| 	BaiduError | ||||
| } | ||||
|  | ||||
| type BaiduChatStreamResponse struct { | ||||
| 	BaiduChatResponse | ||||
| 	SentenceId int  `json:"sentence_id"` | ||||
| 	IsEnd      bool `json:"is_end"` | ||||
| } | ||||
|  | ||||
| type BaiduEmbeddingRequest struct { | ||||
| 	Input []string `json:"input"` | ||||
| } | ||||
|  | ||||
| type BaiduEmbeddingData struct { | ||||
| 	Object    string    `json:"object"` | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| 	Index     int       `json:"index"` | ||||
| } | ||||
|  | ||||
| type BaiduEmbeddingResponse struct { | ||||
| 	Id      string               `json:"id"` | ||||
| 	Object  string               `json:"object"` | ||||
| 	Created int64                `json:"created"` | ||||
| 	Data    []BaiduEmbeddingData `json:"data"` | ||||
| 	Usage   Usage                `json:"usage"` | ||||
| 	BaiduError | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { | ||||
| 	messages := make([]BaiduMessage, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	return &BaiduChatRequest{ | ||||
| 		Messages: messages, | ||||
| 		Stream:   request.Stream, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Result, | ||||
| 		}, | ||||
| 		FinishReason: "stop", | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      response.Id, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: response.Created, | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Usage:   response.Usage, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = baiduResponse.Result | ||||
| 	choice.FinishReason = "stop" | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Id:      baiduResponse.Id, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: baiduResponse.Created, | ||||
| 		Model:   "ernie-bot", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { | ||||
| 	baiduEmbeddingRequest := BaiduEmbeddingRequest{ | ||||
| 		Input: nil, | ||||
| 	} | ||||
| 	switch request.Input.(type) { | ||||
| 	case string: | ||||
| 		baiduEmbeddingRequest.Input = []string{request.Input.(string)} | ||||
| 	case []string: | ||||
| 		baiduEmbeddingRequest.Input = request.Input.([]string) | ||||
| 	} | ||||
| 	return &baiduEmbeddingRequest | ||||
| } | ||||
|  | ||||
| func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { | ||||
| 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), | ||||
| 		Model:  "baidu-embedding", | ||||
| 		Usage:  response.Usage, | ||||
| 	} | ||||
| 	for _, item := range response.Data { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ | ||||
| 			Object:    item.Object, | ||||
| 			Index:     item.Index, | ||||
| 			Embedding: item.Embedding, | ||||
| 		}) | ||||
| 	} | ||||
| 	return &openAIEmbeddingResponse | ||||
| } | ||||
|  | ||||
| func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 6 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[6:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var baiduResponse BaiduChatStreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &baiduResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			usage.PromptTokens += baiduResponse.Usage.PromptTokens | ||||
| 			usage.CompletionTokens += baiduResponse.Usage.CompletionTokens | ||||
| 			usage.TotalTokens += baiduResponse.Usage.TotalTokens | ||||
| 			response := streamResponseBaidu2OpenAI(&baiduResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var baiduResponse BaiduChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &baiduResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if baiduResponse.ErrorMsg != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: baiduResponse.ErrorMsg, | ||||
| 				Type:    "baidu_error", | ||||
| 				Param:   "", | ||||
| 				Code:    baiduResponse.ErrorCode, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseBaidu2OpenAI(&baiduResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|  | ||||
| func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var baiduResponse BaiduEmbeddingResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &baiduResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if baiduResponse.ErrorMsg != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: baiduResponse.ErrorMsg, | ||||
| 				Type:    "baidu_error", | ||||
| 				Param:   "", | ||||
| 				Code:    baiduResponse.ErrorCode, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
							
								
								
									
										221
									
								
								controller/relay-claude.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										221
									
								
								controller/relay-claude.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,221 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type ClaudeMetadata struct { | ||||
| 	UserId string `json:"user_id"` | ||||
| } | ||||
|  | ||||
| type ClaudeRequest struct { | ||||
| 	Model             string   `json:"model"` | ||||
| 	Prompt            string   `json:"prompt"` | ||||
| 	MaxTokensToSample int      `json:"max_tokens_to_sample"` | ||||
| 	StopSequences     []string `json:"stop_sequences,omitempty"` | ||||
| 	Temperature       float64  `json:"temperature,omitempty"` | ||||
| 	TopP              float64  `json:"top_p,omitempty"` | ||||
| 	TopK              int      `json:"top_k,omitempty"` | ||||
| 	//ClaudeMetadata    `json:"metadata,omitempty"` | ||||
| 	Stream bool `json:"stream,omitempty"` | ||||
| } | ||||
|  | ||||
| type ClaudeError struct { | ||||
| 	Type    string `json:"type"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type ClaudeResponse struct { | ||||
| 	Completion string      `json:"completion"` | ||||
| 	StopReason string      `json:"stop_reason"` | ||||
| 	Model      string      `json:"model"` | ||||
| 	Error      ClaudeError `json:"error"` | ||||
| } | ||||
|  | ||||
| func stopReasonClaude2OpenAI(reason string) string { | ||||
| 	switch reason { | ||||
| 	case "stop_sequence": | ||||
| 		return "stop" | ||||
| 	case "max_tokens": | ||||
| 		return "length" | ||||
| 	default: | ||||
| 		return reason | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { | ||||
| 	claudeRequest := ClaudeRequest{ | ||||
| 		Model:             textRequest.Model, | ||||
| 		Prompt:            "", | ||||
| 		MaxTokensToSample: textRequest.MaxTokens, | ||||
| 		StopSequences:     nil, | ||||
| 		Temperature:       textRequest.Temperature, | ||||
| 		TopP:              textRequest.TopP, | ||||
| 		Stream:            textRequest.Stream, | ||||
| 	} | ||||
| 	if claudeRequest.MaxTokensToSample == 0 { | ||||
| 		claudeRequest.MaxTokensToSample = 1000000 | ||||
| 	} | ||||
| 	prompt := "" | ||||
| 	for _, message := range textRequest.Messages { | ||||
| 		if message.Role == "user" { | ||||
| 			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) | ||||
| 		} else if message.Role == "assistant" { | ||||
| 			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) | ||||
| 		} else if message.Role == "system" { | ||||
| 			prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) | ||||
| 		} | ||||
| 	} | ||||
| 	prompt += "\n\nAssistant:" | ||||
| 	claudeRequest.Prompt = prompt | ||||
| 	return &claudeRequest | ||||
| } | ||||
|  | ||||
| func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = claudeResponse.Completion | ||||
| 	choice.FinishReason = stopReasonClaude2OpenAI(claudeResponse.StopReason) | ||||
| 	var response ChatCompletionsStreamResponse | ||||
| 	response.Object = "chat.completion.chunk" | ||||
| 	response.Model = claudeResponse.Model | ||||
| 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: strings.TrimPrefix(claudeResponse.Completion, " "), | ||||
| 			Name:    nil, | ||||
| 		}, | ||||
| 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||||
| 	responseText := "" | ||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | ||||
| 	createdTime := common.GetTimestamp() | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { | ||||
| 			return i + 4, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if !strings.HasPrefix(data, "event: completion") { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = strings.TrimPrefix(data, "event: completion\r\ndata: ") | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			var claudeResponse ClaudeResponse | ||||
| 			err := json.Unmarshal([]byte(data), &claudeResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			responseText += claudeResponse.Completion | ||||
| 			response := streamResponseClaude2OpenAI(&claudeResponse) | ||||
| 			response.Id = responseId | ||||
| 			response.Created = createdTime | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var claudeResponse ClaudeResponse | ||||
| 	err = json.Unmarshal(responseBody, &claudeResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if claudeResponse.Error.Type != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: claudeResponse.Error.Message, | ||||
| 				Type:    claudeResponse.Error.Type, | ||||
| 				Param:   "", | ||||
| 				Code:    claudeResponse.Error.Type, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseClaude2OpenAI(&claudeResponse) | ||||
| 	completionTokens := countTokenText(claudeResponse.Completion, model) | ||||
| 	usage := Usage{ | ||||
| 		PromptTokens:     promptTokens, | ||||
| 		CompletionTokens: completionTokens, | ||||
| 		TotalTokens:      promptTokens + completionTokens, | ||||
| 	} | ||||
| 	fullTextResponse.Usage = usage | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &usage | ||||
| } | ||||
| @@ -1,34 +1,180 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	// TODO: this part is not finished | ||||
| 	req, err := http.NewRequest(c.Request.Method, c.Request.RequestURI, c.Request.Body) | ||||
| 	client := &http.Client{} | ||||
| 	resp, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "do_request_failed", http.StatusOK) | ||||
| 	imageModel := "dall-e" | ||||
|  | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| 	channelType := c.GetInt("channel") | ||||
| 	userId := c.GetInt("id") | ||||
| 	consumeQuota := c.GetBool("consume_quota") | ||||
| 	group := c.GetString("group") | ||||
|  | ||||
| 	var imageRequest ImageRequest | ||||
| 	if consumeQuota { | ||||
| 		err := common.UnmarshalBodyReusable(c, &imageRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Prompt validation | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 		return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Not "256x256", "512x512", or "1024x1024" | ||||
| 	if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { | ||||
| 		return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// N should between 1 and 10 | ||||
| 	if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { | ||||
| 		return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// map model name | ||||
| 	modelMapping := c.GetString("model_mapping") | ||||
| 	isModelMapped := false | ||||
| 	if modelMapping != "" { | ||||
| 		modelMap := make(map[string]string) | ||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		if modelMap[imageModel] != "" { | ||||
| 			imageModel = modelMap[imageModel] | ||||
| 			isModelMapped = true | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	baseURL := common.ChannelBaseURLs[channelType] | ||||
| 	requestURL := c.Request.URL.String() | ||||
|  | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
|  | ||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
|  | ||||
| 	var requestBody io.Reader | ||||
| 	if isModelMapped { | ||||
| 		jsonStr, err := json.Marshal(imageRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} else { | ||||
| 		requestBody = c.Request.Body | ||||
| 	} | ||||
|  | ||||
| 	modelRatio := common.GetModelRatio(imageModel) | ||||
| 	groupRatio := common.GetGroupRatio(group) | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | ||||
|  | ||||
| 	sizeRatio := 1.0 | ||||
| 	// Size | ||||
| 	if imageRequest.Size == "256x256" { | ||||
| 		sizeRatio = 1 | ||||
| 	} else if imageRequest.Size == "512x512" { | ||||
| 		sizeRatio = 1.125 | ||||
| 	} else if imageRequest.Size == "1024x1024" { | ||||
| 		sizeRatio = 1.25 | ||||
| 	} | ||||
| 	quota := int(ratio*sizeRatio*1000) * imageRequest.N | ||||
|  | ||||
| 	if consumeQuota && userQuota-quota < 0 { | ||||
| 		return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden) | ||||
| 	} | ||||
|  | ||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
|  | ||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
|  | ||||
| 	resp, err := httpClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	err = req.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusOK) | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = c.Request.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	var textResponse ImageResponse | ||||
|  | ||||
| 	defer func() { | ||||
| 		if consumeQuota { | ||||
| 			err := model.PostConsumeTokenQuota(tokenId, quota) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 			} | ||||
| 			err = model.CacheUpdateUserQuota(userId) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error update user quota cache: " + err.Error()) | ||||
| 			} | ||||
| 			if quota != 0 { | ||||
| 				tokenName := c.GetString("token_name") | ||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 				model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent) | ||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 				channelId := c.GetInt("channel_id") | ||||
| 				model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	if consumeQuota { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
|  | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = json.Unmarshal(responseBody, &textResponse) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
|  | ||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| 	} | ||||
|  | ||||
| 	for k, v := range resp.Header { | ||||
| 		c.Writer.Header().Set(k, v[0]) | ||||
| 	} | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
|  | ||||
| 	_, err = io.Copy(c.Writer, resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusOK) | ||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusOK) | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
							
								
								
									
										148
									
								
								controller/relay-openai.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								controller/relay-openai.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,148 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { | ||||
| 	responseText := "" | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 6 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:6] != "data: " && data[:6] != "[DONE]" { | ||||
| 				continue | ||||
| 			} | ||||
| 			dataChan <- data | ||||
| 			data = data[6:] | ||||
| 			if !strings.HasPrefix(data, "[DONE]") { | ||||
| 				switch relayMode { | ||||
| 				case RelayModeChatCompletions: | ||||
| 					var streamResponse ChatCompletionsStreamResponse | ||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||
| 					if err != nil { | ||||
| 						common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 						continue // just ignore the error | ||||
| 					} | ||||
| 					for _, choice := range streamResponse.Choices { | ||||
| 						responseText += choice.Delta.Content | ||||
| 					} | ||||
| 				case RelayModeCompletions: | ||||
| 					var streamResponse CompletionsStreamResponse | ||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||
| 					if err != nil { | ||||
| 						common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 						continue | ||||
| 					} | ||||
| 					for _, choice := range streamResponse.Choices { | ||||
| 						responseText += choice.Text | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			if strings.HasPrefix(data, "data: [DONE]") { | ||||
| 				data = data[:12] | ||||
| 			} | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			c.Render(-1, common.CustomEvent{Data: data}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var textResponse TextResponse | ||||
| 	if consumeQuota { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
| 		err = json.Unmarshal(responseBody, &textResponse) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
| 		if textResponse.Error.Type != "" { | ||||
| 			return &OpenAIErrorWithStatusCode{ | ||||
| 				OpenAIError: textResponse.Error, | ||||
| 				StatusCode:  resp.StatusCode, | ||||
| 			}, nil | ||||
| 		} | ||||
| 		// Reset response body | ||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| 	} | ||||
| 	// We shouldn't set the header before we parse the response body, because the parse part may fail. | ||||
| 	// And then we will have to send an error response, but in this case, the header has already been set. | ||||
| 	// So the httpClient will be confused by the response. | ||||
| 	// For example, Postman will report error, and we cannot check the response at all. | ||||
| 	for k, v := range resp.Header { | ||||
| 		c.Writer.Header().Set(k, v[0]) | ||||
| 	} | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err := io.Copy(c.Writer, resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	if textResponse.Usage.TotalTokens == 0 { | ||||
| 		completionTokens := 0 | ||||
| 		for _, choice := range textResponse.Choices { | ||||
| 			completionTokens += countTokenText(choice.Message.Content, model) | ||||
| 		} | ||||
| 		textResponse.Usage = Usage{ | ||||
| 			PromptTokens:     promptTokens, | ||||
| 			CompletionTokens: completionTokens, | ||||
| 			TotalTokens:      promptTokens + completionTokens, | ||||
| 		} | ||||
| 	} | ||||
| 	return nil, &textResponse.Usage | ||||
| } | ||||
| @@ -1,10 +1,17 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| ) | ||||
|  | ||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body | ||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body | ||||
|  | ||||
| type PaLMChatMessage struct { | ||||
| 	Author  string `json:"author"` | ||||
| 	Content string `json:"content"` | ||||
| @@ -15,45 +22,188 @@ type PaLMFilter struct { | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body | ||||
| type PaLMPrompt struct { | ||||
| 	Messages []PaLMChatMessage `json:"messages"` | ||||
| } | ||||
|  | ||||
| type PaLMChatRequest struct { | ||||
| 	Prompt         []Message `json:"prompt"` | ||||
| 	Temperature    float64   `json:"temperature"` | ||||
| 	CandidateCount int       `json:"candidateCount"` | ||||
| 	TopP           float64   `json:"topP"` | ||||
| 	TopK           int       `json:"topK"` | ||||
| 	Prompt         PaLMPrompt `json:"prompt"` | ||||
| 	Temperature    float64    `json:"temperature,omitempty"` | ||||
| 	CandidateCount int        `json:"candidateCount,omitempty"` | ||||
| 	TopP           float64    `json:"topP,omitempty"` | ||||
| 	TopK           int        `json:"topK,omitempty"` | ||||
| } | ||||
|  | ||||
| type PaLMError struct { | ||||
| 	Code    int    `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| 	Status  string `json:"status"` | ||||
| } | ||||
|  | ||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body | ||||
| type PaLMChatResponse struct { | ||||
| 	Candidates []Message    `json:"candidates"` | ||||
| 	Messages   []Message    `json:"messages"` | ||||
| 	Filters    []PaLMFilter `json:"filters"` | ||||
| 	Candidates []PaLMChatMessage `json:"candidates"` | ||||
| 	Messages   []Message         `json:"messages"` | ||||
| 	Filters    []PaLMFilter      `json:"filters"` | ||||
| 	Error      PaLMError         `json:"error"` | ||||
| } | ||||
|  | ||||
| func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode { | ||||
| 	// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage | ||||
| 	messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages)) | ||||
| 	for _, message := range openAIRequest.Messages { | ||||
| 		var author string | ||||
| 		if message.Role == "user" { | ||||
| 			author = "0" | ||||
| 		} else { | ||||
| 			author = "1" | ||||
| 		} | ||||
| 		messages = append(messages, PaLMChatMessage{ | ||||
| 			Author:  author, | ||||
| func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { | ||||
| 	palmRequest := PaLMChatRequest{ | ||||
| 		Prompt: PaLMPrompt{ | ||||
| 			Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), | ||||
| 		}, | ||||
| 		Temperature:    textRequest.Temperature, | ||||
| 		CandidateCount: textRequest.N, | ||||
| 		TopP:           textRequest.TopP, | ||||
| 		TopK:           textRequest.MaxTokens, | ||||
| 	} | ||||
| 	for _, message := range textRequest.Messages { | ||||
| 		palmMessage := PaLMChatMessage{ | ||||
| 			Content: message.Content, | ||||
| 		}) | ||||
| 		} | ||||
| 		if message.Role == "user" { | ||||
| 			palmMessage.Author = "0" | ||||
| 		} else { | ||||
| 			palmMessage.Author = "1" | ||||
| 		} | ||||
| 		palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage) | ||||
| 	} | ||||
| 	request := PaLMChatRequest{ | ||||
| 		Prompt:         nil, | ||||
| 		Temperature:    openAIRequest.Temperature, | ||||
| 		CandidateCount: openAIRequest.N, | ||||
| 		TopP:           openAIRequest.TopP, | ||||
| 		TopK:           openAIRequest.MaxTokens, | ||||
| 	} | ||||
| 	// TODO: forward request to PaLM & convert response | ||||
| 	fmt.Print(request) | ||||
| 	return nil | ||||
| 	return &palmRequest | ||||
| } | ||||
|  | ||||
| func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), | ||||
| 	} | ||||
| 	for i, candidate := range response.Candidates { | ||||
| 		choice := OpenAITextResponseChoice{ | ||||
| 			Index: i, | ||||
| 			Message: Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: candidate.Content, | ||||
| 			}, | ||||
| 			FinishReason: "stop", | ||||
| 		} | ||||
| 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice) | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	if len(palmResponse.Candidates) > 0 { | ||||
| 		choice.Delta.Content = palmResponse.Candidates[0].Content | ||||
| 	} | ||||
| 	choice.FinishReason = "stop" | ||||
| 	var response ChatCompletionsStreamResponse | ||||
| 	response.Object = "chat.completion.chunk" | ||||
| 	response.Model = "palm2" | ||||
| 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||||
| 	responseText := "" | ||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | ||||
| 	createdTime := common.GetTimestamp() | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error reading stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			common.SysError("error closing stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		var palmResponse PaLMChatResponse | ||||
| 		err = json.Unmarshal(responseBody, &palmResponse) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) | ||||
| 		fullTextResponse.Id = responseId | ||||
| 		fullTextResponse.Created = createdTime | ||||
| 		if len(palmResponse.Candidates) > 0 { | ||||
| 			responseText = palmResponse.Candidates[0].Content | ||||
| 		} | ||||
| 		jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		dataChan <- string(jsonResponse) | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + data}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var palmResponse PaLMChatResponse | ||||
| 	err = json.Unmarshal(responseBody, &palmResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: palmResponse.Error.Message, | ||||
| 				Type:    palmResponse.Error.Status, | ||||
| 				Param:   "", | ||||
| 				Code:    palmResponse.Error.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responsePaLM2OpenAI(&palmResponse) | ||||
| 	completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) | ||||
| 	usage := Usage{ | ||||
| 		PromptTokens:     promptTokens, | ||||
| 		CompletionTokens: completionTokens, | ||||
| 		TotalTokens:      promptTokens + completionTokens, | ||||
| 	} | ||||
| 	fullTextResponse.Usage = usage | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &usage | ||||
| } | ||||
|   | ||||
| @@ -1,18 +1,35 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	APITypeOpenAI = iota | ||||
| 	APITypeClaude | ||||
| 	APITypePaLM | ||||
| 	APITypeBaidu | ||||
| 	APITypeZhipu | ||||
| 	APITypeAli | ||||
| 	APITypeXunfei | ||||
| ) | ||||
|  | ||||
| var httpClient *http.Client | ||||
|  | ||||
| func init() { | ||||
| 	httpClient = &http.Client{} | ||||
| } | ||||
|  | ||||
| func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	channelType := c.GetInt("channel") | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| @@ -26,44 +43,135 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
| 	if relayMode == RelayModeModeration && textRequest.Model == "" { | ||||
| 	if relayMode == RelayModeModerations && textRequest.Model == "" { | ||||
| 		textRequest.Model = "text-moderation-latest" | ||||
| 	} | ||||
| 	if relayMode == RelayModeEmbeddings && textRequest.Model == "" { | ||||
| 		textRequest.Model = c.Param("model") | ||||
| 	} | ||||
| 	// request validation | ||||
| 	if textRequest.Model == "" { | ||||
| 		return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 	} | ||||
| 	switch relayMode { | ||||
| 	case RelayModeCompletions: | ||||
| 		if textRequest.Prompt == "" { | ||||
| 			return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 		} | ||||
| 	case RelayModeChatCompletions: | ||||
| 		if textRequest.Messages == nil || len(textRequest.Messages) == 0 { | ||||
| 			return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 		} | ||||
| 	case RelayModeEmbeddings: | ||||
| 	case RelayModeModerations: | ||||
| 		if textRequest.Input == "" { | ||||
| 			return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 		} | ||||
| 	case RelayModeEdits: | ||||
| 		if textRequest.Instruction == "" { | ||||
| 			return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
| 	// map model name | ||||
| 	modelMapping := c.GetString("model_mapping") | ||||
| 	isModelMapped := false | ||||
| 	if modelMapping != "" && modelMapping != "{}" { | ||||
| 		modelMap := make(map[string]string) | ||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		if modelMap[textRequest.Model] != "" { | ||||
| 			textRequest.Model = modelMap[textRequest.Model] | ||||
| 			isModelMapped = true | ||||
| 		} | ||||
| 	} | ||||
| 	apiType := APITypeOpenAI | ||||
| 	switch channelType { | ||||
| 	case common.ChannelTypeAnthropic: | ||||
| 		apiType = APITypeClaude | ||||
| 	case common.ChannelTypeBaidu: | ||||
| 		apiType = APITypeBaidu | ||||
| 	case common.ChannelTypePaLM: | ||||
| 		apiType = APITypePaLM | ||||
| 	case common.ChannelTypeZhipu: | ||||
| 		apiType = APITypeZhipu | ||||
| 	case common.ChannelTypeAli: | ||||
| 		apiType = APITypeAli | ||||
| 	case common.ChannelTypeXunfei: | ||||
| 		apiType = APITypeXunfei | ||||
| 	} | ||||
| 	baseURL := common.ChannelBaseURLs[channelType] | ||||
| 	requestURL := c.Request.URL.String() | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
| 	if channelType == common.ChannelTypeAzure { | ||||
| 		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | ||||
| 		query := c.Request.URL.Query() | ||||
| 		apiVersion := query.Get("api-version") | ||||
| 		if apiVersion == "" { | ||||
| 			apiVersion = c.GetString("api_version") | ||||
| 	switch apiType { | ||||
| 	case APITypeOpenAI: | ||||
| 		if channelType == common.ChannelTypeAzure { | ||||
| 			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | ||||
| 			query := c.Request.URL.Query() | ||||
| 			apiVersion := query.Get("api-version") | ||||
| 			if apiVersion == "" { | ||||
| 				apiVersion = c.GetString("api_version") | ||||
| 			} | ||||
| 			requestURL := strings.Split(requestURL, "?")[0] | ||||
| 			requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) | ||||
| 			baseURL = c.GetString("base_url") | ||||
| 			task := strings.TrimPrefix(requestURL, "/v1/") | ||||
| 			model_ := textRequest.Model | ||||
| 			model_ = strings.Replace(model_, ".", "", -1) | ||||
| 			// https://github.com/songquanpeng/one-api/issues/67 | ||||
| 			model_ = strings.TrimSuffix(model_, "-0301") | ||||
| 			model_ = strings.TrimSuffix(model_, "-0314") | ||||
| 			model_ = strings.TrimSuffix(model_, "-0613") | ||||
| 			fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) | ||||
| 		} | ||||
| 		requestURL := strings.Split(requestURL, "?")[0] | ||||
| 		requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 		task := strings.TrimPrefix(requestURL, "/v1/") | ||||
| 		model_ := textRequest.Model | ||||
| 		model_ = strings.Replace(model_, ".", "", -1) | ||||
| 		// https://github.com/songquanpeng/one-api/issues/67 | ||||
| 		model_ = strings.TrimSuffix(model_, "-0301") | ||||
| 		model_ = strings.TrimSuffix(model_, "-0314") | ||||
| 		model_ = strings.TrimSuffix(model_, "-0613") | ||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) | ||||
| 	} else if channelType == common.ChannelTypePaLM { | ||||
| 		err := relayPaLM(textRequest, c) | ||||
| 		return err | ||||
| 	case APITypeClaude: | ||||
| 		fullRequestURL = "https://api.anthropic.com/v1/complete" | ||||
| 		if baseURL != "" { | ||||
| 			fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL) | ||||
| 		} | ||||
| 	case APITypeBaidu: | ||||
| 		switch textRequest.Model { | ||||
| 		case "ERNIE-Bot": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" | ||||
| 		case "ERNIE-Bot-turbo": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" | ||||
| 		case "BLOOMZ-7B": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" | ||||
| 		case "Embedding-V1": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" | ||||
| 		} | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days | ||||
| 	case APITypePaLM: | ||||
| 		fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" | ||||
| 		if baseURL != "" { | ||||
| 			fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) | ||||
| 		} | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		fullRequestURL += "?key=" + apiKey | ||||
| 	case APITypeZhipu: | ||||
| 		method := "invoke" | ||||
| 		if textRequest.Stream { | ||||
| 			method = "sse-invoke" | ||||
| 		} | ||||
| 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) | ||||
| 	case APITypeAli: | ||||
| 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" | ||||
| 	} | ||||
| 	var promptTokens int | ||||
| 	var completionTokens int | ||||
| 	switch relayMode { | ||||
| 	case RelayModeChatCompletions: | ||||
| 		promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) | ||||
| 	case RelayModeCompletions: | ||||
| 		promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) | ||||
| 	case RelayModeModeration: | ||||
| 	case RelayModeModerations: | ||||
| 		promptTokens = countTokenInput(textRequest.Input, textRequest.Model) | ||||
| 	} | ||||
| 	preConsumedTokens := common.PreConsumedQuota | ||||
| @@ -89,183 +197,309 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||
| 		} | ||||
| 	} | ||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	if channelType == common.ChannelTypeAzure { | ||||
| 		key := c.Request.Header.Get("Authorization") | ||||
| 		key = strings.TrimPrefix(key, "Bearer ") | ||||
| 		req.Header.Set("api-key", key) | ||||
| 	var requestBody io.Reader | ||||
| 	if isModelMapped { | ||||
| 		jsonStr, err := json.Marshal(textRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} else { | ||||
| 		req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
| 		requestBody = c.Request.Body | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
| 	req.Header.Set("Connection", c.Request.Header.Get("Connection")) | ||||
| 	client := &http.Client{} | ||||
| 	resp, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 	switch apiType { | ||||
| 	case APITypeClaude: | ||||
| 		claudeRequest := requestOpenAI2Claude(textRequest) | ||||
| 		jsonStr, err := json.Marshal(claudeRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeBaidu: | ||||
| 		var jsonData []byte | ||||
| 		var err error | ||||
| 		switch relayMode { | ||||
| 		case RelayModeEmbeddings: | ||||
| 			baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) | ||||
| 			jsonData, err = json.Marshal(baiduEmbeddingRequest) | ||||
| 		default: | ||||
| 			baiduRequest := requestOpenAI2Baidu(textRequest) | ||||
| 			jsonData, err = json.Marshal(baiduRequest) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonData) | ||||
| 	case APITypePaLM: | ||||
| 		palmRequest := requestOpenAI2PaLM(textRequest) | ||||
| 		jsonStr, err := json.Marshal(palmRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeZhipu: | ||||
| 		zhipuRequest := requestOpenAI2Zhipu(textRequest) | ||||
| 		jsonStr, err := json.Marshal(zhipuRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeAli: | ||||
| 		aliRequest := requestOpenAI2Ali(textRequest) | ||||
| 		jsonStr, err := json.Marshal(aliRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} | ||||
| 	err = req.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = c.Request.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
|  | ||||
| 	var req *http.Request | ||||
| 	var resp *http.Response | ||||
| 	isStream := textRequest.Stream | ||||
|  | ||||
| 	if apiType != APITypeXunfei { // cause xunfei use websocket | ||||
| 		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		switch apiType { | ||||
| 		case APITypeOpenAI: | ||||
| 			if channelType == common.ChannelTypeAzure { | ||||
| 				req.Header.Set("api-key", apiKey) | ||||
| 			} else { | ||||
| 				req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
| 			} | ||||
| 		case APITypeClaude: | ||||
| 			req.Header.Set("x-api-key", apiKey) | ||||
| 			anthropicVersion := c.Request.Header.Get("anthropic-version") | ||||
| 			if anthropicVersion == "" { | ||||
| 				anthropicVersion = "2023-06-01" | ||||
| 			} | ||||
| 			req.Header.Set("anthropic-version", anthropicVersion) | ||||
| 		case APITypeZhipu: | ||||
| 			token := getZhipuToken(apiKey) | ||||
| 			req.Header.Set("Authorization", token) | ||||
| 		case APITypeAli: | ||||
| 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||
| 			if textRequest.Stream { | ||||
| 				req.Header.Set("X-DashScope-SSE", "enable") | ||||
| 			} | ||||
| 		} | ||||
| 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
| 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | ||||
| 		resp, err = httpClient.Do(req) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = req.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = c.Request.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||
| 	} | ||||
|  | ||||
| 	var textResponse TextResponse | ||||
| 	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||
| 	var streamResponseText string | ||||
|  | ||||
| 	defer func() { | ||||
| 		if consumeQuota { | ||||
| 			quota := 0 | ||||
| 			completionRatio := 1.34 // default for gpt-3 | ||||
| 			completionRatio := 1.0 | ||||
| 			if strings.HasPrefix(textRequest.Model, "gpt-3.5") { | ||||
| 				completionRatio = 1.333333 | ||||
| 			} | ||||
| 			if strings.HasPrefix(textRequest.Model, "gpt-4") { | ||||
| 				completionRatio = 2 | ||||
| 			} | ||||
| 			if isStream { | ||||
| 				responseTokens := countTokenText(streamResponseText, textRequest.Model) | ||||
| 				quota = promptTokens + int(float64(responseTokens)*completionRatio) | ||||
| 			} else { | ||||
| 				quota = textResponse.Usage.PromptTokens + int(float64(textResponse.Usage.CompletionTokens)*completionRatio) | ||||
| 			} | ||||
|  | ||||
| 			promptTokens = textResponse.Usage.PromptTokens | ||||
| 			completionTokens = textResponse.Usage.CompletionTokens | ||||
|  | ||||
| 			quota = promptTokens + int(float64(completionTokens)*completionRatio) | ||||
| 			quota = int(float64(quota) * ratio) | ||||
| 			if ratio != 0 && quota <= 0 { | ||||
| 				quota = 1 | ||||
| 			} | ||||
| 			totalTokens := promptTokens + completionTokens | ||||
| 			if totalTokens == 0 { | ||||
| 				// in this case, must be some error happened | ||||
| 				// we cannot just return, because we may have to return the pre-consumed quota | ||||
| 				quota = 0 | ||||
| 			} | ||||
| 			quotaDelta := quota - preConsumedQuota | ||||
| 			err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 			} | ||||
| 			tokenName := c.GetString("token_name") | ||||
| 			model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("通过令牌「%s」使用模型 %s 消耗 %s(模型倍率 %.2f,分组倍率 %.2f)", tokenName, textRequest.Model, common.LogQuota(quota), modelRatio, groupRatio)) | ||||
| 			model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 			channelId := c.GetInt("channel_id") | ||||
| 			model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 			err = model.CacheUpdateUserQuota(userId) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error update user quota cache: " + err.Error()) | ||||
| 			} | ||||
| 			if quota != 0 { | ||||
| 				tokenName := c.GetString("token_name") | ||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 				model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) | ||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 				channelId := c.GetInt("channel_id") | ||||
| 				model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	if isStream { | ||||
| 		scanner := bufio.NewScanner(resp.Body) | ||||
| 		scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 			if atEOF && len(data) == 0 { | ||||
| 				return 0, nil, nil | ||||
| 			} | ||||
|  | ||||
| 			if i := strings.Index(string(data), "\n\n"); i >= 0 { | ||||
| 				return i + 2, data[0:i], nil | ||||
| 			} | ||||
|  | ||||
| 			if atEOF { | ||||
| 				return len(data), data, nil | ||||
| 			} | ||||
|  | ||||
| 			return 0, nil, nil | ||||
| 		}) | ||||
| 		dataChan := make(chan string) | ||||
| 		stopChan := make(chan bool) | ||||
| 		go func() { | ||||
| 			for scanner.Scan() { | ||||
| 				data := scanner.Text() | ||||
| 				if len(data) < 6 { // must be something wrong! | ||||
| 					common.SysError("invalid stream response: " + data) | ||||
| 					continue | ||||
| 				} | ||||
| 				dataChan <- data | ||||
| 				data = data[6:] | ||||
| 				if !strings.HasPrefix(data, "[DONE]") { | ||||
| 					switch relayMode { | ||||
| 					case RelayModeChatCompletions: | ||||
| 						var streamResponse ChatCompletionsStreamResponse | ||||
| 						err = json.Unmarshal([]byte(data), &streamResponse) | ||||
| 						if err != nil { | ||||
| 							common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 							return | ||||
| 						} | ||||
| 						for _, choice := range streamResponse.Choices { | ||||
| 							streamResponseText += choice.Delta.Content | ||||
| 						} | ||||
| 					case RelayModeCompletions: | ||||
| 						var streamResponse CompletionsStreamResponse | ||||
| 						err = json.Unmarshal([]byte(data), &streamResponse) | ||||
| 						if err != nil { | ||||
| 							common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 							return | ||||
| 						} | ||||
| 						for _, choice := range streamResponse.Choices { | ||||
| 							streamResponseText += choice.Text | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 			stopChan <- true | ||||
| 		}() | ||||
| 		c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 		c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 		c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 		c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 		c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 		c.Stream(func(w io.Writer) bool { | ||||
| 			select { | ||||
| 			case data := <-dataChan: | ||||
| 				if strings.HasPrefix(data, "data: [DONE]") { | ||||
| 					data = data[:12] | ||||
| 				} | ||||
| 				c.Render(-1, common.CustomEvent{Data: data}) | ||||
| 				return true | ||||
| 			case <-stopChan: | ||||
| 				return false | ||||
| 			} | ||||
| 		}) | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		return nil | ||||
| 	} else { | ||||
| 		if consumeQuota { | ||||
| 			responseBody, err := io.ReadAll(resp.Body) | ||||
| 	switch apiType { | ||||
| 	case APITypeOpenAI: | ||||
| 		if isStream { | ||||
| 			err, responseText := openaiStreamHandler(c, resp, relayMode) | ||||
| 			if err != nil { | ||||
| 				return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 				return err | ||||
| 			} | ||||
| 			err = resp.Body.Close() | ||||
| 			textResponse.Usage.PromptTokens = promptTokens | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) | ||||
| 			if err != nil { | ||||
| 				return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 				return err | ||||
| 			} | ||||
| 			err = json.Unmarshal(responseBody, &textResponse) | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeClaude: | ||||
| 		if isStream { | ||||
| 			err, responseText := claudeStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 				return err | ||||
| 			} | ||||
| 			if textResponse.Error.Type != "" { | ||||
| 				return &OpenAIErrorWithStatusCode{ | ||||
| 					OpenAIError: textResponse.Error, | ||||
| 					StatusCode:  resp.StatusCode, | ||||
| 				} | ||||
| 			textResponse.Usage.PromptTokens = promptTokens | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			// Reset response body | ||||
| 			resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 		// We shouldn't set the header before we parse the response body, because the parse part may fail. | ||||
| 		// And then we will have to send an error response, but in this case, the header has already been set. | ||||
| 		// So the client will be confused by the response. | ||||
| 		// For example, Postman will report error, and we cannot check the response at all. | ||||
| 		for k, v := range resp.Header { | ||||
| 			c.Writer.Header().Set(k, v[0]) | ||||
| 	case APITypeBaidu: | ||||
| 		if isStream { | ||||
| 			err, usage := baiduStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			var err *OpenAIErrorWithStatusCode | ||||
| 			var usage *Usage | ||||
| 			switch relayMode { | ||||
| 			case RelayModeEmbeddings: | ||||
| 				err, usage = baiduEmbeddingHandler(c, resp) | ||||
| 			default: | ||||
| 				err, usage = baiduHandler(c, resp) | ||||
| 			} | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 		c.Writer.WriteHeader(resp.StatusCode) | ||||
| 		_, err = io.Copy(c.Writer, resp.Body) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | ||||
| 	case APITypePaLM: | ||||
| 		if textRequest.Stream { // PaLM2 API does not support stream | ||||
| 			err, responseText := palmStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			textResponse.Usage.PromptTokens = promptTokens | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 	case APITypeZhipu: | ||||
| 		if isStream { | ||||
| 			err, usage := zhipuStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			// zhipu's API does not return prompt tokens & completion tokens | ||||
| 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := zhipuHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			// zhipu's API does not return prompt tokens & completion tokens | ||||
| 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens | ||||
| 			return nil | ||||
| 		} | ||||
| 		return nil | ||||
| 	case APITypeAli: | ||||
| 		if isStream { | ||||
| 			err, usage := aliStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := aliHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeXunfei: | ||||
| 		if isStream { | ||||
| 			auth := c.Request.Header.Get("Authorization") | ||||
| 			auth = strings.TrimPrefix(auth, "Bearer ") | ||||
| 			splits := strings.Split(auth, "|") | ||||
| 			if len(splits) != 3 { | ||||
| 				return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) | ||||
| 			} | ||||
| 			err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) | ||||
| 		} | ||||
| 	default: | ||||
| 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -4,7 +4,6 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/pkoukk/tiktoken-go" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | ||||
| @@ -25,6 +24,13 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken { | ||||
| 	return tokenEncoder | ||||
| } | ||||
|  | ||||
| func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | ||||
| 	if common.ApproximateTokenEnabled { | ||||
| 		return int(float64(len(text)) * 0.38) | ||||
| 	} | ||||
| 	return len(tokenEncoder.Encode(text, nil, nil)) | ||||
| } | ||||
|  | ||||
| func countTokenMessages(messages []Message, model string) int { | ||||
| 	tokenEncoder := getTokenEncoder(model) | ||||
| 	// Reference: | ||||
| @@ -34,12 +40,9 @@ func countTokenMessages(messages []Message, model string) int { | ||||
| 	// Every message follows <|start|>{role/name}\n{content}<|end|>\n | ||||
| 	var tokensPerMessage int | ||||
| 	var tokensPerName int | ||||
| 	if strings.HasPrefix(model, "gpt-3.5") { | ||||
| 	if model == "gpt-3.5-turbo-0301" { | ||||
| 		tokensPerMessage = 4 | ||||
| 		tokensPerName = -1 // If there's a name, the role is omitted | ||||
| 	} else if strings.HasPrefix(model, "gpt-4") { | ||||
| 		tokensPerMessage = 3 | ||||
| 		tokensPerName = 1 | ||||
| 	} else { | ||||
| 		tokensPerMessage = 3 | ||||
| 		tokensPerName = 1 | ||||
| @@ -47,11 +50,11 @@ func countTokenMessages(messages []Message, model string) int { | ||||
| 	tokenNum := 0 | ||||
| 	for _, message := range messages { | ||||
| 		tokenNum += tokensPerMessage | ||||
| 		tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil)) | ||||
| 		tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil)) | ||||
| 		tokenNum += getTokenNum(tokenEncoder, message.Content) | ||||
| 		tokenNum += getTokenNum(tokenEncoder, message.Role) | ||||
| 		if message.Name != nil { | ||||
| 			tokenNum += tokensPerName | ||||
| 			tokenNum += len(tokenEncoder.Encode(*message.Name, nil, nil)) | ||||
| 			tokenNum += getTokenNum(tokenEncoder, *message.Name) | ||||
| 		} | ||||
| 	} | ||||
| 	tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> | ||||
| @@ -74,8 +77,7 @@ func countTokenInput(input any, model string) int { | ||||
|  | ||||
| func countTokenText(text string, model string) int { | ||||
| 	tokenEncoder := getTokenEncoder(model) | ||||
| 	token := tokenEncoder.Encode(text, nil, nil) | ||||
| 	return len(token) | ||||
| 	return getTokenNum(tokenEncoder, text) | ||||
| } | ||||
|  | ||||
| func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { | ||||
| @@ -89,3 +91,16 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus | ||||
| 		StatusCode:  statusCode, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func shouldDisableChannel(err *OpenAIError) bool { | ||||
| 	if !common.AutomaticDisableChannelEnabled { | ||||
| 		return false | ||||
| 	} | ||||
| 	if err == nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | ||||
| 		return true | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|   | ||||
							
								
								
									
										278
									
								
								controller/relay-xunfei.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										278
									
								
								controller/relay-xunfei.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,278 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"crypto/hmac" | ||||
| 	"crypto/sha256" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // https://console.xfyun.cn/services/cbm | ||||
| // https://www.xfyun.cn/doc/spark/Web.html | ||||
|  | ||||
| type XunfeiMessage struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type XunfeiChatRequest struct { | ||||
| 	Header struct { | ||||
| 		AppId string `json:"app_id"` | ||||
| 	} `json:"header"` | ||||
| 	Parameter struct { | ||||
| 		Chat struct { | ||||
| 			Domain      string  `json:"domain,omitempty"` | ||||
| 			Temperature float64 `json:"temperature,omitempty"` | ||||
| 			TopK        int     `json:"top_k,omitempty"` | ||||
| 			MaxTokens   int     `json:"max_tokens,omitempty"` | ||||
| 			Auditing    bool    `json:"auditing,omitempty"` | ||||
| 		} `json:"chat"` | ||||
| 	} `json:"parameter"` | ||||
| 	Payload struct { | ||||
| 		Message struct { | ||||
| 			Text []XunfeiMessage `json:"text"` | ||||
| 		} `json:"message"` | ||||
| 	} `json:"payload"` | ||||
| } | ||||
|  | ||||
| type XunfeiChatResponseTextItem struct { | ||||
| 	Content string `json:"content"` | ||||
| 	Role    string `json:"role"` | ||||
| 	Index   int    `json:"index"` | ||||
| } | ||||
|  | ||||
| type XunfeiChatResponse struct { | ||||
| 	Header struct { | ||||
| 		Code    int    `json:"code"` | ||||
| 		Message string `json:"message"` | ||||
| 		Sid     string `json:"sid"` | ||||
| 		Status  int    `json:"status"` | ||||
| 	} `json:"header"` | ||||
| 	Payload struct { | ||||
| 		Choices struct { | ||||
| 			Status int                          `json:"status"` | ||||
| 			Seq    int                          `json:"seq"` | ||||
| 			Text   []XunfeiChatResponseTextItem `json:"text"` | ||||
| 		} `json:"choices"` | ||||
| 		Usage struct { | ||||
| 			//Text struct { | ||||
| 			//	QuestionTokens   string `json:"question_tokens"` | ||||
| 			//	PromptTokens     string `json:"prompt_tokens"` | ||||
| 			//	CompletionTokens string `json:"completion_tokens"` | ||||
| 			//	TotalTokens      string `json:"total_tokens"` | ||||
| 			//} `json:"text"` | ||||
| 			Text Usage `json:"text"` | ||||
| 		} `json:"usage"` | ||||
| 	} `json:"payload"` | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest { | ||||
| 	messages := make([]XunfeiMessage, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	xunfeiRequest := XunfeiChatRequest{} | ||||
| 	xunfeiRequest.Header.AppId = xunfeiAppId | ||||
| 	xunfeiRequest.Parameter.Chat.Domain = "general" | ||||
| 	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature | ||||
| 	xunfeiRequest.Parameter.Chat.TopK = request.N | ||||
| 	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens | ||||
| 	xunfeiRequest.Payload.Message.Text = messages | ||||
| 	return &xunfeiRequest | ||||
| } | ||||
|  | ||||
| func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { | ||||
| 	if len(response.Payload.Choices.Text) == 0 { | ||||
| 		response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | ||||
| 			{ | ||||
| 				Content: "", | ||||
| 			}, | ||||
| 		} | ||||
| 	} | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Payload.Choices.Text[0].Content, | ||||
| 		}, | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Usage:   response.Payload.Usage.Text, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { | ||||
| 		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | ||||
| 			{ | ||||
| 				Content: "", | ||||
| 			}, | ||||
| 		} | ||||
| 	} | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "SparkDesk", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { | ||||
| 	HmacWithShaToBase64 := func(algorithm, data, key string) string { | ||||
| 		mac := hmac.New(sha256.New, []byte(key)) | ||||
| 		mac.Write([]byte(data)) | ||||
| 		encodeData := mac.Sum(nil) | ||||
| 		return base64.StdEncoding.EncodeToString(encodeData) | ||||
| 	} | ||||
| 	ul, err := url.Parse(hostUrl) | ||||
| 	if err != nil { | ||||
| 		fmt.Println(err) | ||||
| 	} | ||||
| 	date := time.Now().UTC().Format(time.RFC1123) | ||||
| 	signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} | ||||
| 	sign := strings.Join(signString, "\n") | ||||
| 	sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret) | ||||
| 	authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, | ||||
| 		"hmac-sha256", "host date request-line", sha) | ||||
| 	authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) | ||||
| 	v := url.Values{} | ||||
| 	v.Add("host", ul.Host) | ||||
| 	v.Add("date", date) | ||||
| 	v.Add("authorization", authorization) | ||||
| 	callUrl := hostUrl + "?" + v.Encode() | ||||
| 	return callUrl | ||||
| } | ||||
|  | ||||
| func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage Usage | ||||
| 	d := websocket.Dialer{ | ||||
| 		HandshakeTimeout: 5 * time.Second, | ||||
| 	} | ||||
| 	hostUrl := "wss://aichat.xf-yun.com/v1/chat" | ||||
| 	conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil) | ||||
| 	if err != nil || resp.StatusCode != 101 { | ||||
| 		return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	data := requestOpenAI2Xunfei(textRequest, appId) | ||||
| 	err = conn.WriteJSON(data) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	dataChan := make(chan XunfeiChatResponse) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			_, msg, err := conn.ReadMessage() | ||||
| 			if err != nil { | ||||
| 				common.SysError("error reading stream response: " + err.Error()) | ||||
| 				break | ||||
| 			} | ||||
| 			var response XunfeiChatResponse | ||||
| 			err = json.Unmarshal(msg, &response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				break | ||||
| 			} | ||||
| 			dataChan <- response | ||||
| 			if response.Payload.Choices.Status == 2 { | ||||
| 				err := conn.Close() | ||||
| 				if err != nil { | ||||
| 					common.SysError("error closing websocket connection: " + err.Error()) | ||||
| 				} | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case xunfeiResponse := <-dataChan: | ||||
| 			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens | ||||
| 			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens | ||||
| 			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens | ||||
| 			response := streamResponseXunfei2OpenAI(&xunfeiResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var xunfeiResponse XunfeiChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &xunfeiResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if xunfeiResponse.Header.Code != 0 { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: xunfeiResponse.Header.Message, | ||||
| 				Type:    "xunfei_error", | ||||
| 				Param:   "", | ||||
| 				Code:    xunfeiResponse.Header.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
							
								
								
									
										306
									
								
								controller/relay-zhipu.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										306
									
								
								controller/relay-zhipu.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,306 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/golang-jwt/jwt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // https://open.bigmodel.cn/doc/api#chatglm_std | ||||
| // chatglm_std, chatglm_lite | ||||
| // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke | ||||
| // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke | ||||
|  | ||||
| type ZhipuMessage struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type ZhipuRequest struct { | ||||
| 	Prompt      []ZhipuMessage `json:"prompt"` | ||||
| 	Temperature float64        `json:"temperature,omitempty"` | ||||
| 	TopP        float64        `json:"top_p,omitempty"` | ||||
| 	RequestId   string         `json:"request_id,omitempty"` | ||||
| 	Incremental bool           `json:"incremental,omitempty"` | ||||
| } | ||||
|  | ||||
| type ZhipuResponseData struct { | ||||
| 	TaskId     string         `json:"task_id"` | ||||
| 	RequestId  string         `json:"request_id"` | ||||
| 	TaskStatus string         `json:"task_status"` | ||||
| 	Choices    []ZhipuMessage `json:"choices"` | ||||
| 	Usage      `json:"usage"` | ||||
| } | ||||
|  | ||||
| type ZhipuResponse struct { | ||||
| 	Code    int               `json:"code"` | ||||
| 	Msg     string            `json:"msg"` | ||||
| 	Success bool              `json:"success"` | ||||
| 	Data    ZhipuResponseData `json:"data"` | ||||
| } | ||||
|  | ||||
| type ZhipuStreamMetaResponse struct { | ||||
| 	RequestId  string `json:"request_id"` | ||||
| 	TaskId     string `json:"task_id"` | ||||
| 	TaskStatus string `json:"task_status"` | ||||
| 	Usage      `json:"usage"` | ||||
| } | ||||
|  | ||||
| type zhipuTokenData struct { | ||||
| 	Token      string | ||||
| 	ExpiryTime time.Time | ||||
| } | ||||
|  | ||||
| var zhipuTokens sync.Map | ||||
| var expSeconds int64 = 24 * 3600 | ||||
|  | ||||
| func getZhipuToken(apikey string) string { | ||||
| 	data, ok := zhipuTokens.Load(apikey) | ||||
| 	if ok { | ||||
| 		tokenData := data.(zhipuTokenData) | ||||
| 		if time.Now().Before(tokenData.ExpiryTime) { | ||||
| 			return tokenData.Token | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	split := strings.Split(apikey, ".") | ||||
| 	if len(split) != 2 { | ||||
| 		common.SysError("invalid zhipu key: " + apikey) | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| 	id := split[0] | ||||
| 	secret := split[1] | ||||
|  | ||||
| 	expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6 | ||||
| 	expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second) | ||||
|  | ||||
| 	timestamp := time.Now().UnixNano() / 1e6 | ||||
|  | ||||
| 	payload := jwt.MapClaims{ | ||||
| 		"api_key":   id, | ||||
| 		"exp":       expMillis, | ||||
| 		"timestamp": timestamp, | ||||
| 	} | ||||
|  | ||||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) | ||||
|  | ||||
| 	token.Header["alg"] = "HS256" | ||||
| 	token.Header["sign_type"] = "SIGN" | ||||
|  | ||||
| 	tokenString, err := token.SignedString([]byte(secret)) | ||||
| 	if err != nil { | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| 	zhipuTokens.Store(apikey, zhipuTokenData{ | ||||
| 		Token:      tokenString, | ||||
| 		ExpiryTime: expiryTime, | ||||
| 	}) | ||||
|  | ||||
| 	return tokenString | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { | ||||
| 	messages := make([]ZhipuMessage, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    "system", | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	return &ZhipuRequest{ | ||||
| 		Prompt:      messages, | ||||
| 		Temperature: request.Temperature, | ||||
| 		TopP:        request.TopP, | ||||
| 		Incremental: false, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      response.Data.TaskId, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)), | ||||
| 		Usage:   response.Data.Usage, | ||||
| 	} | ||||
| 	for i, choice := range response.Data.Choices { | ||||
| 		openaiChoice := OpenAITextResponseChoice{ | ||||
| 			Index: i, | ||||
| 			Message: Message{ | ||||
| 				Role:    choice.Role, | ||||
| 				Content: strings.Trim(choice.Content, "\""), | ||||
| 			}, | ||||
| 			FinishReason: "", | ||||
| 		} | ||||
| 		if i == len(response.Data.Choices)-1 { | ||||
| 			openaiChoice.FinishReason = "stop" | ||||
| 		} | ||||
| 		fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice) | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = zhipuResponse | ||||
| 	choice.FinishReason = "" | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "chatglm", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = "" | ||||
| 	choice.FinishReason = "stop" | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Id:      zhipuResponse.RequestId, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "chatglm", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response, &zhipuResponse.Usage | ||||
| } | ||||
|  | ||||
| func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage *Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 { | ||||
| 			return i + 2, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	metaChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			lines := strings.Split(data, "\n") | ||||
| 			for i, line := range lines { | ||||
| 				if len(line) < 5 { | ||||
| 					continue | ||||
| 				} | ||||
| 				if line[:5] == "data:" { | ||||
| 					dataChan <- line[5:] | ||||
| 					if i != len(lines)-1 { | ||||
| 						dataChan <- "\n" | ||||
| 					} | ||||
| 				} else if line[:5] == "meta:" { | ||||
| 					metaChan <- line[5:] | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			response := streamResponseZhipu2OpenAI(data) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case data := <-metaChan: | ||||
| 			var zhipuResponse ZhipuStreamMetaResponse | ||||
| 			err := json.Unmarshal([]byte(data), &zhipuResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			usage = zhipuUsage | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, usage | ||||
| } | ||||
|  | ||||
| func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var zhipuResponse ZhipuResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &zhipuResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if !zhipuResponse.Success { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: zhipuResponse.Msg, | ||||
| 				Type:    "zhipu_error", | ||||
| 				Param:   "", | ||||
| 				Code:    zhipuResponse.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseZhipu2OpenAI(&zhipuResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
| @@ -2,10 +2,12 @@ package controller | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type Message struct { | ||||
| @@ -19,22 +21,25 @@ const ( | ||||
| 	RelayModeChatCompletions | ||||
| 	RelayModeCompletions | ||||
| 	RelayModeEmbeddings | ||||
| 	RelayModeModeration | ||||
| 	RelayModeModerations | ||||
| 	RelayModeImagesGenerations | ||||
| 	RelayModeEdits | ||||
| ) | ||||
|  | ||||
| // https://platform.openai.com/docs/api-reference/chat | ||||
|  | ||||
| type GeneralOpenAIRequest struct { | ||||
| 	Model       string    `json:"model"` | ||||
| 	Messages    []Message `json:"messages"` | ||||
| 	Prompt      any       `json:"prompt"` | ||||
| 	Stream      bool      `json:"stream"` | ||||
| 	MaxTokens   int       `json:"max_tokens"` | ||||
| 	Temperature float64   `json:"temperature"` | ||||
| 	TopP        float64   `json:"top_p"` | ||||
| 	N           int       `json:"n"` | ||||
| 	Input       any       `json:"input"` | ||||
| 	Model       string    `json:"model,omitempty"` | ||||
| 	Messages    []Message `json:"messages,omitempty"` | ||||
| 	Prompt      any       `json:"prompt,omitempty"` | ||||
| 	Stream      bool      `json:"stream,omitempty"` | ||||
| 	MaxTokens   int       `json:"max_tokens,omitempty"` | ||||
| 	Temperature float64   `json:"temperature,omitempty"` | ||||
| 	TopP        float64   `json:"top_p,omitempty"` | ||||
| 	N           int       `json:"n,omitempty"` | ||||
| 	Input       any       `json:"input,omitempty"` | ||||
| 	Instruction string    `json:"instruction,omitempty"` | ||||
| 	Size        string    `json:"size,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
| @@ -51,6 +56,12 @@ type TextRequest struct { | ||||
| 	//Stream   bool      `json:"stream"` | ||||
| } | ||||
|  | ||||
| type ImageRequest struct { | ||||
| 	Prompt string `json:"prompt"` | ||||
| 	N      int    `json:"n"` | ||||
| 	Size   string `json:"size"` | ||||
| } | ||||
|  | ||||
| type Usage struct { | ||||
| 	PromptTokens     int `json:"prompt_tokens"` | ||||
| 	CompletionTokens int `json:"completion_tokens"` | ||||
| @@ -70,17 +81,58 @@ type OpenAIErrorWithStatusCode struct { | ||||
| } | ||||
|  | ||||
| type TextResponse struct { | ||||
| 	Usage `json:"usage"` | ||||
| 	Error OpenAIError `json:"error"` | ||||
| 	Choices []OpenAITextResponseChoice `json:"choices"` | ||||
| 	Usage   `json:"usage"` | ||||
| 	Error   OpenAIError `json:"error"` | ||||
| } | ||||
|  | ||||
| type OpenAITextResponseChoice struct { | ||||
| 	Index        int `json:"index"` | ||||
| 	Message      `json:"message"` | ||||
| 	FinishReason string `json:"finish_reason"` | ||||
| } | ||||
|  | ||||
| type OpenAITextResponse struct { | ||||
| 	Id      string                     `json:"id"` | ||||
| 	Object  string                     `json:"object"` | ||||
| 	Created int64                      `json:"created"` | ||||
| 	Choices []OpenAITextResponseChoice `json:"choices"` | ||||
| 	Usage   `json:"usage"` | ||||
| } | ||||
|  | ||||
| type OpenAIEmbeddingResponseItem struct { | ||||
| 	Object    string    `json:"object"` | ||||
| 	Index     int       `json:"index"` | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| } | ||||
|  | ||||
| type OpenAIEmbeddingResponse struct { | ||||
| 	Object string                        `json:"object"` | ||||
| 	Data   []OpenAIEmbeddingResponseItem `json:"data"` | ||||
| 	Model  string                        `json:"model"` | ||||
| 	Usage  `json:"usage"` | ||||
| } | ||||
|  | ||||
| type ImageResponse struct { | ||||
| 	Created int `json:"created"` | ||||
| 	Data    []struct { | ||||
| 		Url string `json:"url"` | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type ChatCompletionsStreamResponseChoice struct { | ||||
| 	Delta struct { | ||||
| 		Content string `json:"content"` | ||||
| 	} `json:"delta"` | ||||
| 	FinishReason string `json:"finish_reason,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatCompletionsStreamResponse struct { | ||||
| 	Choices []struct { | ||||
| 		Delta struct { | ||||
| 			Content string `json:"content"` | ||||
| 		} `json:"delta"` | ||||
| 		FinishReason string `json:"finish_reason"` | ||||
| 	} `json:"choices"` | ||||
| 	Id      string                                `json:"id"` | ||||
| 	Object  string                                `json:"object"` | ||||
| 	Created int64                                 `json:"created"` | ||||
| 	Model   string                                `json:"model"` | ||||
| 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"` | ||||
| } | ||||
|  | ||||
| type CompletionsStreamResponse struct { | ||||
| @@ -98,10 +150,14 @@ func Relay(c *gin.Context) { | ||||
| 		relayMode = RelayModeCompletions | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { | ||||
| 		relayMode = RelayModeEmbeddings | ||||
| 	} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||
| 		relayMode = RelayModeEmbeddings | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||
| 		relayMode = RelayModeModeration | ||||
| 		relayMode = RelayModeModerations | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||
| 		relayMode = RelayModeImagesGenerations | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { | ||||
| 		relayMode = RelayModeEdits | ||||
| 	} | ||||
| 	var err *OpenAIErrorWithStatusCode | ||||
| 	switch relayMode { | ||||
| @@ -111,16 +167,25 @@ func Relay(c *gin.Context) { | ||||
| 		err = relayTextHelper(c, relayMode) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		if err.StatusCode == http.StatusTooManyRequests { | ||||
| 			err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" | ||||
| 		retryTimesStr := c.Query("retry") | ||||
| 		retryTimes, _ := strconv.Atoi(retryTimesStr) | ||||
| 		if retryTimesStr == "" { | ||||
| 			retryTimes = common.RetryTimes | ||||
| 		} | ||||
| 		if retryTimes > 0 { | ||||
| 			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) | ||||
| 		} else { | ||||
| 			if err.StatusCode == http.StatusTooManyRequests { | ||||
| 				err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" | ||||
| 			} | ||||
| 			c.JSON(err.StatusCode, gin.H{ | ||||
| 				"error": err.OpenAIError, | ||||
| 			}) | ||||
| 		} | ||||
| 		c.JSON(err.StatusCode, gin.H{ | ||||
| 			"error": err.OpenAIError, | ||||
| 		}) | ||||
| 		channelId := c.GetInt("channel_id") | ||||
| 		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | ||||
| 		// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||
| 		if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key") { | ||||
| 		if shouldDisableChannel(&err.OpenAIError) { | ||||
| 			channelId := c.GetInt("channel_id") | ||||
| 			channelName := c.GetString("channel_name") | ||||
| 			disableChannel(channelId, channelName, err.Message) | ||||
|   | ||||
| @@ -180,10 +180,10 @@ func UpdateToken(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
| 	if token.Status == common.TokenStatusEnabled { | ||||
| 		if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() { | ||||
| 		if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "令牌已过期,无法启用,请先修改令牌过期时间", | ||||
| 				"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
|   | ||||
| @@ -3,12 +3,13 @@ package controller | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
|  | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type LoginRequest struct { | ||||
| @@ -477,6 +478,16 @@ func DeleteUser(c *gin.Context) { | ||||
|  | ||||
| func DeleteSelf(c *gin.Context) { | ||||
| 	id := c.GetInt("id") | ||||
| 	user, _ := model.GetUserById(id, false) | ||||
|  | ||||
| 	if user.Role == common.RoleRootUser { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "不能删除超级管理员账户", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err := model.DeleteUserById(id) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
|   | ||||
							
								
								
									
										4
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.mod
									
									
									
									
									
								
							| @@ -11,7 +11,9 @@ require ( | ||||
| 	github.com/gin-gonic/gin v1.9.1 | ||||
| 	github.com/go-playground/validator/v10 v10.14.0 | ||||
| 	github.com/go-redis/redis/v8 v8.11.5 | ||||
| 	github.com/golang-jwt/jwt v3.2.2+incompatible | ||||
| 	github.com/google/uuid v1.3.0 | ||||
| 	github.com/gorilla/websocket v1.5.0 | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.1 | ||||
| 	golang.org/x/crypto v0.9.0 | ||||
| 	gorm.io/driver/mysql v1.4.3 | ||||
| @@ -20,7 +22,6 @@ require ( | ||||
| ) | ||||
|  | ||||
| require ( | ||||
| 	github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff // indirect | ||||
| 	github.com/bytedance/sonic v1.9.1 // indirect | ||||
| 	github.com/cespare/xxhash/v2 v2.1.2 // indirect | ||||
| 	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect | ||||
| @@ -32,7 +33,6 @@ require ( | ||||
| 	github.com/go-playground/universal-translator v0.18.1 // indirect | ||||
| 	github.com/go-sql-driver/mysql v1.6.0 // indirect | ||||
| 	github.com/goccy/go-json v0.10.2 // indirect | ||||
| 	github.com/gomodule/redigo v2.0.0+incompatible // indirect | ||||
| 	github.com/gorilla/context v1.1.1 // indirect | ||||
| 	github.com/gorilla/securecookie v1.1.1 // indirect | ||||
| 	github.com/gorilla/sessions v1.2.1 // indirect | ||||
|   | ||||
							
								
								
									
										9
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								go.sum
									
									
									
									
									
								
							| @@ -1,5 +1,3 @@ | ||||
| github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff h1:RmdPFa+slIr4SCBg4st/l/vZWVe9QJKMXGO60Bxbe04= | ||||
| github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P97fT2CiHkbFQwkK3mjsFAP6zCYV2aXtjw= | ||||
| github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= | ||||
| github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= | ||||
| github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= | ||||
| @@ -54,10 +52,10 @@ github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB | ||||
| github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= | ||||
| github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= | ||||
| github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= | ||||
| github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= | ||||
| github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= | ||||
| github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= | ||||
| github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= | ||||
| github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= | ||||
| github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= | ||||
| github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= | ||||
| github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | ||||
| github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= | ||||
| @@ -67,9 +65,10 @@ github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8 | ||||
| github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= | ||||
| github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= | ||||
| github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= | ||||
| github.com/gorilla/sessions v1.1.1/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w= | ||||
| github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= | ||||
| github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= | ||||
| github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= | ||||
| github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= | ||||
| github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | ||||
| github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | ||||
| github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||
|   | ||||
							
								
								
									
										121
									
								
								i18n/en.json
									
									
									
									
									
								
							
							
						
						
									
										121
									
								
								i18n/en.json
									
									
									
									
									
								
							| @@ -3,6 +3,11 @@ | ||||
|   "%d 点额度": "%d point quota", | ||||
|   "尚未实现": "Not yet implemented", | ||||
|   "余额不足": "Insufficient balance", | ||||
|   "危险操作": "Hazardous operations", | ||||
|   "输入你的账户名": "Enter your account name", | ||||
|   "确认删除": "Confirm Delete", | ||||
|   "确认绑定": "Confirm Binding", | ||||
|   "您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.", | ||||
|   "\"通道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"", | ||||
|   "通道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s", | ||||
|   "测试已在运行中": "Test is already running", | ||||
| @@ -36,7 +41,7 @@ | ||||
|   "通过令牌「%s」使用模型 %s 消耗 %s(模型倍率 %.2f,分组倍率 %.2f)": "Using model %s with token %s consumes %s (model rate %.2f, group rate %.2f)", | ||||
|   "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。": "The current group load is saturated, please try again later, or upgrade your account to improve service quality.", | ||||
|   "令牌名称长度必须在1-20之间": "The length of the token name must be between 1-20", | ||||
|   "令牌已过期,无法启用,请先修改令牌过期时间": "The token has expired and cannot be enabled. Please modify the token expiration time first", | ||||
|   "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期": "The token has expired and cannot be enabled. Please modify the expiration time of the token, or set it to never expire.", | ||||
|   "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度": "The available quota of the token has been used up and cannot be enabled. Please modify the remaining quota of the token, or set it to unlimited quota", | ||||
|   "管理员关闭了密码登录": "The administrator has turned off password login", | ||||
|   "无法保存会话信息,请重试": "Unable to save session information, please try again", | ||||
| @@ -107,6 +112,11 @@ | ||||
|   "已禁用": "Disabled", | ||||
|   "未知状态": "Unknown status", | ||||
|   " 秒": "s", | ||||
|   " 分钟 ": " m ", | ||||
|   " 小时 ": " h ", | ||||
|   " 天 ": " d ", | ||||
|   " 个月 ": " M ", | ||||
|   " 年 ": " y ", | ||||
|   "未测试": "Not tested", | ||||
|   "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", | ||||
|   "已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", | ||||
| @@ -366,12 +376,12 @@ | ||||
|   "添加新的用户": "Add New User", | ||||
|   "自定义": "Custom", | ||||
|   "等价金额": "Equivalent Amount", | ||||
|   "错误": "Error", | ||||
|   "错误:未登录或登录已过期,请重新登录": "Error: Not logged in or login has expired, please log in again", | ||||
|   "错误:请求次数过多,请稍后再试": "Error: Too many requests, please try again later", | ||||
|   "错误:服务器内部错误,请联系管理员": "Error: Server internal error, please contact the administrator", | ||||
|   "未登录或登录已过期,请重新登录": "Not logged in or login has expired, please log in again", | ||||
|   "请求次数过多,请稍后再试": "Too many requests, please try again later", | ||||
|   "服务器内部错误,请联系管理员": "Server internal error, please contact the administrator", | ||||
|   "本站仅作演示之用,无服务端": "This site is for demonstration purposes only, no server-side", | ||||
|   "错误:": "Error:", | ||||
|   "超级管理员未设置充值链接!": "Super administrator has not set the recharge link!", | ||||
|   "错误:": "Error: ", | ||||
|   "新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面": "New version available: ${data.version}, please refresh the page using shortcut Shift + F5", | ||||
|   "无法正常连接至服务器": "Unable to connect to the server normally", | ||||
|   "管理渠道": "Manage Channels", | ||||
| @@ -384,7 +394,7 @@ | ||||
|   "系统配置": "System Configuration", | ||||
|   "系统配置总览": "System Configuration Overview", | ||||
|   "邮箱验证": "Email Verification", | ||||
|   "未": "Not ", | ||||
|   "未启用": "Not Enabled", | ||||
|   "GitHub 身份验证": "GitHub Authentication", | ||||
|   "微信身份验证": "WeChat Authentication", | ||||
|   "Turnstile 用户校验": "Turnstile User Verification", | ||||
| @@ -411,9 +421,100 @@ | ||||
|   "其他设置": "Other Settings", | ||||
|   "项目仓库地址": "Project Repository Address", | ||||
|   "可在设置页面设置关于内容,支持 HTML & Markdown": "You can set the content about in the settings page, support HTML & Markdown", | ||||
|   "由{' '}": "build by{' '}", | ||||
|   "由{' '}": "built by{' '}", | ||||
|   "构建,源代码遵循{' '}": ", the source code licensed under{' '}", | ||||
|   "MIT 协议": "MIT License", | ||||
|   "充值额度": "Recharge Quota", | ||||
|   "获取兑换码": "Get Redeem Code" | ||||
| } | ||||
|   "获取兑换码": "Get Redeem Code", | ||||
|   "一个月后过期": "Expires after one month", | ||||
|   "一天后过期": "Expires after one day", | ||||
|   "一小时后过期": "Expires after one hour", | ||||
|   "一分钟后过期": "Expires after one minute", | ||||
|   "创建新的令牌": "Create New Token", | ||||
|   "注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note that the quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account.", | ||||
|   "设为无限额度": "Set to unlimited quota", | ||||
|   "更新令牌信息": "Update Token Information", | ||||
|   "请输入充值码!": "Please enter the recharge code!", | ||||
|   "请输入名称": "Please enter a name", | ||||
|   "请输入密钥,一行一个": "Please enter the key, one per line", | ||||
|   "请输入额度": "Please enter the quota", | ||||
|   "令牌创建成功": "Token created successfully", | ||||
|   "令牌更新成功": "Token updated successfully", | ||||
|   "充值成功!": "Recharge successful!", | ||||
|   "更新用户信息": "Update User Information", | ||||
|   "请输入新的用户名": "Please enter a new username", | ||||
|   "密码": "Password", | ||||
|   "请输入新的密码": "Please enter a new password", | ||||
|   "显示名称": "Display Name", | ||||
|   "请输入新的显示名称": "Please enter a new display name", | ||||
|   "已绑定的 GitHub 账户": "GitHub Account Bound", | ||||
|   "此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改": "This item is read-only. Users need to bind through the relevant binding button on the personal settings page, and cannot be modified directly", | ||||
|   "已绑定的微信账户": "WeChat Account Bound", | ||||
|   "已绑定的邮箱账户": "Email Account Bound", | ||||
|   "用户信息更新成功!": "User information updated successfully!", | ||||
|   "模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f", | ||||
|   "使用明细(总消耗额度:{renderQuota(stat.quota)})": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})", | ||||
|   "用户名称": "User Name", | ||||
|   "令牌名称": "Token Name", | ||||
|   "留空则查询全部用户": "Leave blank to query all users", | ||||
|   "留空则查询全部令牌": "Leave blank to query all tokens", | ||||
|   "模型名称": "Model Name", | ||||
|   "留空则查询全部模型": "Leave blank to query all models", | ||||
|   "起始时间": "Start Time", | ||||
|   "结束时间": "End Time", | ||||
|   "查询": "Query", | ||||
|   "提示": "Prompt", | ||||
|   "补全": "Completion", | ||||
|   "消耗额度": "Used Quota", | ||||
|   "可选值": "Optional Values", | ||||
|   "渠道不存在:%d": "Channel does not exist: %d", | ||||
|   "数据库一致性已被破坏,请联系管理员": "Database consistency has been broken, please contact the administrator", | ||||
|   "使用近似的方式估算 token 数以减少计算量": "Estimate the number of tokens in an approximate way to reduce computational load", | ||||
|   "请填写ChannelName和ChannelKey!": "Please fill in the ChannelName and ChannelKey!", | ||||
|   "请至少选择一个Model!": "Please select at least one Model!", | ||||
|   "加载首页内容失败": "Failed to load the homepage content", | ||||
|   "加载关于内容失败": "Failed to load the About content", | ||||
|   "兑换码更新成功!": "Redemption code updated successfully!", | ||||
|   "兑换码创建成功!": "Redemption code created successfully!", | ||||
|   "用户账户创建成功!": "User account created successfully!", | ||||
|   "生成数量": "Generate quantity", | ||||
|   "请输入生成数量": "Please enter the quantity to generate", | ||||
|   "创建新用户账户": "Create new user account", | ||||
|   "渠道更新成功!": "Channel updated successfully!", | ||||
|   "渠道创建成功!": "Channel created successfully!", | ||||
|   "请选择分组": "Please select a group", | ||||
|   "更新兑换码信息": "Update redemption code information", | ||||
|   "创建新的兑换码": "Create a new redemption code", | ||||
|   "请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit the group ratio in the system settings page to add a new group:", | ||||
|   "未找到所请求的页面": "The requested page was not found", | ||||
|   "过期时间格式错误!": "Expiration time format error!", | ||||
|   "请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss,-1 表示无限制": "Please enter the expiration time, the format is yyyy-MM-dd HH:mm:ss, -1 means no limit", | ||||
|   "此项可选,为一个 JSON 文本,键为用户请求的模型名称,值为要替换的模型名称,例如:": "This is optional, it's a JSON text, the key is the model name requested by the user, and the value is the model name to be replaced, for example:", | ||||
|   "此项可选,输入镜像站地址,格式为:": "This is optional, enter the mirror site address, the format is:", | ||||
|   "模型映射": "Model mapping", | ||||
|   "请输入默认 API 版本,例如:2023-03-15-preview,该配置可以被实际的请求查询参数所覆盖": "Please enter the default API version, for example: 2023-03-15-preview, this configuration can be overridden by the actual request query parameters", | ||||
|   "默认": "Default", | ||||
|   "图片演示": "Image demo", | ||||
|   "参数替换为你的部署名称(模型名称中的点会被剔除)": "Replace the parameter with your deployment name (dots in the model name will be removed)", | ||||
|   "模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!", | ||||
|   "取消无限额度": "Cancel unlimited quota", | ||||
|   "取消": "Cancel", | ||||
|   "请输入新的剩余额度": "Please enter the new remaining quota", | ||||
|   "请输入单个兑换码中包含的额度": "Please enter the quota included in a single redemption code", | ||||
|   "请输入用户名": "Please enter username", | ||||
|   "请输入显示名称": "Please enter display name", | ||||
|   "请输入密码": "Please enter password", | ||||
|   "模型部署名称必须和模型名称保持一致": "The model deployment name must be consistent with the model name", | ||||
|   ",因为 One API 会把请求体中的 model": ", because One API will take the model in the request body", | ||||
|   "请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT", | ||||
|   "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", | ||||
|   "Homepage URL 填": "Fill in the Homepage URL", | ||||
|   "Authorization callback URL 填": "Fill in the Authorization callback URL", | ||||
|   "请为通道命名": "Please name the channel", | ||||
|   "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", | ||||
|   "模型重定向": "Model redirection", | ||||
|   "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel", | ||||
|   "注意,": "Note that, ", | ||||
|   ",图片演示。": "related image demo.", | ||||
|   "令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!" | ||||
| } | ||||
|   | ||||
							
								
								
									
										12
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								main.go
									
									
									
									
									
								
							| @@ -4,7 +4,6 @@ import ( | ||||
| 	"embed" | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-contrib/sessions/cookie" | ||||
| 	"github.com/gin-contrib/sessions/redis" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"one-api/common" | ||||
| 	"one-api/controller" | ||||
| @@ -55,6 +54,7 @@ func main() { | ||||
| 		if err != nil { | ||||
| 			common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error()) | ||||
| 		} | ||||
| 		common.SyncFrequency = frequency | ||||
| 		go model.SyncOptions(frequency) | ||||
| 		if common.RedisEnabled { | ||||
| 			go model.SyncChannelCache(frequency) | ||||
| @@ -82,14 +82,8 @@ func main() { | ||||
| 	server.Use(middleware.CORS()) | ||||
|  | ||||
| 	// Initialize session store | ||||
| 	if common.RedisEnabled { | ||||
| 		opt := common.ParseRedisOption() | ||||
| 		store, _ := redis.NewStore(opt.MinIdleConns, opt.Network, opt.Addr, opt.Password, []byte(common.SessionSecret)) | ||||
| 		server.Use(sessions.Sessions("session", store)) | ||||
| 	} else { | ||||
| 		store := cookie.NewStore([]byte(common.SessionSecret)) | ||||
| 		server.Use(sessions.Sessions("session", store)) | ||||
| 	} | ||||
| 	store := cookie.NewStore([]byte(common.SessionSecret)) | ||||
| 	server.Use(sessions.Sessions("session", store)) | ||||
|  | ||||
| 	router.SetRouter(server, buildFS, indexPage) | ||||
| 	var port = os.Getenv("PORT") | ||||
|   | ||||
| @@ -2,12 +2,13 @@ package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type ModelRequest struct { | ||||
| @@ -73,11 +74,26 @@ func Distribute() func(c *gin.Context) { | ||||
| 					modelRequest.Model = "text-moderation-stable" | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = c.Param("model") | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "dall-e" | ||||
| 				} | ||||
| 			} | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) | ||||
| 			if err != nil { | ||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | ||||
| 				if channel != nil { | ||||
| 					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | ||||
| 					message = "数据库一致性已被破坏,请联系管理员" | ||||
| 				} | ||||
| 				c.JSON(http.StatusServiceUnavailable, gin.H{ | ||||
| 					"error": gin.H{ | ||||
| 						"message": "无可用渠道", | ||||
| 						"message": message, | ||||
| 						"type":    "one_api_error", | ||||
| 					}, | ||||
| 				}) | ||||
| @@ -88,6 +104,7 @@ func Distribute() func(c *gin.Context) { | ||||
| 		c.Set("channel", channel.Type) | ||||
| 		c.Set("channel_id", channel.Id) | ||||
| 		c.Set("channel_name", channel.Name) | ||||
| 		c.Set("model_mapping", channel.ModelMapping) | ||||
| 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||
| 		c.Set("base_url", channel.BaseURL) | ||||
| 		if channel.Type == common.ChannelTypeAzure { | ||||
|   | ||||
| @@ -24,6 +24,7 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	channel := Channel{} | ||||
| 	channel.Id = ability.ChannelId | ||||
| 	err = DB.First(&channel, "id = ?", ability.ChannelId).Error | ||||
| 	return &channel, err | ||||
| } | ||||
|   | ||||
| @@ -12,11 +12,11 @@ import ( | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	TokenCacheSeconds         = 60 * 60 | ||||
| 	UserId2GroupCacheSeconds  = 60 * 60 | ||||
| 	UserId2QuotaCacheSeconds  = 10 * 60 | ||||
| 	UserId2StatusCacheSeconds = 60 * 60 | ||||
| var ( | ||||
| 	TokenCacheSeconds         = common.SyncFrequency | ||||
| 	UserId2GroupCacheSeconds  = common.SyncFrequency | ||||
| 	UserId2QuotaCacheSeconds  = common.SyncFrequency | ||||
| 	UserId2StatusCacheSeconds = common.SyncFrequency | ||||
| ) | ||||
|  | ||||
| func CacheGetTokenByKey(key string) (*Token, error) { | ||||
| @@ -35,7 +35,7 @@ func CacheGetTokenByKey(key string) (*Token, error) { | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), TokenCacheSeconds*time.Second) | ||||
| 		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) | ||||
| 		if err != nil { | ||||
| 			common.SysError("Redis set token error: " + err.Error()) | ||||
| 		} | ||||
| @@ -55,7 +55,7 @@ func CacheGetUserGroup(id int) (group string, err error) { | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, UserId2GroupCacheSeconds*time.Second) | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) | ||||
| 		if err != nil { | ||||
| 			common.SysError("Redis set user group error: " + err.Error()) | ||||
| 		} | ||||
| @@ -73,7 +73,7 @@ func CacheGetUserQuota(id int) (quota int, err error) { | ||||
| 		if err != nil { | ||||
| 			return 0, err | ||||
| 		} | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second) | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||
| 		if err != nil { | ||||
| 			common.SysError("Redis set user quota error: " + err.Error()) | ||||
| 		} | ||||
| @@ -83,6 +83,18 @@ func CacheGetUserQuota(id int) (quota int, err error) { | ||||
| 	return quota, err | ||||
| } | ||||
|  | ||||
| func CacheUpdateUserQuota(id int) error { | ||||
| 	if !common.RedisEnabled { | ||||
| 		return nil | ||||
| 	} | ||||
| 	quota, err := GetUserQuota(id) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func CacheIsUserEnabled(userId int) bool { | ||||
| 	if !common.RedisEnabled { | ||||
| 		return IsUserEnabled(userId) | ||||
| @@ -94,7 +106,7 @@ func CacheIsUserEnabled(userId int) bool { | ||||
| 			status = common.UserStatusEnabled | ||||
| 		} | ||||
| 		enabled = fmt.Sprintf("%d", status) | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, UserId2StatusCacheSeconds*time.Second) | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) | ||||
| 		if err != nil { | ||||
| 			common.SysError("Redis set user enabled error: " + err.Error()) | ||||
| 		} | ||||
| @@ -108,7 +120,7 @@ var channelSyncLock sync.RWMutex | ||||
| func InitChannelCache() { | ||||
| 	newChannelId2channel := make(map[int]*Channel) | ||||
| 	var channels []*Channel | ||||
| 	DB.Find(&channels) | ||||
| 	DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) | ||||
| 	for _, channel := range channels { | ||||
| 		newChannelId2channel[channel.Id] = channel | ||||
| 	} | ||||
|   | ||||
| @@ -22,6 +22,7 @@ type Channel struct { | ||||
| 	Models             string  `json:"models"` | ||||
| 	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"` | ||||
| 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"` | ||||
| 	ModelMapping       string  `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | ||||
| } | ||||
|  | ||||
| func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | ||||
| @@ -36,7 +37,7 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | ||||
| } | ||||
|  | ||||
| func SearchChannels(keyword string) (channels []*Channel, err error) { | ||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or key = ?", keyword, keyword+"%", keyword).Find(&channels).Error | ||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error | ||||
| 	return channels, err | ||||
| } | ||||
|  | ||||
|   | ||||
							
								
								
									
										112
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										112
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -6,11 +6,17 @@ import ( | ||||
| ) | ||||
|  | ||||
| type Log struct { | ||||
| 	Id        int    `json:"id"` | ||||
| 	UserId    int    `json:"user_id" gorm:"index"` | ||||
| 	CreatedAt int64  `json:"created_at" gorm:"bigint"` | ||||
| 	Type      int    `json:"type" gorm:"index"` | ||||
| 	Content   string `json:"content"` | ||||
| 	Id               int    `json:"id"` | ||||
| 	UserId           int    `json:"user_id"` | ||||
| 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index"` | ||||
| 	Type             int    `json:"type" gorm:"index"` | ||||
| 	Content          string `json:"content"` | ||||
| 	Username         string `json:"username" gorm:"index;default:''"` | ||||
| 	TokenName        string `json:"token_name" gorm:"index;default:''"` | ||||
| 	ModelName        string `json:"model_name" gorm:"index;default:''"` | ||||
| 	Quota            int    `json:"quota" gorm:"default:0"` | ||||
| 	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"` | ||||
| 	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"` | ||||
| } | ||||
|  | ||||
| const ( | ||||
| @@ -27,6 +33,7 @@ func RecordLog(userId int, logType int, content string) { | ||||
| 	} | ||||
| 	log := &Log{ | ||||
| 		UserId:    userId, | ||||
| 		Username:  GetUsernameById(userId), | ||||
| 		CreatedAt: common.GetTimestamp(), | ||||
| 		Type:      logType, | ||||
| 		Content:   content, | ||||
| @@ -37,24 +44,73 @@ func RecordLog(userId int, logType int, content string) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetAllLogs(logType int, startIdx int, num int) (logs []*Log, err error) { | ||||
| func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | ||||
| 	if !common.LogConsumeEnabled { | ||||
| 		return | ||||
| 	} | ||||
| 	log := &Log{ | ||||
| 		UserId:           userId, | ||||
| 		Username:         GetUsernameById(userId), | ||||
| 		CreatedAt:        common.GetTimestamp(), | ||||
| 		Type:             LogTypeConsume, | ||||
| 		Content:          content, | ||||
| 		PromptTokens:     promptTokens, | ||||
| 		CompletionTokens: completionTokens, | ||||
| 		TokenName:        tokenName, | ||||
| 		ModelName:        modelName, | ||||
| 		Quota:            quota, | ||||
| 	} | ||||
| 	err := DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| 		common.SysError("failed to record log: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) { | ||||
| 	var tx *gorm.DB | ||||
| 	if logType == LogTypeUnknown { | ||||
| 		tx = DB | ||||
| 	} else { | ||||
| 		tx = DB.Where("type = ?", logType) | ||||
| 	} | ||||
| 	if modelName != "" { | ||||
| 		tx = tx.Where("model_name = ?", modelName) | ||||
| 	} | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
| 	} | ||||
| 	if tokenName != "" { | ||||
| 		tx = tx.Where("token_name = ?", tokenName) | ||||
| 	} | ||||
| 	if startTimestamp != 0 { | ||||
| 		tx = tx.Where("created_at >= ?", startTimestamp) | ||||
| 	} | ||||
| 	if endTimestamp != 0 { | ||||
| 		tx = tx.Where("created_at <= ?", endTimestamp) | ||||
| 	} | ||||
| 	err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error | ||||
| 	return logs, err | ||||
| } | ||||
|  | ||||
| func GetUserLogs(userId int, logType int, startIdx int, num int) (logs []*Log, err error) { | ||||
| func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { | ||||
| 	var tx *gorm.DB | ||||
| 	if logType == LogTypeUnknown { | ||||
| 		tx = DB.Where("user_id = ?", userId) | ||||
| 	} else { | ||||
| 		tx = DB.Where("user_id = ? and type = ?", userId, logType) | ||||
| 	} | ||||
| 	if modelName != "" { | ||||
| 		tx = tx.Where("model_name = ?", modelName) | ||||
| 	} | ||||
| 	if tokenName != "" { | ||||
| 		tx = tx.Where("token_name = ?", tokenName) | ||||
| 	} | ||||
| 	if startTimestamp != 0 { | ||||
| 		tx = tx.Where("created_at >= ?", startTimestamp) | ||||
| 	} | ||||
| 	if endTimestamp != 0 { | ||||
| 		tx = tx.Where("created_at <= ?", endTimestamp) | ||||
| 	} | ||||
| 	err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error | ||||
| 	return logs, err | ||||
| } | ||||
| @@ -68,3 +124,45 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | ||||
| 	err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error | ||||
| 	return logs, err | ||||
| } | ||||
|  | ||||
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) { | ||||
| 	tx := DB.Table("logs").Select("sum(quota)") | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
| 	} | ||||
| 	if tokenName != "" { | ||||
| 		tx = tx.Where("token_name = ?", tokenName) | ||||
| 	} | ||||
| 	if startTimestamp != 0 { | ||||
| 		tx = tx.Where("created_at >= ?", startTimestamp) | ||||
| 	} | ||||
| 	if endTimestamp != 0 { | ||||
| 		tx = tx.Where("created_at <= ?", endTimestamp) | ||||
| 	} | ||||
| 	if modelName != "" { | ||||
| 		tx = tx.Where("model_name = ?", modelName) | ||||
| 	} | ||||
| 	tx.Where("type = ?", LogTypeConsume).Scan("a) | ||||
| 	return quota | ||||
| } | ||||
|  | ||||
| func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | ||||
| 	tx := DB.Table("logs").Select("sum(prompt_tokens) + sum(completion_tokens)") | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
| 	} | ||||
| 	if tokenName != "" { | ||||
| 		tx = tx.Where("token_name = ?", tokenName) | ||||
| 	} | ||||
| 	if startTimestamp != 0 { | ||||
| 		tx = tx.Where("created_at >= ?", startTimestamp) | ||||
| 	} | ||||
| 	if endTimestamp != 0 { | ||||
| 		tx = tx.Where("created_at <= ?", endTimestamp) | ||||
| 	} | ||||
| 	if modelName != "" { | ||||
| 		tx = tx.Where("model_name = ?", modelName) | ||||
| 	} | ||||
| 	tx.Where("type = ?", LogTypeConsume).Scan(&token) | ||||
| 	return token | ||||
| } | ||||
|   | ||||
| @@ -34,10 +34,13 @@ func InitOptionMap() { | ||||
| 	common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) | ||||
| 	common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) | ||||
| 	common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) | ||||
| 	common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) | ||||
| 	common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) | ||||
| 	common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) | ||||
| 	common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) | ||||
| 	common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) | ||||
| 	common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) | ||||
| 	common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") | ||||
| 	common.OptionMap["SMTPServer"] = "" | ||||
| 	common.OptionMap["SMTPFrom"] = "" | ||||
| 	common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) | ||||
| @@ -67,6 +70,7 @@ func InitOptionMap() { | ||||
| 	common.OptionMap["TopUpLink"] = common.TopUpLink | ||||
| 	common.OptionMap["ChatLink"] = common.ChatLink | ||||
| 	common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) | ||||
| 	common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) | ||||
| 	common.OptionMapRWMutex.Unlock() | ||||
| 	loadOptionsFromDatabase() | ||||
| } | ||||
| @@ -139,8 +143,12 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 			common.TurnstileCheckEnabled = boolValue | ||||
| 		case "RegisterEnabled": | ||||
| 			common.RegisterEnabled = boolValue | ||||
| 		case "EmailDomainRestrictionEnabled": | ||||
| 			common.EmailDomainRestrictionEnabled = boolValue | ||||
| 		case "AutomaticDisableChannelEnabled": | ||||
| 			common.AutomaticDisableChannelEnabled = boolValue | ||||
| 		case "ApproximateTokenEnabled": | ||||
| 			common.ApproximateTokenEnabled = boolValue | ||||
| 		case "LogConsumeEnabled": | ||||
| 			common.LogConsumeEnabled = boolValue | ||||
| 		case "DisplayInCurrencyEnabled": | ||||
| @@ -150,6 +158,8 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 		} | ||||
| 	} | ||||
| 	switch key { | ||||
| 	case "EmailDomainWhitelist": | ||||
| 		common.EmailDomainWhitelist = strings.Split(value, ",") | ||||
| 	case "SMTPServer": | ||||
| 		common.SMTPServer = value | ||||
| 	case "SMTPPort": | ||||
| @@ -193,6 +203,8 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 		common.QuotaRemindThreshold, _ = strconv.Atoi(value) | ||||
| 	case "PreConsumedQuota": | ||||
| 		common.PreConsumedQuota, _ = strconv.Atoi(value) | ||||
| 	case "RetryTimes": | ||||
| 		common.RetryTimes, _ = strconv.Atoi(value) | ||||
| 	case "ModelRatio": | ||||
| 		err = common.UpdateModelRatioByJSONString(value) | ||||
| 	case "GroupRatio": | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package model | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm" | ||||
| 	"one-api/common" | ||||
| ) | ||||
|  | ||||
| @@ -48,26 +49,28 @@ func Redeem(key string, userId int) (quota int, err error) { | ||||
| 		return 0, errors.New("无效的 user id") | ||||
| 	} | ||||
| 	redemption := &Redemption{} | ||||
| 	err = DB.Where("`key` = ?", key).First(redemption).Error | ||||
| 	if err != nil { | ||||
| 		return 0, errors.New("无效的兑换码") | ||||
| 	} | ||||
| 	if redemption.Status != common.RedemptionCodeStatusEnabled { | ||||
| 		return 0, errors.New("该兑换码已被使用") | ||||
| 	} | ||||
| 	err = IncreaseUserQuota(userId, redemption.Quota) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	go func() { | ||||
|  | ||||
| 	err = DB.Transaction(func(tx *gorm.DB) error { | ||||
| 		err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error | ||||
| 		if err != nil { | ||||
| 			return errors.New("无效的兑换码") | ||||
| 		} | ||||
| 		if redemption.Status != common.RedemptionCodeStatusEnabled { | ||||
| 			return errors.New("该兑换码已被使用") | ||||
| 		} | ||||
| 		err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		redemption.RedeemedTime = common.GetTimestamp() | ||||
| 		redemption.Status = common.RedemptionCodeStatusUsed | ||||
| 		err := redemption.SelectUpdate() | ||||
| 		if err != nil { | ||||
| 			common.SysError("failed to update redemption status: " + err.Error()) | ||||
| 		} | ||||
| 		RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) | ||||
| 	}() | ||||
| 		err = tx.Save(redemption).Error | ||||
| 		return err | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return 0, errors.New("兑换失败," + err.Error()) | ||||
| 	} | ||||
| 	RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) | ||||
| 	return redemption.Quota, nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -303,3 +303,8 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | ||||
| 		common.SysError("failed to update user used quota and request count: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetUsernameById(id int) (username string) { | ||||
| 	DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username) | ||||
| 	return username | ||||
| } | ||||
|   | ||||
| @@ -1,10 +1,11 @@ | ||||
| package router | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-contrib/gzip" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"one-api/controller" | ||||
| 	"one-api/middleware" | ||||
|  | ||||
| 	"github.com/gin-contrib/gzip" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func SetApiRouter(router *gin.Engine) { | ||||
| @@ -35,7 +36,7 @@ func SetApiRouter(router *gin.Engine) { | ||||
| 			{ | ||||
| 				selfRoute.GET("/self", controller.GetSelf) | ||||
| 				selfRoute.PUT("/self", controller.UpdateSelf) | ||||
| 				selfRoute.DELETE("/self", controller.DeleteSelf) | ||||
| 				selfRoute.DELETE("/self", middleware.TurnstileCheck(), controller.DeleteSelf) | ||||
| 				selfRoute.GET("/token", controller.GenerateAccessToken) | ||||
| 				selfRoute.GET("/aff", controller.GetAffCode) | ||||
| 				selfRoute.POST("/topup", controller.TopUp) | ||||
| @@ -96,6 +97,8 @@ func SetApiRouter(router *gin.Engine) { | ||||
| 		} | ||||
| 		logRoute := apiRouter.Group("/log") | ||||
| 		logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs) | ||||
| 		logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat) | ||||
| 		logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat) | ||||
| 		logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) | ||||
| 		logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs) | ||||
| 		logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs) | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| ) | ||||
| @@ -14,6 +15,10 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { | ||||
| 	SetDashboardRouter(router) | ||||
| 	SetRelayRouter(router) | ||||
| 	frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") | ||||
| 	if common.IsMasterNode && frontendBaseUrl != "" { | ||||
| 		frontendBaseUrl = "" | ||||
| 		common.SysLog("FRONTEND_BASE_URL is ignored on master node") | ||||
| 	} | ||||
| 	if frontendBaseUrl == "" { | ||||
| 		SetWebRouter(router, buildFS, indexPage) | ||||
| 	} else { | ||||
|   | ||||
| @@ -1,9 +1,10 @@ | ||||
| package router | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"one-api/controller" | ||||
| 	"one-api/middleware" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func SetRelayRouter(router *gin.Engine) { | ||||
| @@ -11,7 +12,7 @@ func SetRelayRouter(router *gin.Engine) { | ||||
| 	modelsRouter := router.Group("/v1/models") | ||||
| 	modelsRouter.Use(middleware.TokenAuth()) | ||||
| 	{ | ||||
| 		modelsRouter.GET("/", controller.ListModels) | ||||
| 		modelsRouter.GET("", controller.ListModels) | ||||
| 		modelsRouter.GET("/:model", controller.RetrieveModel) | ||||
| 	} | ||||
| 	relayV1Router := router.Group("/v1") | ||||
| @@ -19,11 +20,12 @@ func SetRelayRouter(router *gin.Engine) { | ||||
| 	{ | ||||
| 		relayV1Router.POST("/completions", controller.Relay) | ||||
| 		relayV1Router.POST("/chat/completions", controller.Relay) | ||||
| 		relayV1Router.POST("/edits", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/images/generations", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/edits", controller.Relay) | ||||
| 		relayV1Router.POST("/images/generations", controller.Relay) | ||||
| 		relayV1Router.POST("/images/edits", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/images/variations", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/embeddings", controller.Relay) | ||||
| 		relayV1Router.POST("/engines/:model/embeddings", controller.Relay) | ||||
| 		relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/audio/translations", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/files", controller.RelayNotImplemented) | ||||
|   | ||||
| @@ -30,6 +30,9 @@ function renderType(type) { | ||||
| function renderBalance(type, balance) { | ||||
|   switch (type) { | ||||
|     case 1: // OpenAI | ||||
|       return <span>${balance.toFixed(2)}</span>; | ||||
|     case 4: // CloseAI | ||||
|       return <span>¥{balance.toFixed(2)}</span>; | ||||
|     case 8: // 自定义 | ||||
|       return <span>${balance.toFixed(2)}</span>; | ||||
|     case 5: // OpenAI-SB | ||||
| @@ -38,6 +41,8 @@ function renderBalance(type, balance) { | ||||
|       return <span>{renderNumber(balance)}</span>; | ||||
|     case 12: // API2GPT | ||||
|       return <span>¥{balance.toFixed(2)}</span>; | ||||
|     case 13: // AIGC2D | ||||
|       return <span>{renderNumber(balance)}</span>; | ||||
|     default: | ||||
|       return <span>不支持</span>; | ||||
|   } | ||||
| @@ -58,8 +63,8 @@ const ChannelsTable = () => { | ||||
|       if (startIdx === 0) { | ||||
|         setChannels(data); | ||||
|       } else { | ||||
|         let newChannels = channels; | ||||
|         newChannels.push(...data); | ||||
|         let newChannels = [...channels]; | ||||
|         newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); | ||||
|         setChannels(newChannels); | ||||
|       } | ||||
|     } else { | ||||
| @@ -80,7 +85,7 @@ const ChannelsTable = () => { | ||||
|  | ||||
|   const refresh = async () => { | ||||
|     setLoading(true); | ||||
|     await loadChannels(0); | ||||
|     await loadChannels(activePage - 1); | ||||
|   }; | ||||
|  | ||||
|   useEffect(() => { | ||||
| @@ -238,7 +243,7 @@ const ChannelsTable = () => { | ||||
|     if (channels.length === 0) return; | ||||
|     setLoading(true); | ||||
|     let sortedChannels = [...channels]; | ||||
|     if (typeof sortedChannels[0][key] === 'string'){ | ||||
|     if (typeof sortedChannels[0][key] === 'string') { | ||||
|       sortedChannels.sort((a, b) => { | ||||
|         return ('' + a[key]).localeCompare(b[key]); | ||||
|       }); | ||||
| @@ -358,9 +363,12 @@ const ChannelsTable = () => { | ||||
|                   </Table.Cell> | ||||
|                   <Table.Cell> | ||||
|                     <Popup | ||||
|                       content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'} | ||||
|                       key={channel.id} | ||||
|                       trigger={renderBalance(channel.type, channel.balance)} | ||||
|                       trigger={<span onClick={() => { | ||||
|                         updateChannelBalance(channel.id, channel.name, idx); | ||||
|                       }} style={{ cursor: 'pointer' }}> | ||||
|                       {renderBalance(channel.type, channel.balance)} | ||||
|                     </span>} | ||||
|                       content="点击更新" | ||||
|                       basic | ||||
|                     /> | ||||
|                   </Table.Cell> | ||||
| @@ -375,16 +383,16 @@ const ChannelsTable = () => { | ||||
|                       > | ||||
|                         测试 | ||||
|                       </Button> | ||||
|                       <Button | ||||
|                         size={'small'} | ||||
|                         positive | ||||
|                         loading={updatingBalance} | ||||
|                         onClick={() => { | ||||
|                           updateChannelBalance(channel.id, channel.name, idx); | ||||
|                         }} | ||||
|                       > | ||||
|                         更新余额 | ||||
|                       </Button> | ||||
|                       {/*<Button*/} | ||||
|                       {/*  size={'small'}*/} | ||||
|                       {/*  positive*/} | ||||
|                       {/*  loading={updatingBalance}*/} | ||||
|                       {/*  onClick={() => {*/} | ||||
|                       {/*    updateChannelBalance(channel.id, channel.name, idx);*/} | ||||
|                       {/*  }}*/} | ||||
|                       {/*>*/} | ||||
|                       {/*  更新余额*/} | ||||
|                       {/*</Button>*/} | ||||
|                       <Popup | ||||
|                         trigger={ | ||||
|                           <Button size='small' negative> | ||||
|   | ||||
| @@ -1,15 +1,5 @@ | ||||
| import React, { useContext, useEffect, useState } from 'react'; | ||||
| import { | ||||
|   Button, | ||||
|   Divider, | ||||
|   Form, | ||||
|   Grid, | ||||
|   Header, | ||||
|   Image, | ||||
|   Message, | ||||
|   Modal, | ||||
|   Segment, | ||||
| } from 'semantic-ui-react'; | ||||
| import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react'; | ||||
| import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | ||||
| import { UserContext } from '../context/User'; | ||||
| import { API, getLogo, showError, showSuccess } from '../helpers'; | ||||
| @@ -18,19 +8,18 @@ const LoginForm = () => { | ||||
|   const [inputs, setInputs] = useState({ | ||||
|     username: '', | ||||
|     password: '', | ||||
|     wechat_verification_code: '', | ||||
|     wechat_verification_code: '' | ||||
|   }); | ||||
|   const [searchParams, setSearchParams] = useSearchParams(); | ||||
|   const [submitted, setSubmitted] = useState(false); | ||||
|   const { username, password } = inputs; | ||||
|   const [userState, userDispatch] = useContext(UserContext); | ||||
|   let navigate = useNavigate(); | ||||
|  | ||||
|   const [status, setStatus] = useState({}); | ||||
|   const logo = getLogo(); | ||||
|  | ||||
|   useEffect(() => { | ||||
|     if (searchParams.get("expired")) { | ||||
|     if (searchParams.get('expired')) { | ||||
|       showError('未登录或登录已过期,请重新登录!'); | ||||
|     } | ||||
|     let status = localStorage.getItem('status'); | ||||
| @@ -76,9 +65,9 @@ const LoginForm = () => { | ||||
|   async function handleSubmit(e) { | ||||
|     setSubmitted(true); | ||||
|     if (username && password) { | ||||
|       const res = await API.post('/api/user/login', { | ||||
|       const res = await API.post(`/api/user/login`, { | ||||
|         username, | ||||
|         password, | ||||
|         password | ||||
|       }); | ||||
|       const { success, message, data } = res.data; | ||||
|       if (success) { | ||||
| @@ -93,44 +82,44 @@ const LoginForm = () => { | ||||
|   } | ||||
|  | ||||
|   return ( | ||||
|     <Grid textAlign="center" style={{ marginTop: '48px' }}> | ||||
|     <Grid textAlign='center' style={{ marginTop: '48px' }}> | ||||
|       <Grid.Column style={{ maxWidth: 450 }}> | ||||
|         <Header as="h2" color="" textAlign="center"> | ||||
|         <Header as='h2' color='' textAlign='center'> | ||||
|           <Image src={logo} /> 用户登录 | ||||
|         </Header> | ||||
|         <Form size="large"> | ||||
|         <Form size='large'> | ||||
|           <Segment> | ||||
|             <Form.Input | ||||
|               fluid | ||||
|               icon="user" | ||||
|               iconPosition="left" | ||||
|               placeholder="用户名" | ||||
|               name="username" | ||||
|               icon='user' | ||||
|               iconPosition='left' | ||||
|               placeholder='用户名' | ||||
|               name='username' | ||||
|               value={username} | ||||
|               onChange={handleChange} | ||||
|             /> | ||||
|             <Form.Input | ||||
|               fluid | ||||
|               icon="lock" | ||||
|               iconPosition="left" | ||||
|               placeholder="密码" | ||||
|               name="password" | ||||
|               type="password" | ||||
|               icon='lock' | ||||
|               iconPosition='left' | ||||
|               placeholder='密码' | ||||
|               name='password' | ||||
|               type='password' | ||||
|               value={password} | ||||
|               onChange={handleChange} | ||||
|             /> | ||||
|             <Button color="" fluid size="large" onClick={handleSubmit}> | ||||
|             <Button color='green' fluid size='large' onClick={handleSubmit}> | ||||
|               登录 | ||||
|             </Button> | ||||
|           </Segment> | ||||
|         </Form> | ||||
|         <Message> | ||||
|           忘记密码? | ||||
|           <Link to="/reset" className="btn btn-link"> | ||||
|           <Link to='/reset' className='btn btn-link'> | ||||
|             点击重置 | ||||
|           </Link> | ||||
|           ; 没有账户? | ||||
|           <Link to="/register" className="btn btn-link"> | ||||
|           <Link to='/register' className='btn btn-link'> | ||||
|             点击注册 | ||||
|           </Link> | ||||
|         </Message> | ||||
| @@ -140,8 +129,8 @@ const LoginForm = () => { | ||||
|             {status.github_oauth ? ( | ||||
|               <Button | ||||
|                 circular | ||||
|                 color="black" | ||||
|                 icon="github" | ||||
|                 color='black' | ||||
|                 icon='github' | ||||
|                 onClick={onGitHubOAuthClicked} | ||||
|               /> | ||||
|             ) : ( | ||||
| @@ -150,8 +139,8 @@ const LoginForm = () => { | ||||
|             {status.wechat_login ? ( | ||||
|               <Button | ||||
|                 circular | ||||
|                 color="green" | ||||
|                 icon="wechat" | ||||
|                 color='green' | ||||
|                 icon='wechat' | ||||
|                 onClick={onWeChatLoginClicked} | ||||
|               /> | ||||
|             ) : ( | ||||
| @@ -175,18 +164,18 @@ const LoginForm = () => { | ||||
|                   微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) | ||||
|                 </p> | ||||
|               </div> | ||||
|               <Form size="large"> | ||||
|               <Form size='large'> | ||||
|                 <Form.Input | ||||
|                   fluid | ||||
|                   placeholder="验证码" | ||||
|                   name="wechat_verification_code" | ||||
|                   placeholder='验证码' | ||||
|                   name='wechat_verification_code' | ||||
|                   value={inputs.wechat_verification_code} | ||||
|                   onChange={handleChange} | ||||
|                 /> | ||||
|                 <Button | ||||
|                   color="" | ||||
|                   color='' | ||||
|                   fluid | ||||
|                   size="large" | ||||
|                   size='large' | ||||
|                   onClick={onSubmitWeChatVerificationCode} | ||||
|                 > | ||||
|                   登录 | ||||
|   | ||||
| @@ -1,8 +1,9 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Button, Label, Pagination, Select, Table } from 'semantic-ui-react'; | ||||
| import { Button, Form, Header, Label, Pagination, Segment, Select, Table } from 'semantic-ui-react'; | ||||
| import { API, isAdmin, showError, timestamp2string } from '../helpers'; | ||||
|  | ||||
| import { ITEMS_PER_PAGE } from '../constants'; | ||||
| import { renderQuota } from '../helpers/render'; | ||||
|  | ||||
| function renderTimestamp(timestamp) { | ||||
|   return ( | ||||
| @@ -14,7 +15,7 @@ function renderTimestamp(timestamp) { | ||||
|  | ||||
| const MODE_OPTIONS = [ | ||||
|   { key: 'all', text: '全部用户', value: 'all' }, | ||||
|   { key: 'self', text: '当前用户', value: 'self' }, | ||||
|   { key: 'self', text: '当前用户', value: 'self' } | ||||
| ]; | ||||
|  | ||||
| const LOG_OPTIONS = [ | ||||
| @@ -47,13 +48,58 @@ const LogsTable = () => { | ||||
|   const [searchKeyword, setSearchKeyword] = useState(''); | ||||
|   const [searching, setSearching] = useState(false); | ||||
|   const [logType, setLogType] = useState(0); | ||||
|   const [mode, setMode] = useState('self'); // all, self | ||||
|   const showModePanel = isAdmin(); | ||||
|   const isAdminUser = isAdmin(); | ||||
|   let now = new Date(); | ||||
|   const [inputs, setInputs] = useState({ | ||||
|     username: '', | ||||
|     token_name: '', | ||||
|     model_name: '', | ||||
|     start_timestamp: timestamp2string(0), | ||||
|     end_timestamp: timestamp2string(now.getTime() / 1000 + 3600) | ||||
|   }); | ||||
|   const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs; | ||||
|  | ||||
|   const [stat, setStat] = useState({ | ||||
|     quota: 0, | ||||
|     token: 0 | ||||
|   }); | ||||
|  | ||||
|   const handleInputChange = (e, { name, value }) => { | ||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
|   }; | ||||
|  | ||||
|   const getLogSelfStat = async () => { | ||||
|     let localStartTimestamp = Date.parse(start_timestamp) / 1000; | ||||
|     let localEndTimestamp = Date.parse(end_timestamp) / 1000; | ||||
|     let res = await API.get(`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`); | ||||
|     const { success, message, data } = res.data; | ||||
|     if (success) { | ||||
|       setStat(data); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const getLogStat = async () => { | ||||
|     let localStartTimestamp = Date.parse(start_timestamp) / 1000; | ||||
|     let localEndTimestamp = Date.parse(end_timestamp) / 1000; | ||||
|     let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`); | ||||
|     const { success, message, data } = res.data; | ||||
|     if (success) { | ||||
|       setStat(data); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const loadLogs = async (startIdx) => { | ||||
|     let url = `/api/log/self/?p=${startIdx}&type=${logType}`; | ||||
|     if (mode === 'all') { | ||||
|       url = `/api/log/?p=${startIdx}&type=${logType}`; | ||||
|     let url = ''; | ||||
|     let localStartTimestamp = Date.parse(start_timestamp) / 1000; | ||||
|     let localEndTimestamp = Date.parse(end_timestamp) / 1000; | ||||
|     if (isAdminUser) { | ||||
|       url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; | ||||
|     } else { | ||||
|       url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; | ||||
|     } | ||||
|     const res = await API.get(url); | ||||
|     const { success, message, data } = res.data; | ||||
| @@ -61,8 +107,8 @@ const LogsTable = () => { | ||||
|       if (startIdx === 0) { | ||||
|         setLogs(data); | ||||
|       } else { | ||||
|         let newLogs = logs; | ||||
|         newLogs.push(...data); | ||||
|         let newLogs = [...logs]; | ||||
|         newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); | ||||
|         setLogs(newLogs); | ||||
|       } | ||||
|     } else { | ||||
| @@ -83,20 +129,18 @@ const LogsTable = () => { | ||||
|  | ||||
|   const refresh = async () => { | ||||
|     setLoading(true); | ||||
|     setActivePage(1) | ||||
|     await loadLogs(0); | ||||
|     if (isAdminUser) { | ||||
|       getLogStat().then(); | ||||
|     } else { | ||||
|       getLogSelfStat().then(); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   useEffect(() => { | ||||
|     loadLogs(0) | ||||
|       .then() | ||||
|       .catch((reason) => { | ||||
|         showError(reason); | ||||
|       }); | ||||
|   }, []); | ||||
|  | ||||
|   useEffect(() => { | ||||
|     refresh().then(); | ||||
|   }, [mode, logType]); | ||||
|   }, [logType]); | ||||
|  | ||||
|   const searchLogs = async () => { | ||||
|     if (searchKeyword === '') { | ||||
| @@ -125,9 +169,17 @@ const LogsTable = () => { | ||||
|     if (logs.length === 0) return; | ||||
|     setLoading(true); | ||||
|     let sortedLogs = [...logs]; | ||||
|     sortedLogs.sort((a, b) => { | ||||
|       return ('' + a[key]).localeCompare(b[key]); | ||||
|     }); | ||||
|     if (typeof sortedLogs[0][key] === 'string'){ | ||||
|       sortedLogs.sort((a, b) => { | ||||
|         return ('' + a[key]).localeCompare(b[key]); | ||||
|       }); | ||||
|     } else { | ||||
|       sortedLogs.sort((a, b) => { | ||||
|         if (a[key] === b[key]) return 0; | ||||
|         if (a[key] > b[key]) return -1; | ||||
|         if (a[key] < b[key]) return 1; | ||||
|       }); | ||||
|     } | ||||
|     if (sortedLogs[0].id === logs[0].id) { | ||||
|       sortedLogs.reverse(); | ||||
|     } | ||||
| @@ -137,118 +189,178 @@ const LogsTable = () => { | ||||
|  | ||||
|   return ( | ||||
|     <> | ||||
|       <Table basic> | ||||
|         <Table.Header> | ||||
|           <Table.Row> | ||||
|             <Table.HeaderCell | ||||
|               style={{ cursor: 'pointer' }} | ||||
|               onClick={() => { | ||||
|                 sortLog('created_time'); | ||||
|               }} | ||||
|               width={3} | ||||
|             > | ||||
|               时间 | ||||
|             </Table.HeaderCell> | ||||
|       <Segment> | ||||
|         <Header as='h3'>使用明细(总消耗额度:{renderQuota(stat.quota)})</Header> | ||||
|         <Form> | ||||
|           <Form.Group> | ||||
|             { | ||||
|               showModePanel && ( | ||||
|                 <Table.HeaderCell | ||||
|               isAdminUser && ( | ||||
|                 <Form.Input fluid label={'用户名称'} width={2} value={username} | ||||
|                             placeholder={'可选值'} name='username' | ||||
|                             onChange={handleInputChange} /> | ||||
|               ) | ||||
|             } | ||||
|             <Form.Input fluid label={'令牌名称'} width={isAdminUser ? 2 : 3} value={token_name} | ||||
|                         placeholder={'可选值'} name='token_name' onChange={handleInputChange} /> | ||||
|             <Form.Input fluid label='模型名称' width={isAdminUser ? 2 : 3} value={model_name} placeholder='可选值' | ||||
|                         name='model_name' | ||||
|                         onChange={handleInputChange} /> | ||||
|             <Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local' | ||||
|                         name='start_timestamp' | ||||
|                         onChange={handleInputChange} /> | ||||
|             <Form.Input fluid label='结束时间' width={4} value={end_timestamp} type='datetime-local' | ||||
|                         name='end_timestamp' | ||||
|                         onChange={handleInputChange} /> | ||||
|             <Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button> | ||||
|           </Form.Group> | ||||
|         </Form> | ||||
|         <Table basic compact size='small'> | ||||
|           <Table.Header> | ||||
|             <Table.Row> | ||||
|               <Table.HeaderCell | ||||
|                 style={{ cursor: 'pointer' }} | ||||
|                 onClick={() => { | ||||
|                   sortLog('created_time'); | ||||
|                 }} | ||||
|                 width={3} | ||||
|               > | ||||
|                 时间 | ||||
|               </Table.HeaderCell> | ||||
|               { | ||||
|                 isAdminUser && <Table.HeaderCell | ||||
|                   style={{ cursor: 'pointer' }} | ||||
|                   onClick={() => { | ||||
|                     sortLog('user_id'); | ||||
|                     sortLog('username'); | ||||
|                   }} | ||||
|                   width={1} | ||||
|                 > | ||||
|                   用户 | ||||
|                 </Table.HeaderCell> | ||||
|               ) | ||||
|             } | ||||
|             <Table.HeaderCell | ||||
|               style={{ cursor: 'pointer' }} | ||||
|               onClick={() => { | ||||
|                 sortLog('type'); | ||||
|               }} | ||||
|               width={2} | ||||
|             > | ||||
|               类型 | ||||
|             </Table.HeaderCell> | ||||
|             <Table.HeaderCell | ||||
|               style={{ cursor: 'pointer' }} | ||||
|               onClick={() => { | ||||
|                 sortLog('content'); | ||||
|               }} | ||||
|               width={showModePanel ? 10 : 11} | ||||
|             > | ||||
|               详情 | ||||
|             </Table.HeaderCell> | ||||
|           </Table.Row> | ||||
|         </Table.Header> | ||||
|  | ||||
|         <Table.Body> | ||||
|           {logs | ||||
|             .slice( | ||||
|               (activePage - 1) * ITEMS_PER_PAGE, | ||||
|               activePage * ITEMS_PER_PAGE | ||||
|             ) | ||||
|             .map((log, idx) => { | ||||
|               if (log.deleted) return <></>; | ||||
|               return ( | ||||
|                 <Table.Row key={log.created_at}> | ||||
|                   <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell> | ||||
|                   { | ||||
|                     showModePanel && ( | ||||
|                       <Table.Cell><Label>{log.user_id}</Label></Table.Cell> | ||||
|                     ) | ||||
|                   } | ||||
|                   <Table.Cell>{renderType(log.type)}</Table.Cell> | ||||
|                   <Table.Cell>{log.content}</Table.Cell> | ||||
|                 </Table.Row> | ||||
|               ); | ||||
|             })} | ||||
|         </Table.Body> | ||||
|  | ||||
|         <Table.Footer> | ||||
|           <Table.Row> | ||||
|             <Table.HeaderCell colSpan={showModePanel ? '5' : '4'}> | ||||
|               { | ||||
|                 showModePanel && ( | ||||
|                   <Select | ||||
|                     placeholder='选择模式' | ||||
|                     options={MODE_OPTIONS} | ||||
|                     style={{ marginRight: '8px' }} | ||||
|                     name='mode' | ||||
|                     value={mode} | ||||
|                     onChange={(e, { name, value }) => { | ||||
|                       setMode(value); | ||||
|                     }} | ||||
|                   /> | ||||
|                 ) | ||||
|               } | ||||
|               <Select | ||||
|                 placeholder='选择明细分类' | ||||
|                 options={LOG_OPTIONS} | ||||
|                 style={{ marginRight: '8px' }} | ||||
|                 name='logType' | ||||
|                 value={logType} | ||||
|                 onChange={(e, { name, value }) => { | ||||
|                   setLogType(value); | ||||
|               <Table.HeaderCell | ||||
|                 style={{ cursor: 'pointer' }} | ||||
|                 onClick={() => { | ||||
|                   sortLog('token_name'); | ||||
|                 }} | ||||
|               /> | ||||
|               <Button size='small' onClick={refresh} loading={loading}>刷新</Button> | ||||
|               <Pagination | ||||
|                 floated='right' | ||||
|                 activePage={activePage} | ||||
|                 onPageChange={onPaginationChange} | ||||
|                 size='small' | ||||
|                 siblingRange={1} | ||||
|                 totalPages={ | ||||
|                   Math.ceil(logs.length / ITEMS_PER_PAGE) + | ||||
|                   (logs.length % ITEMS_PER_PAGE === 0 ? 1 : 0) | ||||
|                 } | ||||
|               /> | ||||
|             </Table.HeaderCell> | ||||
|           </Table.Row> | ||||
|         </Table.Footer> | ||||
|       </Table> | ||||
|                 width={1} | ||||
|               > | ||||
|                 令牌 | ||||
|               </Table.HeaderCell> | ||||
|               <Table.HeaderCell | ||||
|                 style={{ cursor: 'pointer' }} | ||||
|                 onClick={() => { | ||||
|                   sortLog('type'); | ||||
|                 }} | ||||
|                 width={1} | ||||
|               > | ||||
|                 类型 | ||||
|               </Table.HeaderCell> | ||||
|               <Table.HeaderCell | ||||
|                 style={{ cursor: 'pointer' }} | ||||
|                 onClick={() => { | ||||
|                   sortLog('model_name'); | ||||
|                 }} | ||||
|                 width={2} | ||||
|               > | ||||
|                 模型 | ||||
|               </Table.HeaderCell> | ||||
|               <Table.HeaderCell | ||||
|                 style={{ cursor: 'pointer' }} | ||||
|                 onClick={() => { | ||||
|                   sortLog('prompt_tokens'); | ||||
|                 }} | ||||
|                 width={1} | ||||
|               > | ||||
|                 提示 | ||||
|               </Table.HeaderCell> | ||||
|               <Table.HeaderCell | ||||
|                 style={{ cursor: 'pointer' }} | ||||
|                 onClick={() => { | ||||
|                   sortLog('completion_tokens'); | ||||
|                 }} | ||||
|                 width={1} | ||||
|               > | ||||
|                 补全 | ||||
|               </Table.HeaderCell> | ||||
|               <Table.HeaderCell | ||||
|                 style={{ cursor: 'pointer' }} | ||||
|                 onClick={() => { | ||||
|                   sortLog('quota'); | ||||
|                 }} | ||||
|                 width={2} | ||||
|               > | ||||
|                 消耗额度 | ||||
|               </Table.HeaderCell> | ||||
|               <Table.HeaderCell | ||||
|                 style={{ cursor: 'pointer' }} | ||||
|                 onClick={() => { | ||||
|                   sortLog('content'); | ||||
|                 }} | ||||
|                 width={isAdminUser ? 4 : 5} | ||||
|               > | ||||
|                 详情 | ||||
|               </Table.HeaderCell> | ||||
|             </Table.Row> | ||||
|           </Table.Header> | ||||
|  | ||||
|           <Table.Body> | ||||
|             {logs | ||||
|               .slice( | ||||
|                 (activePage - 1) * ITEMS_PER_PAGE, | ||||
|                 activePage * ITEMS_PER_PAGE | ||||
|               ) | ||||
|               .map((log, idx) => { | ||||
|                 if (log.deleted) return <></>; | ||||
|                 return ( | ||||
|                   <Table.Row key={log.created_at}> | ||||
|                     <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell> | ||||
|                     { | ||||
|                       isAdminUser && ( | ||||
|                         <Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell> | ||||
|                       ) | ||||
|                     } | ||||
|                     <Table.Cell>{log.token_name ? <Label basic>{log.token_name}</Label> : ''}</Table.Cell> | ||||
|                     <Table.Cell>{renderType(log.type)}</Table.Cell> | ||||
|                     <Table.Cell>{log.model_name ? <Label basic>{log.model_name}</Label> : ''}</Table.Cell> | ||||
|                     <Table.Cell>{log.prompt_tokens ? log.prompt_tokens : ''}</Table.Cell> | ||||
|                     <Table.Cell>{log.completion_tokens ? log.completion_tokens : ''}</Table.Cell> | ||||
|                     <Table.Cell>{log.quota ? renderQuota(log.quota, 6) : ''}</Table.Cell> | ||||
|                     <Table.Cell>{log.content}</Table.Cell> | ||||
|                   </Table.Row> | ||||
|                 ); | ||||
|               })} | ||||
|           </Table.Body> | ||||
|  | ||||
|           <Table.Footer> | ||||
|             <Table.Row> | ||||
|               <Table.HeaderCell colSpan={'9'}> | ||||
|                 <Select | ||||
|                   placeholder='选择明细分类' | ||||
|                   options={LOG_OPTIONS} | ||||
|                   style={{ marginRight: '8px' }} | ||||
|                   name='logType' | ||||
|                   value={logType} | ||||
|                   onChange={(e, { name, value }) => { | ||||
|                     setLogType(value); | ||||
|                   }} | ||||
|                 /> | ||||
|                 <Button size='small' onClick={refresh} loading={loading}>刷新</Button> | ||||
|                 <Pagination | ||||
|                   floated='right' | ||||
|                   activePage={activePage} | ||||
|                   onPageChange={onPaginationChange} | ||||
|                   size='small' | ||||
|                   siblingRange={1} | ||||
|                   totalPages={ | ||||
|                     Math.ceil(logs.length / ITEMS_PER_PAGE) + | ||||
|                     (logs.length % ITEMS_PER_PAGE === 0 ? 1 : 0) | ||||
|                   } | ||||
|                 /> | ||||
|               </Table.HeaderCell> | ||||
|             </Table.Row> | ||||
|           </Table.Footer> | ||||
|         </Table> | ||||
|       </Segment> | ||||
|     </> | ||||
|   ); | ||||
| }; | ||||
|   | ||||
| @@ -18,7 +18,9 @@ const OperationSetting = () => { | ||||
|     ChannelDisableThreshold: 0, | ||||
|     LogConsumeEnabled: '', | ||||
|     DisplayInCurrencyEnabled: '', | ||||
|     DisplayTokenStatEnabled: '' | ||||
|     DisplayTokenStatEnabled: '', | ||||
|     ApproximateTokenEnabled: '', | ||||
|     RetryTimes: 0, | ||||
|   }); | ||||
|   const [originInputs, setOriginInputs] = useState({}); | ||||
|   let [loading, setLoading] = useState(false); | ||||
| @@ -74,9 +76,6 @@ const OperationSetting = () => { | ||||
|   const submitConfig = async (group) => { | ||||
|     switch (group) { | ||||
|       case 'monitor': | ||||
|         if (originInputs['AutomaticDisableChannelEnabled'] !== inputs.AutomaticDisableChannelEnabled) { | ||||
|           await updateOption('AutomaticDisableChannelEnabled', inputs.AutomaticDisableChannelEnabled); | ||||
|         } | ||||
|         if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) { | ||||
|           await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold); | ||||
|         } | ||||
| @@ -124,6 +123,9 @@ const OperationSetting = () => { | ||||
|         if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) { | ||||
|           await updateOption('QuotaPerUnit', inputs.QuotaPerUnit); | ||||
|         } | ||||
|         if (originInputs['RetryTimes'] !== inputs.RetryTimes) { | ||||
|           await updateOption('RetryTimes', inputs.RetryTimes); | ||||
|         } | ||||
|         break; | ||||
|     } | ||||
|   }; | ||||
| @@ -135,7 +137,7 @@ const OperationSetting = () => { | ||||
|           <Header as='h3'> | ||||
|             通用设置 | ||||
|           </Header> | ||||
|           <Form.Group widths={3}> | ||||
|           <Form.Group widths={4}> | ||||
|             <Form.Input | ||||
|               label='充值链接' | ||||
|               name='TopUpLink' | ||||
| @@ -164,6 +166,17 @@ const OperationSetting = () => { | ||||
|               step='0.01' | ||||
|               placeholder='一单位货币能兑换的额度' | ||||
|             /> | ||||
|             <Form.Input | ||||
|               label='失败重试次数' | ||||
|               name='RetryTimes' | ||||
|               type={'number'} | ||||
|               step='1' | ||||
|               min='0' | ||||
|               onChange={handleInputChange} | ||||
|               autoComplete='new-password' | ||||
|               value={inputs.RetryTimes} | ||||
|               placeholder='失败重试次数' | ||||
|             /> | ||||
|           </Form.Group> | ||||
|           <Form.Group inline> | ||||
|             <Form.Checkbox | ||||
| @@ -184,6 +197,12 @@ const OperationSetting = () => { | ||||
|               name='DisplayTokenStatEnabled' | ||||
|               onChange={handleInputChange} | ||||
|             /> | ||||
|             <Form.Checkbox | ||||
|               checked={inputs.ApproximateTokenEnabled === 'true'} | ||||
|               label='使用近似的方式估算 token 数以减少计算量' | ||||
|               name='ApproximateTokenEnabled' | ||||
|               onChange={handleInputChange} | ||||
|             /> | ||||
|           </Form.Group> | ||||
|           <Form.Button onClick={() => { | ||||
|             submitConfig('general').then(); | ||||
|   | ||||
| @@ -12,6 +12,11 @@ const PasswordResetConfirm = () => { | ||||
|  | ||||
|   const [loading, setLoading] = useState(false); | ||||
|  | ||||
|   const [disableButton, setDisableButton] = useState(false); | ||||
|   const [countdown, setCountdown] = useState(30); | ||||
|  | ||||
|   const [newPassword, setNewPassword] = useState(''); | ||||
|  | ||||
|   const [searchParams, setSearchParams] = useSearchParams(); | ||||
|   useEffect(() => { | ||||
|     let token = searchParams.get('token'); | ||||
| @@ -22,7 +27,21 @@ const PasswordResetConfirm = () => { | ||||
|     }); | ||||
|   }, []); | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let countdownInterval = null; | ||||
|     if (disableButton && countdown > 0) { | ||||
|       countdownInterval = setInterval(() => { | ||||
|         setCountdown(countdown - 1); | ||||
|       }, 1000); | ||||
|     } else if (countdown === 0) { | ||||
|       setDisableButton(false); | ||||
|       setCountdown(30); | ||||
|     } | ||||
|     return () => clearInterval(countdownInterval);  | ||||
|   }, [disableButton, countdown]); | ||||
|  | ||||
|   async function handleSubmit(e) { | ||||
|     setDisableButton(true); | ||||
|     if (!email) return; | ||||
|     setLoading(true); | ||||
|     const res = await API.post(`/api/user/reset`, { | ||||
| @@ -32,14 +51,15 @@ const PasswordResetConfirm = () => { | ||||
|     const { success, message } = res.data; | ||||
|     if (success) { | ||||
|       let password = res.data.data; | ||||
|       setNewPassword(password); | ||||
|       await copy(password); | ||||
|       showNotice(`密码已重置并已复制到剪贴板:${password}`); | ||||
|       showNotice(`新密码已复制到剪贴板:${password}`); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
|     setLoading(false); | ||||
|   } | ||||
|  | ||||
|    | ||||
|   return ( | ||||
|     <Grid textAlign='center' style={{ marginTop: '48px' }}> | ||||
|       <Grid.Column style={{ maxWidth: 450 }}> | ||||
| @@ -57,20 +77,37 @@ const PasswordResetConfirm = () => { | ||||
|               value={email} | ||||
|               readOnly | ||||
|             /> | ||||
|             {newPassword && ( | ||||
|               <Form.Input | ||||
|               fluid | ||||
|               icon='lock' | ||||
|               iconPosition='left' | ||||
|               placeholder='新密码' | ||||
|               name='newPassword' | ||||
|               value={newPassword} | ||||
|               readOnly | ||||
|               onClick={(e) => { | ||||
|                 e.target.select(); | ||||
|                 navigator.clipboard.writeText(newPassword); | ||||
|                 showNotice(`密码已复制到剪贴板:${newPassword}`); | ||||
|               }} | ||||
|             />             | ||||
|             )} | ||||
|             <Button | ||||
|               color='' | ||||
|               color='green' | ||||
|               fluid | ||||
|               size='large' | ||||
|               onClick={handleSubmit} | ||||
|               loading={loading} | ||||
|               disabled={disableButton} | ||||
|             > | ||||
|               提交 | ||||
|               {disableButton ? `密码重置完成` : '提交'} | ||||
|             </Button> | ||||
|           </Segment> | ||||
|         </Form> | ||||
|       </Grid.Column> | ||||
|     </Grid> | ||||
|   ); | ||||
|   );   | ||||
| }; | ||||
|  | ||||
| export default PasswordResetConfirm; | ||||
|   | ||||
| @@ -5,7 +5,7 @@ import Turnstile from 'react-turnstile'; | ||||
|  | ||||
| const PasswordResetForm = () => { | ||||
|   const [inputs, setInputs] = useState({ | ||||
|     email: '', | ||||
|     email: '' | ||||
|   }); | ||||
|   const { email } = inputs; | ||||
|  | ||||
| @@ -13,24 +13,29 @@ const PasswordResetForm = () => { | ||||
|   const [turnstileEnabled, setTurnstileEnabled] = useState(false); | ||||
|   const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); | ||||
|   const [turnstileToken, setTurnstileToken] = useState(''); | ||||
|   const [disableButton, setDisableButton] = useState(false); | ||||
|   const [countdown, setCountdown] = useState(30); | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let status = localStorage.getItem('status'); | ||||
|     if (status) { | ||||
|       status = JSON.parse(status); | ||||
|       if (status.turnstile_check) { | ||||
|         setTurnstileEnabled(true); | ||||
|         setTurnstileSiteKey(status.turnstile_site_key); | ||||
|       } | ||||
|     let countdownInterval = null; | ||||
|     if (disableButton && countdown > 0) { | ||||
|       countdownInterval = setInterval(() => { | ||||
|         setCountdown(countdown - 1); | ||||
|       }, 1000); | ||||
|     } else if (countdown === 0) { | ||||
|       setDisableButton(false); | ||||
|       setCountdown(30); | ||||
|     } | ||||
|   }, []); | ||||
|     return () => clearInterval(countdownInterval); | ||||
|   }, [disableButton, countdown]); | ||||
|  | ||||
|   function handleChange(e) { | ||||
|     const { name, value } = e.target; | ||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
|     setInputs(inputs => ({ ...inputs, [name]: value })); | ||||
|   } | ||||
|  | ||||
|   async function handleSubmit(e) { | ||||
|     setDisableButton(true); | ||||
|     if (!email) return; | ||||
|     if (turnstileEnabled && turnstileToken === '') { | ||||
|       showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); | ||||
| @@ -78,13 +83,14 @@ const PasswordResetForm = () => { | ||||
|               <></> | ||||
|             )} | ||||
|             <Button | ||||
|               color='' | ||||
|               color='green' | ||||
|               fluid | ||||
|               size='large' | ||||
|               onClick={handleSubmit} | ||||
|               loading={loading} | ||||
|               disabled={disableButton} | ||||
|             > | ||||
|               提交 | ||||
|               {disableButton ? `重试 (${countdown})` : '提交'} | ||||
|             </Button> | ||||
|           </Segment> | ||||
|         </Form> | ||||
|   | ||||
| @@ -1,22 +1,32 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import React, { useContext, useEffect, useState } from 'react'; | ||||
| import { Button, Divider, Form, Header, Image, Message, Modal } from 'semantic-ui-react'; | ||||
| import { Link } from 'react-router-dom'; | ||||
| import { Link, useNavigate } from 'react-router-dom'; | ||||
| import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; | ||||
| import Turnstile from 'react-turnstile'; | ||||
| import { UserContext } from '../context/User'; | ||||
|  | ||||
| const PersonalSetting = () => { | ||||
|   const [userState, userDispatch] = useContext(UserContext); | ||||
|   let navigate = useNavigate(); | ||||
|  | ||||
|   const [inputs, setInputs] = useState({ | ||||
|     wechat_verification_code: '', | ||||
|     email_verification_code: '', | ||||
|     email: '', | ||||
|     self_account_deletion_confirmation: '' | ||||
|   }); | ||||
|   const [status, setStatus] = useState({}); | ||||
|   const [showWeChatBindModal, setShowWeChatBindModal] = useState(false); | ||||
|   const [showEmailBindModal, setShowEmailBindModal] = useState(false); | ||||
|   const [showAccountDeleteModal, setShowAccountDeleteModal] = useState(false); | ||||
|   const [turnstileEnabled, setTurnstileEnabled] = useState(false); | ||||
|   const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); | ||||
|   const [turnstileToken, setTurnstileToken] = useState(''); | ||||
|   const [loading, setLoading] = useState(false); | ||||
|   const [disableButton, setDisableButton] = useState(false); | ||||
|   const [countdown, setCountdown] = useState(30); | ||||
|   const [affLink, setAffLink] = useState(""); | ||||
|   const [systemToken, setSystemToken] = useState(""); | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let status = localStorage.getItem('status'); | ||||
| @@ -30,6 +40,19 @@ const PersonalSetting = () => { | ||||
|     } | ||||
|   }, []); | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let countdownInterval = null; | ||||
|     if (disableButton && countdown > 0) { | ||||
|       countdownInterval = setInterval(() => { | ||||
|         setCountdown(countdown - 1); | ||||
|       }, 1000); | ||||
|     } else if (countdown === 0) { | ||||
|       setDisableButton(false); | ||||
|       setCountdown(30); | ||||
|     } | ||||
|     return () => clearInterval(countdownInterval); // Clean up on unmount | ||||
|   }, [disableButton, countdown]); | ||||
|  | ||||
|   const handleInputChange = (e, { name, value }) => { | ||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
|   }; | ||||
| @@ -38,8 +61,10 @@ const PersonalSetting = () => { | ||||
|     const res = await API.get('/api/user/token'); | ||||
|     const { success, message, data } = res.data; | ||||
|     if (success) { | ||||
|       setSystemToken(data); | ||||
|       setAffLink("");  | ||||
|       await copy(data); | ||||
|       showSuccess(`令牌已重置并已复制到剪贴板:${data}`); | ||||
|       showSuccess(`令牌已重置并已复制到剪贴板`); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
| @@ -50,8 +75,42 @@ const PersonalSetting = () => { | ||||
|     const { success, message, data } = res.data; | ||||
|     if (success) { | ||||
|       let link = `${window.location.origin}/register?aff=${data}`; | ||||
|       setAffLink(link); | ||||
|       setSystemToken(""); | ||||
|       await copy(link); | ||||
|       showNotice(`邀请链接已复制到剪切板:${link}`); | ||||
|       showSuccess(`邀请链接已复制到剪切板`); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const handleAffLinkClick = async (e) => { | ||||
|     e.target.select(); | ||||
|     await copy(e.target.value); | ||||
|     showSuccess(`邀请链接已复制到剪切板`); | ||||
|   }; | ||||
|  | ||||
|   const handleSystemTokenClick = async (e) => { | ||||
|     e.target.select(); | ||||
|     await copy(e.target.value); | ||||
|     showSuccess(`系统令牌已复制到剪切板`); | ||||
|   }; | ||||
|  | ||||
|   const deleteAccount = async () => { | ||||
|     if (inputs.self_account_deletion_confirmation !== userState.user.username) { | ||||
|       showError('请输入你的账户名以确认删除!'); | ||||
|       return; | ||||
|     } | ||||
|  | ||||
|     const res = await API.delete('/api/user/self'); | ||||
|     const { success, message } = res.data; | ||||
|  | ||||
|     if (success) { | ||||
|       showSuccess('账户已删除!'); | ||||
|       await API.get('/api/user/logout'); | ||||
|       userDispatch({ type: 'logout' }); | ||||
|       localStorage.removeItem('user'); | ||||
|       navigate('/login'); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
| @@ -78,6 +137,7 @@ const PersonalSetting = () => { | ||||
|   }; | ||||
|  | ||||
|   const sendVerificationCode = async () => { | ||||
|     setDisableButton(true); | ||||
|     if (inputs.email === '') return; | ||||
|     if (turnstileEnabled && turnstileToken === '') { | ||||
|       showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); | ||||
| @@ -123,6 +183,28 @@ const PersonalSetting = () => { | ||||
|       </Button> | ||||
|       <Button onClick={generateAccessToken}>生成系统访问令牌</Button> | ||||
|       <Button onClick={getAffLink}>复制邀请链接</Button> | ||||
|       <Button onClick={() => { | ||||
|         setShowAccountDeleteModal(true); | ||||
|       }}>删除个人账户</Button> | ||||
|        | ||||
|       {systemToken && ( | ||||
|         <Form.Input  | ||||
|           fluid  | ||||
|           readOnly  | ||||
|           value={systemToken}  | ||||
|           onClick={handleSystemTokenClick} | ||||
|           style={{ marginTop: '10px' }} | ||||
|         /> | ||||
|       )} | ||||
|       {affLink && ( | ||||
|         <Form.Input  | ||||
|           fluid  | ||||
|           readOnly  | ||||
|           value={affLink}  | ||||
|           onClick={handleAffLinkClick} | ||||
|           style={{ marginTop: '10px' }} | ||||
|         /> | ||||
|       )} | ||||
|       <Divider /> | ||||
|       <Header as='h3'>账号绑定</Header> | ||||
|       { | ||||
| @@ -195,8 +277,8 @@ const PersonalSetting = () => { | ||||
|                 name='email' | ||||
|                 type='email' | ||||
|                 action={ | ||||
|                   <Button onClick={sendVerificationCode} disabled={loading}> | ||||
|                     获取验证码 | ||||
|                   <Button onClick={sendVerificationCode} disabled={disableButton || loading}> | ||||
|                     {disableButton ? `重新发送(${countdown})` : '获取验证码'} | ||||
|                   </Button> | ||||
|                 } | ||||
|               /> | ||||
| @@ -217,6 +299,7 @@ const PersonalSetting = () => { | ||||
|               ) : ( | ||||
|                 <></> | ||||
|               )} | ||||
|               <div style={{ display: 'flex', justifyContent: 'space-between', marginTop: '1rem' }}> | ||||
|               <Button | ||||
|                 color='' | ||||
|                 fluid | ||||
| @@ -224,8 +307,69 @@ const PersonalSetting = () => { | ||||
|                 onClick={bindEmail} | ||||
|                 loading={loading} | ||||
|               > | ||||
|                 绑定 | ||||
|                 确认绑定 | ||||
|               </Button> | ||||
|               <div style={{ width: '1rem' }}></div>  | ||||
|               <Button | ||||
|                 fluid | ||||
|                 size='large' | ||||
|                 onClick={() => setShowEmailBindModal(false)} | ||||
|               > | ||||
|                 取消 | ||||
|               </Button> | ||||
|               </div> | ||||
|             </Form> | ||||
|           </Modal.Description> | ||||
|         </Modal.Content> | ||||
|       </Modal> | ||||
|       <Modal | ||||
|         onClose={() => setShowAccountDeleteModal(false)} | ||||
|         onOpen={() => setShowAccountDeleteModal(true)} | ||||
|         open={showAccountDeleteModal} | ||||
|         size={'tiny'} | ||||
|         style={{ maxWidth: '450px' }} | ||||
|       > | ||||
|         <Modal.Header>危险操作</Modal.Header> | ||||
|         <Modal.Content> | ||||
|         <Message>您正在删除自己的帐户,将清空所有数据且不可恢复</Message> | ||||
|           <Modal.Description> | ||||
|             <Form size='large'> | ||||
|               <Form.Input | ||||
|                 fluid | ||||
|                 placeholder={`输入你的账户名 ${userState?.user?.username} 以确认删除`} | ||||
|                 name='self_account_deletion_confirmation' | ||||
|                 value={inputs.self_account_deletion_confirmation} | ||||
|                 onChange={handleInputChange} | ||||
|               /> | ||||
|               {turnstileEnabled ? ( | ||||
|                 <Turnstile | ||||
|                   sitekey={turnstileSiteKey} | ||||
|                   onVerify={(token) => { | ||||
|                     setTurnstileToken(token); | ||||
|                   }} | ||||
|                 /> | ||||
|               ) : ( | ||||
|                 <></> | ||||
|               )} | ||||
|               <div style={{ display: 'flex', justifyContent: 'space-between', marginTop: '1rem' }}> | ||||
|                 <Button | ||||
|                   color='red' | ||||
|                   fluid | ||||
|                   size='large' | ||||
|                   onClick={deleteAccount} | ||||
|                   loading={loading} | ||||
|                 > | ||||
|                   确认删除 | ||||
|                 </Button> | ||||
|                 <div style={{ width: '1rem' }}></div> | ||||
|                 <Button | ||||
|                   fluid | ||||
|                   size='large' | ||||
|                   onClick={() => setShowAccountDeleteModal(false)} | ||||
|                 > | ||||
|                   取消 | ||||
|                 </Button> | ||||
|               </div> | ||||
|             </Form> | ||||
|           </Modal.Description> | ||||
|         </Modal.Content> | ||||
|   | ||||
| @@ -1,13 +1,5 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { | ||||
|   Button, | ||||
|   Form, | ||||
|   Grid, | ||||
|   Header, | ||||
|   Image, | ||||
|   Message, | ||||
|   Segment, | ||||
| } from 'semantic-ui-react'; | ||||
| import { Button, Form, Grid, Header, Image, Message, Segment } from 'semantic-ui-react'; | ||||
| import { Link, useNavigate } from 'react-router-dom'; | ||||
| import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; | ||||
| import Turnstile from 'react-turnstile'; | ||||
| @@ -18,7 +10,7 @@ const RegisterForm = () => { | ||||
|     password: '', | ||||
|     password2: '', | ||||
|     email: '', | ||||
|     verification_code: '', | ||||
|     verification_code: '' | ||||
|   }); | ||||
|   const { username, password, password2 } = inputs; | ||||
|   const [showEmailVerification, setShowEmailVerification] = useState(false); | ||||
| @@ -178,7 +170,7 @@ const RegisterForm = () => { | ||||
|               <></> | ||||
|             )} | ||||
|             <Button | ||||
|               color='' | ||||
|               color='green' | ||||
|               fluid | ||||
|               size='large' | ||||
|               onClick={handleSubmit} | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Divider, Form, Grid, Header, Message } from 'semantic-ui-react'; | ||||
| import { API, removeTrailingSlash, showError, verifyJSON } from '../helpers'; | ||||
| import { Button, Divider, Form, Grid, Header, Input, Message } from 'semantic-ui-react'; | ||||
| import { API, removeTrailingSlash, showError } from '../helpers'; | ||||
|  | ||||
| const SystemSetting = () => { | ||||
|   let [inputs, setInputs] = useState({ | ||||
| @@ -26,9 +26,13 @@ const SystemSetting = () => { | ||||
|     TurnstileSiteKey: '', | ||||
|     TurnstileSecretKey: '', | ||||
|     RegisterEnabled: '', | ||||
|     EmailDomainRestrictionEnabled: '', | ||||
|     EmailDomainWhitelist: '' | ||||
|   }); | ||||
|   const [originInputs, setOriginInputs] = useState({}); | ||||
|   let [loading, setLoading] = useState(false); | ||||
|   const [EmailDomainWhitelist, setEmailDomainWhitelist] = useState([]); | ||||
|   const [restrictedDomainInput, setRestrictedDomainInput] = useState(''); | ||||
|  | ||||
|   const getOptions = async () => { | ||||
|     const res = await API.get('/api/option/'); | ||||
| @@ -38,8 +42,15 @@ const SystemSetting = () => { | ||||
|       data.forEach((item) => { | ||||
|         newInputs[item.key] = item.value; | ||||
|       }); | ||||
|       setInputs(newInputs); | ||||
|       setInputs({ | ||||
|         ...newInputs, | ||||
|         EmailDomainWhitelist: newInputs.EmailDomainWhitelist.split(',') | ||||
|       }); | ||||
|       setOriginInputs(newInputs); | ||||
|  | ||||
|       setEmailDomainWhitelist(newInputs.EmailDomainWhitelist.split(',').map((item) => { | ||||
|         return { key: item, text: item, value: item }; | ||||
|       })); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
| @@ -58,6 +69,7 @@ const SystemSetting = () => { | ||||
|       case 'GitHubOAuthEnabled': | ||||
|       case 'WeChatAuthEnabled': | ||||
|       case 'TurnstileCheckEnabled': | ||||
|       case 'EmailDomainRestrictionEnabled': | ||||
|       case 'RegisterEnabled': | ||||
|         value = inputs[key] === 'true' ? 'false' : 'true'; | ||||
|         break; | ||||
| @@ -70,7 +82,12 @@ const SystemSetting = () => { | ||||
|     }); | ||||
|     const { success, message } = res.data; | ||||
|     if (success) { | ||||
|       setInputs((inputs) => ({ ...inputs, [key]: value })); | ||||
|       if (key === 'EmailDomainWhitelist') { | ||||
|         value = value.split(','); | ||||
|       } | ||||
|       setInputs((inputs) => ({ | ||||
|         ...inputs, [key]: value | ||||
|       })); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
| @@ -88,7 +105,8 @@ const SystemSetting = () => { | ||||
|       name === 'WeChatServerToken' || | ||||
|       name === 'WeChatAccountQRCodeImageURL' || | ||||
|       name === 'TurnstileSiteKey' || | ||||
|       name === 'TurnstileSecretKey' | ||||
|       name === 'TurnstileSecretKey' || | ||||
|       name === 'EmailDomainWhitelist' | ||||
|     ) { | ||||
|       setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
|     } else { | ||||
| @@ -125,6 +143,16 @@ const SystemSetting = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|  | ||||
|   const submitEmailDomainWhitelist = async () => { | ||||
|     if ( | ||||
|       originInputs['EmailDomainWhitelist'] !== inputs.EmailDomainWhitelist.join(',') && | ||||
|       inputs.SMTPToken !== '' | ||||
|     ) { | ||||
|       await updateOption('EmailDomainWhitelist', inputs.EmailDomainWhitelist.join(',')); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const submitWeChat = async () => { | ||||
|     if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) { | ||||
|       await updateOption( | ||||
| @@ -173,6 +201,22 @@ const SystemSetting = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const submitNewRestrictedDomain = () => { | ||||
|     const localDomainList = inputs.EmailDomainWhitelist; | ||||
|     if (restrictedDomainInput !== '' && !localDomainList.includes(restrictedDomainInput)) { | ||||
|       setRestrictedDomainInput(''); | ||||
|       setInputs({ | ||||
|         ...inputs, | ||||
|         EmailDomainWhitelist: [...localDomainList, restrictedDomainInput], | ||||
|       }); | ||||
|       setEmailDomainWhitelist([...EmailDomainWhitelist, { | ||||
|         key: restrictedDomainInput, | ||||
|         text: restrictedDomainInput, | ||||
|         value: restrictedDomainInput, | ||||
|       }]); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   return ( | ||||
|     <Grid columns={1}> | ||||
|       <Grid.Column> | ||||
| @@ -239,6 +283,54 @@ const SystemSetting = () => { | ||||
|             /> | ||||
|           </Form.Group> | ||||
|           <Divider /> | ||||
|           <Header as='h3'> | ||||
|             配置邮箱域名白名单 | ||||
|             <Header.Subheader>用以防止恶意用户利用临时邮箱批量注册</Header.Subheader> | ||||
|           </Header> | ||||
|           <Form.Group widths={3}> | ||||
|             <Form.Checkbox | ||||
|               label='启用邮箱域名白名单' | ||||
|               name='EmailDomainRestrictionEnabled' | ||||
|               onChange={handleInputChange} | ||||
|               checked={inputs.EmailDomainRestrictionEnabled === 'true'} | ||||
|             /> | ||||
|           </Form.Group> | ||||
|           <Form.Group widths={2}> | ||||
|             <Form.Dropdown | ||||
|               label='允许的邮箱域名' | ||||
|               placeholder='允许的邮箱域名' | ||||
|               name='EmailDomainWhitelist' | ||||
|               required | ||||
|               fluid | ||||
|               multiple | ||||
|               selection | ||||
|               onChange={handleInputChange} | ||||
|               value={inputs.EmailDomainWhitelist} | ||||
|               autoComplete='new-password' | ||||
|               options={EmailDomainWhitelist} | ||||
|             /> | ||||
|             <Form.Input | ||||
|               label='添加新的允许的邮箱域名' | ||||
|               action={ | ||||
|                 <Button type='button' onClick={() => { | ||||
|                   submitNewRestrictedDomain(); | ||||
|                 }}>填入</Button> | ||||
|               } | ||||
|               onKeyDown={(e) => { | ||||
|                 if (e.key === 'Enter') { | ||||
|                   submitNewRestrictedDomain(); | ||||
|                 } | ||||
|               }} | ||||
|               autoComplete='new-password' | ||||
|               placeholder='输入新的允许的邮箱域名' | ||||
|               value={restrictedDomainInput} | ||||
|               onChange={(e, { value }) => { | ||||
|                 setRestrictedDomainInput(value); | ||||
|               }} | ||||
|             /> | ||||
|           </Form.Group> | ||||
|           <Form.Button onClick={submitEmailDomainWhitelist}>保存邮箱域名白名单设置</Form.Button> | ||||
|           <Divider /> | ||||
|           <Header as='h3'> | ||||
|             配置 SMTP | ||||
|             <Header.Subheader>用以支持系统的邮件发送</Header.Subheader> | ||||
| @@ -284,7 +376,7 @@ const SystemSetting = () => { | ||||
|               onChange={handleInputChange} | ||||
|               type='password' | ||||
|               autoComplete='new-password' | ||||
|               value={inputs.SMTPToken} | ||||
|               checked={inputs.RegisterEnabled === 'true'} | ||||
|               placeholder='敏感信息不会发送到前端显示' | ||||
|             /> | ||||
|           </Form.Group> | ||||
|   | ||||
| @@ -1,11 +1,22 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Button, Form, Label, Modal, Pagination, Popup, Table } from 'semantic-ui-react'; | ||||
| import { Button, Dropdown, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; | ||||
| import { Link } from 'react-router-dom'; | ||||
| import { API, copy, showError, showSuccess, showWarning, timestamp2string } from '../helpers'; | ||||
|  | ||||
| import { ITEMS_PER_PAGE } from '../constants'; | ||||
| import { renderQuota } from '../helpers/render'; | ||||
|  | ||||
| const COPY_OPTIONS = [ | ||||
|   { key: 'next', text: 'ChatGPT Next Web', value: 'next' }, | ||||
|   { key: 'ama', text: 'AMA 问天', value: 'ama' }, | ||||
|   { key: 'opencat', text: 'OpenCat', value: 'opencat' }, | ||||
| ]; | ||||
|  | ||||
| const OPEN_LINK_OPTIONS = [ | ||||
|   { key: 'ama', text: 'AMA 问天', value: 'ama' }, | ||||
|   { key: 'opencat', text: 'OpenCat', value: 'opencat' }, | ||||
| ]; | ||||
|  | ||||
| function renderTimestamp(timestamp) { | ||||
|   return ( | ||||
|     <> | ||||
| @@ -45,8 +56,8 @@ const TokensTable = () => { | ||||
|       if (startIdx === 0) { | ||||
|         setTokens(data); | ||||
|       } else { | ||||
|         let newTokens = tokens; | ||||
|         newTokens.push(...data); | ||||
|         let newTokens = [...tokens]; | ||||
|         newTokens.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); | ||||
|         setTokens(newTokens); | ||||
|       } | ||||
|     } else { | ||||
| @@ -67,7 +78,85 @@ const TokensTable = () => { | ||||
|  | ||||
|   const refresh = async () => { | ||||
|     setLoading(true); | ||||
|     await loadTokens(0); | ||||
|     await loadTokens(activePage - 1); | ||||
|   }; | ||||
|  | ||||
|   const onCopy = async (type, key) => { | ||||
|     let status = localStorage.getItem('status'); | ||||
|     let serverAddress = ''; | ||||
|     if (status) { | ||||
|       status = JSON.parse(status); | ||||
|       serverAddress = status.server_address; | ||||
|     } | ||||
|     if (serverAddress === '') { | ||||
|       serverAddress = window.location.origin; | ||||
|     } | ||||
|     let encodedServerAddress = encodeURIComponent(serverAddress); | ||||
|     const nextLink = localStorage.getItem('chat_link'); | ||||
|     let nextUrl; | ||||
|    | ||||
|     if (nextLink) { | ||||
|       nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`; | ||||
|     } else { | ||||
|       nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; | ||||
|     } | ||||
|  | ||||
|     let url; | ||||
|     switch (type) { | ||||
|       case 'ama': | ||||
|         url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`; | ||||
|         break; | ||||
|       case 'opencat': | ||||
|         url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`; | ||||
|         break; | ||||
|       case 'next': | ||||
|         url = nextUrl; | ||||
|         break; | ||||
|       default: | ||||
|         url = `sk-${key}`; | ||||
|     } | ||||
|     if (await copy(url)) { | ||||
|       showSuccess('已复制到剪贴板!'); | ||||
|     } else { | ||||
|       showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。'); | ||||
|       setSearchKeyword(url); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const onOpenLink = async (type, key) => { | ||||
|     let status = localStorage.getItem('status'); | ||||
|     let serverAddress = ''; | ||||
|     if (status) { | ||||
|       status = JSON.parse(status); | ||||
|       serverAddress = status.server_address;  | ||||
|     } | ||||
|     if (serverAddress === '') { | ||||
|       serverAddress = window.location.origin; | ||||
|     } | ||||
|     let encodedServerAddress = encodeURIComponent(serverAddress); | ||||
|     const chatLink = localStorage.getItem('chat_link'); | ||||
|     let defaultUrl; | ||||
|    | ||||
|     if (chatLink) { | ||||
|       defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`; | ||||
|     } else { | ||||
|       defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; | ||||
|     } | ||||
|     let url; | ||||
|     switch (type) { | ||||
|       case 'ama': | ||||
|         url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`; | ||||
|         break; | ||||
|    | ||||
|       case 'opencat': | ||||
|         url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`; | ||||
|         break; | ||||
|          | ||||
|       default: | ||||
|         url = defaultUrl; | ||||
|     } | ||||
|    | ||||
|     window.open(url, '_blank'); | ||||
|   } | ||||
|  | ||||
|   useEffect(() => { | ||||
| @@ -235,21 +324,51 @@ const TokensTable = () => { | ||||
|                   <Table.Cell>{token.expired_time === -1 ? '永不过期' : renderTimestamp(token.expired_time)}</Table.Cell> | ||||
|                   <Table.Cell> | ||||
|                     <div> | ||||
|                       <Button | ||||
|                         size={'small'} | ||||
|                         positive | ||||
|                         onClick={async () => { | ||||
|                           let key = "sk-" + token.key; | ||||
|                           if (await copy(key)) { | ||||
|                             showSuccess('已复制到剪贴板!'); | ||||
|                           } else { | ||||
|                             showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。'); | ||||
|                             setSearchKeyword(key); | ||||
|                           } | ||||
|                         }} | ||||
|                       > | ||||
|                         复制 | ||||
|                       </Button> | ||||
|                     <Button.Group color='green' size={'small'}> | ||||
|                         <Button | ||||
|                           size={'small'} | ||||
|                           positive | ||||
|                           onClick={async () => { | ||||
|                             await onCopy('', token.key); | ||||
|                           }} | ||||
|                         > | ||||
|                           复制 | ||||
|                         </Button> | ||||
|                         <Dropdown | ||||
|                           className='button icon' | ||||
|                           floating | ||||
|                           options={COPY_OPTIONS.map(option => ({ | ||||
|                             ...option, | ||||
|                             onClick: async () => { | ||||
|                               await onCopy(option.value, token.key); | ||||
|                             } | ||||
|                           }))} | ||||
|                           trigger={<></>} | ||||
|                         /> | ||||
|                       </Button.Group> | ||||
|                       {' '} | ||||
|                       <Button.Group color='blue' size={'small'}> | ||||
|                         <Button | ||||
|                             size={'small'} | ||||
|                             positive | ||||
|                             onClick={() => {      | ||||
|                               onOpenLink('', token.key);        | ||||
|                             }}> | ||||
|                             聊天 | ||||
|                           </Button> | ||||
|                           <Dropdown    | ||||
|                             className="button icon"        | ||||
|                             floating | ||||
|                             options={OPEN_LINK_OPTIONS.map(option => ({ | ||||
|                               ...option, | ||||
|                               onClick: async () => { | ||||
|                                 await onOpenLink(option.value, token.key); | ||||
|                               } | ||||
|                             }))}        | ||||
|                             trigger={<></>}    | ||||
|                           /> | ||||
|                       </Button.Group> | ||||
|                       {' '} | ||||
|                       <Popup | ||||
|                         trigger={ | ||||
|                           <Button size='small' negative> | ||||
|   | ||||
| @@ -183,14 +183,6 @@ const UsersTable = () => { | ||||
|             > | ||||
|               分组 | ||||
|             </Table.HeaderCell> | ||||
|             <Table.HeaderCell | ||||
|               style={{ cursor: 'pointer' }} | ||||
|               onClick={() => { | ||||
|                 sortUser('email'); | ||||
|               }} | ||||
|             > | ||||
|               邮箱地址 | ||||
|             </Table.HeaderCell> | ||||
|             <Table.HeaderCell | ||||
|               style={{ cursor: 'pointer' }} | ||||
|               onClick={() => { | ||||
| @@ -233,20 +225,20 @@ const UsersTable = () => { | ||||
|                   <Table.Cell> | ||||
|                     <Popup | ||||
|                       content={user.email ? user.email : '未绑定邮箱地址'} | ||||
|                       key={user.display_name} | ||||
|                       key={user.username} | ||||
|                       header={user.display_name ? user.display_name : user.username} | ||||
|                       trigger={<span>{renderText(user.username, 10)}</span>} | ||||
|                       trigger={<span>{renderText(user.username, 15)}</span>} | ||||
|                       hoverable | ||||
|                     /> | ||||
|                   </Table.Cell> | ||||
|                   <Table.Cell>{renderGroup(user.group)}</Table.Cell> | ||||
|                   {/*<Table.Cell>*/} | ||||
|                   {/*  {user.email ? <Popup hoverable content={user.email} trigger={<span>{renderText(user.email, 24)}</span>} /> : '无'}*/} | ||||
|                   {/*</Table.Cell>*/} | ||||
|                   <Table.Cell> | ||||
|                     {user.email ? <Popup hoverable content={user.email} trigger={<span>{renderText(user.email, 24)}</span>} /> : '无'} | ||||
|                   </Table.Cell> | ||||
|                   <Table.Cell> | ||||
|                     <Popup content='剩余额度' trigger={<Label>{renderQuota(user.quota)}</Label>} /> | ||||
|                     <Popup content='已用额度' trigger={<Label>{renderQuota(user.used_quota)}</Label>} /> | ||||
|                     <Popup content='请求次数' trigger={<Label>{renderNumber(user.request_count)}</Label>} /> | ||||
|                     <Popup content='剩余额度' trigger={<Label basic>{renderQuota(user.quota)}</Label>} /> | ||||
|                     <Popup content='已用额度' trigger={<Label basic>{renderQuota(user.used_quota)}</Label>} /> | ||||
|                     <Popup content='请求次数' trigger={<Label basic>{renderNumber(user.request_count)}</Label>} /> | ||||
|                   </Table.Cell> | ||||
|                   <Table.Cell>{renderRole(user.role)}</Table.Cell> | ||||
|                   <Table.Cell>{renderStatus(user.status)}</Table.Cell> | ||||
| @@ -320,7 +312,7 @@ const UsersTable = () => { | ||||
|  | ||||
|         <Table.Footer> | ||||
|           <Table.Row> | ||||
|             <Table.HeaderCell colSpan='8'> | ||||
|             <Table.HeaderCell colSpan='7'> | ||||
|               <Button size='small' as={Link} to='/user/add' loading={loading}> | ||||
|                 添加新的用户 | ||||
|               </Button> | ||||
|   | ||||
| @@ -1,13 +1,20 @@ | ||||
| export const CHANNEL_OPTIONS = [ | ||||
|   { key: 1, text: 'OpenAI', value: 1, color: 'green' }, | ||||
|   { key: 8, text: '自定义', value: 8, color: 'pink' }, | ||||
|   { key: 3, text: 'Azure', value: 3, color: 'olive' }, | ||||
|   { key: 2, text: 'API2D', value: 2, color: 'blue' }, | ||||
|   { key: 4, text: 'CloseAI', value: 4, color: 'teal' }, | ||||
|   { key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' }, | ||||
|   { key: 6, text: 'OpenAI Max', value: 6, color: 'violet' }, | ||||
|   { key: 7, text: 'OhMyGPT', value: 7, color: 'purple' }, | ||||
|   { key: 9, text: 'AI.LS', value: 9, color: 'yellow' }, | ||||
|   { key: 10, text: 'AI Proxy', value: 10, color: 'purple' }, | ||||
|   { key: 12, text: 'API2GPT', value: 12, color: 'blue' } | ||||
| ]; | ||||
|   { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, | ||||
|   { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, | ||||
|   { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, | ||||
|   { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, | ||||
|   { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, | ||||
|   { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, | ||||
|   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, | ||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||
|   { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, | ||||
|   { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, | ||||
|   { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' }, | ||||
|   { key: 10, text: '代理:AI Proxy', value: 10, color: 'purple' }, | ||||
|   { key: 4, text: '代理:CloseAI', value: 4, color: 'teal' }, | ||||
|   { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' }, | ||||
|   { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' }, | ||||
|   { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' }, | ||||
|   { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' } | ||||
| ]; | ||||
| @@ -1,5 +1,5 @@ | ||||
| export const toastConstants = { | ||||
|   SUCCESS_TIMEOUT: 500, | ||||
|   SUCCESS_TIMEOUT: 1500, | ||||
|   INFO_TIMEOUT: 3000, | ||||
|   ERROR_TIMEOUT: 5000, | ||||
|   WARNING_TIMEOUT: 10000, | ||||
|   | ||||
| @@ -46,9 +46,7 @@ const About = () => { | ||||
|             about.startsWith('https://') ? <iframe | ||||
|               src={about} | ||||
|               style={{ width: '100%', height: '100vh', border: 'none' }} | ||||
|             /> : <Segment> | ||||
|               <div style={{ fontSize: 'larger' }} dangerouslySetInnerHTML={{ __html: about }}></div> | ||||
|             </Segment> | ||||
|             /> : <div style={{ fontSize: 'larger' }} dangerouslySetInnerHTML={{ __html: about }}></div> | ||||
|           } | ||||
|         </> | ||||
|       } | ||||
|   | ||||
| @@ -1,9 +1,15 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Button, Form, Header, Message, Segment } from 'semantic-ui-react'; | ||||
| import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react'; | ||||
| import { useParams } from 'react-router-dom'; | ||||
| import { API, showError, showInfo, showSuccess } from '../../helpers'; | ||||
| import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; | ||||
| import { CHANNEL_OPTIONS } from '../../constants'; | ||||
|  | ||||
| const MODEL_MAPPING_EXAMPLE = { | ||||
|   'gpt-3.5-turbo-0301': 'gpt-3.5-turbo', | ||||
|   'gpt-4-0314': 'gpt-4', | ||||
|   'gpt-4-32k-0314': 'gpt-4-32k' | ||||
| }; | ||||
|  | ||||
| const EditChannel = () => { | ||||
|   const params = useParams(); | ||||
|   const channelId = params.id; | ||||
| @@ -15,17 +21,44 @@ const EditChannel = () => { | ||||
|     key: '', | ||||
|     base_url: '', | ||||
|     other: '', | ||||
|     model_mapping: '', | ||||
|     models: [], | ||||
|     groups: ['default'] | ||||
|   }; | ||||
|   const [batch, setBatch] = useState(false); | ||||
|   const [inputs, setInputs] = useState(originInputs); | ||||
|   const [originModelOptions, setOriginModelOptions] = useState([]); | ||||
|   const [modelOptions, setModelOptions] = useState([]); | ||||
|   const [groupOptions, setGroupOptions] = useState([]); | ||||
|   const [basicModels, setBasicModels] = useState([]); | ||||
|   const [fullModels, setFullModels] = useState([]); | ||||
|   const [customModel, setCustomModel] = useState(''); | ||||
|   const handleInputChange = (e, { name, value }) => { | ||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
|     if (name === 'type' && inputs.models.length === 0) { | ||||
|       let localModels = []; | ||||
|       switch (value) { | ||||
|         case 14: | ||||
|           localModels = ['claude-instant-1', 'claude-2']; | ||||
|           break; | ||||
|         case 11: | ||||
|           localModels = ['PaLM-2']; | ||||
|           break; | ||||
|         case 15: | ||||
|           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; | ||||
|           break; | ||||
|         case 17: | ||||
|           localModels = ['qwen-v1', 'qwen-plus-v1']; | ||||
|           break; | ||||
|         case 16: | ||||
|           localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||
|           break; | ||||
|         case 18: | ||||
|           localModels = ['SparkDesk']; | ||||
|           break; | ||||
|       } | ||||
|       setInputs((inputs) => ({ ...inputs, models: localModels })); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const loadChannel = async () => { | ||||
| @@ -42,6 +75,9 @@ const EditChannel = () => { | ||||
|       } else { | ||||
|         data.groups = data.group.split(','); | ||||
|       } | ||||
|       if (data.model_mapping !== '') { | ||||
|         data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2); | ||||
|       } | ||||
|       setInputs(data); | ||||
|     } else { | ||||
|       showError(message); | ||||
| @@ -52,13 +88,16 @@ const EditChannel = () => { | ||||
|   const fetchModels = async () => { | ||||
|     try { | ||||
|       let res = await API.get(`/api/channel/models`); | ||||
|       setModelOptions(res.data.data.map((model) => ({ | ||||
|       let localModelOptions = res.data.data.map((model) => ({ | ||||
|         key: model.id, | ||||
|         text: model.id, | ||||
|         value: model.id | ||||
|       }))); | ||||
|       })); | ||||
|       setOriginModelOptions(localModelOptions); | ||||
|       setFullModels(res.data.data.map((model) => model.id)); | ||||
|       setBasicModels(res.data.data.filter((model) => !model.id.startsWith('gpt-4')).map((model) => model.id)); | ||||
|       setBasicModels(res.data.data.filter((model) => { | ||||
|         return model.id.startsWith('gpt-3') || model.id.startsWith('text-'); | ||||
|       }).map((model) => model.id)); | ||||
|     } catch (error) { | ||||
|       showError(error.message); | ||||
|     } | ||||
| @@ -77,6 +116,20 @@ const EditChannel = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let localModelOptions = [...originModelOptions]; | ||||
|     inputs.models.forEach((model) => { | ||||
|       if (!localModelOptions.find((option) => option.key === model)) { | ||||
|         localModelOptions.push({ | ||||
|           key: model, | ||||
|           text: model, | ||||
|           value: model | ||||
|         }); | ||||
|       } | ||||
|     }); | ||||
|     setModelOptions(localModelOptions); | ||||
|   }, [originModelOptions, inputs.models]); | ||||
|  | ||||
|   useEffect(() => { | ||||
|     if (isEdit) { | ||||
|       loadChannel().then(); | ||||
| @@ -94,12 +147,19 @@ const EditChannel = () => { | ||||
|       showInfo('请至少选择一个模型!'); | ||||
|       return; | ||||
|     } | ||||
|     if (inputs.model_mapping !== '' && !verifyJSON(inputs.model_mapping)) { | ||||
|       showInfo('模型映射必须是合法的 JSON 格式!'); | ||||
|       return; | ||||
|     } | ||||
|     let localInputs = inputs; | ||||
|     if (localInputs.base_url.endsWith('/')) { | ||||
|       localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); | ||||
|     } | ||||
|     if (localInputs.type === 3 && localInputs.other === '') { | ||||
|       localInputs.other = '2023-03-15-preview'; | ||||
|       localInputs.other = '2023-06-01-preview'; | ||||
|     } | ||||
|     if (localInputs.model_mapping === '') { | ||||
|       localInputs.model_mapping = '{}'; | ||||
|     } | ||||
|     let res; | ||||
|     localInputs.models = localInputs.models.join(','); | ||||
| @@ -131,6 +191,7 @@ const EditChannel = () => { | ||||
|             <Form.Select | ||||
|               label='类型' | ||||
|               name='type' | ||||
|               required | ||||
|               options={CHANNEL_OPTIONS} | ||||
|               value={inputs.type} | ||||
|               onChange={handleInputChange} | ||||
| @@ -158,7 +219,7 @@ const EditChannel = () => { | ||||
|                   <Form.Input | ||||
|                     label='默认 API 版本' | ||||
|                     name='other' | ||||
|                     placeholder={'请输入默认 API 版本,例如:2023-03-15-preview,该配置可以被实际的请求查询参数所覆盖'} | ||||
|                     placeholder={'请输入默认 API 版本,例如:2023-06-01-preview,该配置可以被实际的请求查询参数所覆盖'} | ||||
|                     onChange={handleInputChange} | ||||
|                     value={inputs.other} | ||||
|                     autoComplete='new-password' | ||||
| @@ -181,25 +242,12 @@ const EditChannel = () => { | ||||
|               </Form.Field> | ||||
|             ) | ||||
|           } | ||||
|           { | ||||
|             inputs.type !== 3 && inputs.type !== 8 && ( | ||||
|               <Form.Field> | ||||
|                 <Form.Input | ||||
|                   label='镜像' | ||||
|                   name='base_url' | ||||
|                   placeholder={'请输入镜像站地址,格式为:https://domain.com,可不填,不填则使用渠道默认值'} | ||||
|                   onChange={handleInputChange} | ||||
|                   value={inputs.base_url} | ||||
|                   autoComplete='new-password' | ||||
|                 /> | ||||
|               </Form.Field> | ||||
|             ) | ||||
|           } | ||||
|           <Form.Field> | ||||
|             <Form.Input | ||||
|               label='名称' | ||||
|               required | ||||
|               name='name' | ||||
|               placeholder={'请输入名称'} | ||||
|               placeholder={'请为渠道命名'} | ||||
|               onChange={handleInputChange} | ||||
|               value={inputs.name} | ||||
|               autoComplete='new-password' | ||||
| @@ -208,8 +256,9 @@ const EditChannel = () => { | ||||
|           <Form.Field> | ||||
|             <Form.Dropdown | ||||
|               label='分组' | ||||
|               placeholder={'请选择分组'} | ||||
|               placeholder={'请选择可以使用该渠道的分组'} | ||||
|               name='groups' | ||||
|               required | ||||
|               fluid | ||||
|               multiple | ||||
|               selection | ||||
| @@ -224,8 +273,9 @@ const EditChannel = () => { | ||||
|           <Form.Field> | ||||
|             <Form.Dropdown | ||||
|               label='模型' | ||||
|               placeholder={'请选择该通道所支持的模型'} | ||||
|               placeholder={'请选择该渠道所支持的模型'} | ||||
|               name='models' | ||||
|               required | ||||
|               fluid | ||||
|               multiple | ||||
|               selection | ||||
| @@ -245,12 +295,50 @@ const EditChannel = () => { | ||||
|             <Button type={'button'} onClick={() => { | ||||
|               handleInputChange(null, { name: 'models', value: [] }); | ||||
|             }}>清除所有模型</Button> | ||||
|             <Input | ||||
|               action={ | ||||
|                 <Button type={'button'} onClick={() => { | ||||
|                   if (customModel.trim() === '') return; | ||||
|                   if (inputs.models.includes(customModel)) return; | ||||
|                   let localModels = [...inputs.models]; | ||||
|                   localModels.push(customModel); | ||||
|                   let localModelOptions = []; | ||||
|                   localModelOptions.push({ | ||||
|                     key: customModel, | ||||
|                     text: customModel, | ||||
|                     value: customModel | ||||
|                   }); | ||||
|                   setModelOptions(modelOptions => { | ||||
|                     return [...modelOptions, ...localModelOptions]; | ||||
|                   }); | ||||
|                   setCustomModel(''); | ||||
|                   handleInputChange(null, { name: 'models', value: localModels }); | ||||
|                 }}>填入</Button> | ||||
|               } | ||||
|               placeholder='输入自定义模型名称' | ||||
|               value={customModel} | ||||
|               onChange={(e, { value }) => { | ||||
|                 setCustomModel(value); | ||||
|               }} | ||||
|             /> | ||||
|           </div> | ||||
|           <Form.Field> | ||||
|             <Form.TextArea | ||||
|               label='模型重定向' | ||||
|               placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`} | ||||
|               name='model_mapping' | ||||
|               onChange={handleInputChange} | ||||
|               value={inputs.model_mapping} | ||||
|               style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }} | ||||
|               autoComplete='new-password' | ||||
|             /> | ||||
|           </Form.Field> | ||||
|           { | ||||
|             batch ? <Form.Field> | ||||
|               <Form.TextArea | ||||
|                 label='密钥' | ||||
|                 name='key' | ||||
|                 required | ||||
|                 placeholder={'请输入密钥,一行一个'} | ||||
|                 onChange={handleInputChange} | ||||
|                 value={inputs.key} | ||||
| @@ -261,7 +349,8 @@ const EditChannel = () => { | ||||
|               <Form.Input | ||||
|                 label='密钥' | ||||
|                 name='key' | ||||
|                 placeholder={'请输入密钥'} | ||||
|                 required | ||||
|                 placeholder={inputs.type === 15 ? '请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} | ||||
|                 onChange={handleInputChange} | ||||
|                 value={inputs.key} | ||||
|                 autoComplete='new-password' | ||||
| @@ -278,7 +367,21 @@ const EditChannel = () => { | ||||
|               /> | ||||
|             ) | ||||
|           } | ||||
|           <Button positive onClick={submit}>提交</Button> | ||||
|           { | ||||
|             inputs.type !== 3 && inputs.type !== 8 && ( | ||||
|               <Form.Field> | ||||
|                 <Form.Input | ||||
|                   label='镜像' | ||||
|                   name='base_url' | ||||
|                   placeholder={'此项可选,用于通过镜像站来进行 API 调用,请输入镜像站地址,格式为:https://domain.com'} | ||||
|                   onChange={handleInputChange} | ||||
|                   value={inputs.base_url} | ||||
|                   autoComplete='new-password' | ||||
|                 /> | ||||
|               </Form.Field> | ||||
|             ) | ||||
|           } | ||||
|           <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button> | ||||
|         </Form> | ||||
|       </Segment> | ||||
|     </> | ||||
|   | ||||
| @@ -4,10 +4,7 @@ import LogsTable from '../../components/LogsTable'; | ||||
|  | ||||
| const Token = () => ( | ||||
|   <> | ||||
|     <Segment> | ||||
|       <Header as='h3'>额度明细</Header> | ||||
|       <LogsTable /> | ||||
|     </Segment> | ||||
|     <LogsTable /> | ||||
|   </> | ||||
| ); | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Button, Form, Header, Message, Segment } from 'semantic-ui-react'; | ||||
| import { useParams } from 'react-router-dom'; | ||||
| import { useParams, useNavigate } from 'react-router-dom'; | ||||
| import { API, showError, showSuccess, timestamp2string } from '../../helpers'; | ||||
| import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render'; | ||||
|  | ||||
| @@ -11,17 +11,19 @@ const EditToken = () => { | ||||
|   const [loading, setLoading] = useState(isEdit); | ||||
|   const originInputs = { | ||||
|     name: '', | ||||
|     remain_quota: 0, | ||||
|     remain_quota: isEdit ? 0 : 500000, | ||||
|     expired_time: -1, | ||||
|     unlimited_quota: false | ||||
|   }; | ||||
|   const [inputs, setInputs] = useState(originInputs); | ||||
|   const { name, remain_quota, expired_time, unlimited_quota } = inputs; | ||||
|  | ||||
|   const navigate = useNavigate(); | ||||
|   const handleInputChange = (e, { name, value }) => { | ||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
|   }; | ||||
|  | ||||
|   const handleCancel = () => { | ||||
|     navigate("/token"); | ||||
|   } | ||||
|   const setExpiredTime = (month, day, hour, minute) => { | ||||
|     let now = new Date(); | ||||
|     let timestamp = now.getTime() / 1000; | ||||
| @@ -83,7 +85,7 @@ const EditToken = () => { | ||||
|       if (isEdit) { | ||||
|         showSuccess('令牌更新成功!'); | ||||
|       } else { | ||||
|         showSuccess('令牌创建成功!'); | ||||
|         showSuccess('令牌创建成功,请在列表页面点击复制获取令牌!'); | ||||
|         setInputs(originInputs); | ||||
|       } | ||||
|     } else { | ||||
| @@ -150,8 +152,9 @@ const EditToken = () => { | ||||
|           </Form.Field> | ||||
|           <Button type={'button'} onClick={() => { | ||||
|             setUnlimitedQuota(); | ||||
|           }}>{unlimited_quota ? '取消无限额度' : '设置为无限额度'}</Button> | ||||
|           <Button positive onClick={submit}>提交</Button> | ||||
|           }}>{unlimited_quota ? '取消无限额度' : '设为无限额度'}</Button> | ||||
|           <Button floated='right' positive onClick={submit}>提交</Button> | ||||
|           <Button floated='right' onClick={handleCancel}>取消</Button> | ||||
|         </Form> | ||||
|       </Segment> | ||||
|     </> | ||||
|   | ||||
| @@ -7,24 +7,32 @@ const TopUp = () => { | ||||
|   const [redemptionCode, setRedemptionCode] = useState(''); | ||||
|   const [topUpLink, setTopUpLink] = useState(''); | ||||
|   const [userQuota, setUserQuota] = useState(0); | ||||
|   const [isSubmitting, setIsSubmitting] = useState(false); | ||||
|  | ||||
|   const topUp = async () => { | ||||
|     if (redemptionCode === '') { | ||||
|       showInfo('请输入充值码!') | ||||
|       return; | ||||
|     } | ||||
|     const res = await API.post('/api/user/topup', { | ||||
|       key: redemptionCode | ||||
|     }); | ||||
|     const { success, message, data } = res.data; | ||||
|     if (success) { | ||||
|       showSuccess('充值成功!'); | ||||
|       setUserQuota((quota) => { | ||||
|         return quota + data; | ||||
|     setIsSubmitting(true); | ||||
|     try { | ||||
|       const res = await API.post('/api/user/topup', { | ||||
|         key: redemptionCode | ||||
|       }); | ||||
|       setRedemptionCode(''); | ||||
|     } else { | ||||
|       showError(message); | ||||
|       const { success, message, data } = res.data; | ||||
|       if (success) { | ||||
|         showSuccess('充值成功!'); | ||||
|         setUserQuota((quota) => { | ||||
|           return quota + data; | ||||
|         }); | ||||
|         setRedemptionCode(''); | ||||
|       } else { | ||||
|         showError(message); | ||||
|       } | ||||
|     } catch (err) { | ||||
|       showError('请求失败'); | ||||
|     } finally { | ||||
|       setIsSubmitting(false);  | ||||
|     } | ||||
|   }; | ||||
|  | ||||
| @@ -74,8 +82,8 @@ const TopUp = () => { | ||||
|             <Button color='green' onClick={openTopUpLink}> | ||||
|               获取兑换码 | ||||
|             </Button> | ||||
|             <Button color='yellow' onClick={topUp}> | ||||
|               充值 | ||||
|             <Button color='yellow' onClick={topUp} disabled={isSubmitting}> | ||||
|                 {isSubmitting ? '兑换中...' : '兑换'} | ||||
|             </Button> | ||||
|           </Form> | ||||
|         </Grid.Column> | ||||
| @@ -92,5 +100,4 @@ const TopUp = () => { | ||||
|   ); | ||||
| }; | ||||
|  | ||||
|  | ||||
| export default TopUp; | ||||
| export default TopUp; | ||||
| @@ -1,7 +1,8 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Button, Form, Header, Segment } from 'semantic-ui-react'; | ||||
| import { useParams } from 'react-router-dom'; | ||||
| import { useParams, useNavigate } from 'react-router-dom'; | ||||
| import { API, showError, showSuccess } from '../../helpers'; | ||||
| import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render'; | ||||
|  | ||||
| const EditUser = () => { | ||||
|   const params = useParams(); | ||||
| @@ -35,7 +36,10 @@ const EditUser = () => { | ||||
|       showError(error.message); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const navigate = useNavigate(); | ||||
|   const handleCancel = () => { | ||||
|     navigate("/setting"); | ||||
|   } | ||||
|   const loadUser = async () => { | ||||
|     let res = undefined; | ||||
|     if (userId) { | ||||
| @@ -134,7 +138,7 @@ const EditUser = () => { | ||||
|               </Form.Field> | ||||
|               <Form.Field> | ||||
|                 <Form.Input | ||||
|                   label='剩余额度' | ||||
|                   label={`剩余额度${renderQuotaWithPrompt(quota)}`} | ||||
|                   name='quota' | ||||
|                   placeholder={'请输入新的剩余额度'} | ||||
|                   onChange={handleInputChange} | ||||
| @@ -175,6 +179,7 @@ const EditUser = () => { | ||||
|               readOnly | ||||
|             /> | ||||
|           </Form.Field> | ||||
|           <Button onClick={handleCancel}>取消</Button> | ||||
|           <Button positive onClick={submit}>提交</Button> | ||||
|         </Form> | ||||
|       </Segment> | ||||
|   | ||||
		Reference in New Issue
	
	Block a user