mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-26 03:13:41 +08:00 
			
		
		
		
	Compare commits
	
		
			117 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | f823581235 | ||
|  | 89e6b9fe33 | ||
|  | 5a8fef00e5 | ||
|  | fe72f85554 | ||
|  | 3ac0b256e3 | ||
|  | b0fefd6dc5 | ||
|  | 43d8bedbb4 | ||
|  | c60f755715 | ||
|  | b7fcb319da | ||
|  | a4138aec1a | ||
|  | 67c64e71c8 | ||
|  | ffa4e491ea | ||
|  | 97030e27f8 | ||
|  | 461f5dab56 | ||
|  | af378c59af | ||
|  | bc6769826b | ||
|  | 0fe26cc4bd | ||
|  | 7d6a169669 | ||
|  | 66f06e5d6f | ||
|  | 6acb9537a9 | ||
|  | 7069c49bdf | ||
|  | 58dee76bf7 | ||
|  | 5cf23d8698 | ||
|  | 366b82128f | ||
|  | 2a70744dbf | ||
|  | 4c5feee0b6 | ||
|  | 9ba5388367 | ||
|  | 379074f7d0 | ||
|  | 365744a040 | ||
|  | 01f7b0186f | ||
|  | a3f80a3392 | ||
|  | 8f5b83562b | ||
|  | b7570d5c77 | ||
|  | 8bcaf182bc | ||
|  | 045e2fa139 | ||
|  | a884c4b0bf | ||
|  | c97c8a0f65 | ||
|  | 58fc40a744 | ||
|  | da87fca2a2 | ||
|  | 5e08cc8719 | ||
|  | d8b13b2c07 | ||
|  | be364ae09b | ||
|  | 2114bc1982 | ||
|  | 0f038d715d | ||
|  | 9dd92bbddd | ||
|  | 5b70ee3407 | ||
|  | 17027fb61e | ||
|  | a013b1a166 | ||
|  | 7c6dee7390 | ||
|  | 96dc7614e6 | ||
|  | 1c7c2d40bb | ||
|  | 455269c145 | ||
|  | 544f20cc73 | ||
|  | 902c2faa2c | ||
|  | 0e73418cdf | ||
|  | 9889377f0e | ||
|  | b273464e77 | ||
|  | b4e43d97fd | ||
|  | 3347a44023 | ||
|  | 923e24534b | ||
|  | b4d67ca614 | ||
|  | d85e356b6e | ||
|  | 53da7134b2 | ||
|  | 1fa1c66f13 | ||
|  | 495fc628e4 | ||
|  | 76f9288c34 | ||
|  | 915d13fdd4 | ||
|  | 969f539777 | ||
|  | 54e5f8ecd2 | ||
|  | 341c21e4cb | ||
|  | fe56aa1a46 | ||
|  | f0e2ba0318 | ||
|  | 43e7b465cb | ||
|  | 34d517cfa2 | ||
|  | ddcaf95f5f | ||
|  | 1d15157f7d | ||
|  | de7b9710a5 | ||
|  | 4f245bf738 | ||
|  | 58bb3ab6f6 | ||
|  | d306cb5229 | ||
|  | 6c5307d0c4 | ||
|  | 7c4505bdfc | ||
|  | 9d43ec57d8 | ||
|  | e5311892d1 | ||
|  | 56b3c939bf | ||
|  | 257135f676 | ||
|  | 84784ffccc | ||
|  | ef18eb9f93 | ||
|  | bc7c9105f4 | ||
|  | 3fe76c8af7 | ||
|  | c70c614018 | ||
|  | 0d87de697c | ||
|  | aec343dc38 | ||
|  | 12499aaf69 | ||
|  | 1e17944e4a | ||
|  | 28c29283c5 | ||
|  | cb3e9b8277 | ||
|  | 89d458b9cf | ||
|  | 63fafba112 | ||
|  | a398f35968 | ||
|  | 57aa637c77 | ||
|  | 3b483639a4 | ||
|  | 22980b4c44 | ||
|  | 64cdb7eafb | ||
|  | 824444244b | ||
|  | fbe9985f57 | ||
|  | a27a5bcc06 | ||
|  | e28d4b1741 | ||
|  | f073592d39 | ||
|  | fa41ca9805 | ||
|  | e338de45b6 | ||
|  | 114587b46f | ||
|  | b4b4acc288 | ||
|  | d663de3e3a | ||
|  | a85ecace2e | ||
|  | fbdea91ea1 | ||
|  | 8d34b7a77e | 
							
								
								
									
										46
									
								
								.air.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								.air.toml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | ||||
| root = "." | ||||
| testdata_dir = "testdata" | ||||
| tmp_dir = "tmp" | ||||
|  | ||||
| [build] | ||||
| args_bin = [] | ||||
| bin = "./tmp/main" | ||||
| cmd = "go build -o ./tmp/main ." | ||||
| delay = 1000 | ||||
| exclude_dir = ["assets", "tmp", "vendor", "testdata", "web"] | ||||
| exclude_file = [] | ||||
| exclude_regex = ["_test.go"] | ||||
| exclude_unchanged = false | ||||
| follow_symlink = false | ||||
| full_bin = "" | ||||
| include_dir = [] | ||||
| include_ext = ["go", "tpl", "tmpl", "html"] | ||||
| include_file = [] | ||||
| kill_delay = "0s" | ||||
| log = "build-errors.log" | ||||
| poll = false | ||||
| poll_interval = 0 | ||||
| post_cmd = [] | ||||
| pre_cmd = [] | ||||
| rerun = false | ||||
| rerun_delay = 500 | ||||
| send_interrupt = false | ||||
| stop_on_error = false | ||||
|  | ||||
| [color] | ||||
| app = "" | ||||
| build = "yellow" | ||||
| main = "magenta" | ||||
| runner = "green" | ||||
| watcher = "cyan" | ||||
|  | ||||
| [log] | ||||
| main_only = false | ||||
| time = false | ||||
|  | ||||
| [misc] | ||||
| clean_on_exit = false | ||||
|  | ||||
| [screen] | ||||
| clear_on_rebuild = false | ||||
| keep_scroll = true | ||||
							
								
								
									
										49
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										49
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							| @@ -1,49 +0,0 @@ | ||||
| name: Publish Docker image (amd64, English) | ||||
|  | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - '*' | ||||
|   workflow_dispatch: | ||||
|     inputs: | ||||
|       name: | ||||
|         description: 'reason' | ||||
|         required: false | ||||
| jobs: | ||||
|   push_to_registries: | ||||
|     name: Push Docker image to multiple registries | ||||
|     runs-on: ubuntu-latest | ||||
|     permissions: | ||||
|       packages: write | ||||
|       contents: read | ||||
|     steps: | ||||
|       - name: Check out the repo | ||||
|         uses: actions/checkout@v3 | ||||
|  | ||||
|       - name: Save version info | ||||
|         run: | | ||||
|           git describe --tags > VERSION  | ||||
|  | ||||
|       - name: Translate | ||||
|         run: | | ||||
|           python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json | ||||
|       - name: Log in to Docker Hub | ||||
|         uses: docker/login-action@v2 | ||||
|         with: | ||||
|           username: ${{ secrets.DOCKERHUB_USERNAME }} | ||||
|           password: ${{ secrets.DOCKERHUB_TOKEN }} | ||||
|  | ||||
|       - name: Extract metadata (tags, labels) for Docker | ||||
|         id: meta | ||||
|         uses: docker/metadata-action@v4 | ||||
|         with: | ||||
|           images: | | ||||
|             justsong/one-api-en | ||||
|  | ||||
|       - name: Build and push Docker images | ||||
|         uses: docker/build-push-action@v3 | ||||
|         with: | ||||
|           context: . | ||||
|           push: true | ||||
|           tags: ${{ steps.meta.outputs.tags }} | ||||
|           labels: ${{ steps.meta.outputs.labels }} | ||||
							
								
								
									
										54
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										54
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -1,54 +0,0 @@ | ||||
| name: Publish Docker image (amd64) | ||||
|  | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - '*' | ||||
|   workflow_dispatch: | ||||
|     inputs: | ||||
|       name: | ||||
|         description: 'reason' | ||||
|         required: false | ||||
| jobs: | ||||
|   push_to_registries: | ||||
|     name: Push Docker image to multiple registries | ||||
|     runs-on: ubuntu-latest | ||||
|     permissions: | ||||
|       packages: write | ||||
|       contents: read | ||||
|     steps: | ||||
|       - name: Check out the repo | ||||
|         uses: actions/checkout@v3 | ||||
|  | ||||
|       - name: Save version info | ||||
|         run: | | ||||
|           git describe --tags > VERSION  | ||||
|  | ||||
|       - name: Log in to Docker Hub | ||||
|         uses: docker/login-action@v2 | ||||
|         with: | ||||
|           username: ${{ secrets.DOCKERHUB_USERNAME }} | ||||
|           password: ${{ secrets.DOCKERHUB_TOKEN }} | ||||
|  | ||||
|       - name: Log in to the Container registry | ||||
|         uses: docker/login-action@v2 | ||||
|         with: | ||||
|           registry: ghcr.io | ||||
|           username: ${{ github.actor }} | ||||
|           password: ${{ secrets.GITHUB_TOKEN }} | ||||
|  | ||||
|       - name: Extract metadata (tags, labels) for Docker | ||||
|         id: meta | ||||
|         uses: docker/metadata-action@v4 | ||||
|         with: | ||||
|           images: | | ||||
|             justsong/one-api | ||||
|             ghcr.io/${{ github.repository }} | ||||
|  | ||||
|       - name: Build and push Docker images | ||||
|         uses: docker/build-push-action@v3 | ||||
|         with: | ||||
|           context: . | ||||
|           push: true | ||||
|           tags: ${{ steps.meta.outputs.tags }} | ||||
|           labels: ${{ steps.meta.outputs.labels }} | ||||
							
								
								
									
										62
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										62
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -1,62 +0,0 @@ | ||||
| name: Publish Docker image (arm64) | ||||
|  | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - '*' | ||||
|       - '!*-alpha*' | ||||
|   workflow_dispatch: | ||||
|     inputs: | ||||
|       name: | ||||
|         description: 'reason' | ||||
|         required: false | ||||
| jobs: | ||||
|   push_to_registries: | ||||
|     name: Push Docker image to multiple registries | ||||
|     runs-on: ubuntu-latest | ||||
|     permissions: | ||||
|       packages: write | ||||
|       contents: read | ||||
|     steps: | ||||
|       - name: Check out the repo | ||||
|         uses: actions/checkout@v3 | ||||
|  | ||||
|       - name: Save version info | ||||
|         run: | | ||||
|           git describe --tags > VERSION  | ||||
|  | ||||
|       - name: Set up QEMU | ||||
|         uses: docker/setup-qemu-action@v2 | ||||
|  | ||||
|       - name: Set up Docker Buildx | ||||
|         uses: docker/setup-buildx-action@v2 | ||||
|  | ||||
|       - name: Log in to Docker Hub | ||||
|         uses: docker/login-action@v2 | ||||
|         with: | ||||
|           username: ${{ secrets.DOCKERHUB_USERNAME }} | ||||
|           password: ${{ secrets.DOCKERHUB_TOKEN }} | ||||
|  | ||||
|       - name: Log in to the Container registry | ||||
|         uses: docker/login-action@v2 | ||||
|         with: | ||||
|           registry: ghcr.io | ||||
|           username: ${{ github.actor }} | ||||
|           password: ${{ secrets.GITHUB_TOKEN }} | ||||
|  | ||||
|       - name: Extract metadata (tags, labels) for Docker | ||||
|         id: meta | ||||
|         uses: docker/metadata-action@v4 | ||||
|         with: | ||||
|           images: | | ||||
|             justsong/one-api | ||||
|             ghcr.io/${{ github.repository }} | ||||
|  | ||||
|       - name: Build and push Docker images | ||||
|         uses: docker/build-push-action@v3 | ||||
|         with: | ||||
|           context: . | ||||
|           platforms: linux/amd64,linux/arm64 | ||||
|           push: true | ||||
|           tags: ${{ steps.meta.outputs.tags }} | ||||
|           labels: ${{ steps.meta.outputs.labels }} | ||||
							
								
								
									
										62
									
								
								.github/workflows/docker-image.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								.github/workflows/docker-image.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,62 @@ | ||||
| name: one-api docker image | ||||
|  | ||||
| on: | ||||
|   push: | ||||
|     branches: | ||||
|       - main | ||||
|     tags: | ||||
|       - "v*" | ||||
|  | ||||
| env: | ||||
|   # github.repository as <account>/<repo> | ||||
|   IMAGE_NAME: martialbe/one-api | ||||
|  | ||||
| jobs: | ||||
|   build-and-push: | ||||
|     runs-on: ubuntu-latest | ||||
|     permissions: | ||||
|       packages: write | ||||
|       contents: read | ||||
|     steps: | ||||
|       - name: Check out the repo | ||||
|         uses: actions/checkout@v3 | ||||
|         with: | ||||
|           fetch-depth: 0 | ||||
|  | ||||
|       - name: Set up QEMU | ||||
|         uses: docker/setup-qemu-action@v2 | ||||
|  | ||||
|       - name: Set up Docker Buildx | ||||
|         uses: docker/setup-buildx-action@v2 | ||||
|  | ||||
|       - name: Login to GHCR | ||||
|         uses: docker/login-action@v2 | ||||
|         with: | ||||
|           registry: ghcr.io | ||||
|           username: ${{ github.repository_owner }} | ||||
|           password: ${{ secrets.GT_Token }} | ||||
|  | ||||
|       - name: Docker meta | ||||
|         id: meta | ||||
|         uses: docker/metadata-action@v4 | ||||
|         with: | ||||
|           # list of Docker images to use as base name for tags | ||||
|           images: ghcr.io/${{ env.IMAGE_NAME }} | ||||
|           # generate Docker tags based on the following events/attributes | ||||
|           tags: | | ||||
|             type=raw,value=dev,enable=${{ github.ref == 'refs/heads/main' }} | ||||
|             type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }} | ||||
|             type=pep440,pattern={{raw}},enable=${{ startsWith(github.ref, 'refs/tags/') }} | ||||
|  | ||||
|       - name: Build and push | ||||
|         uses: docker/build-push-action@v4 | ||||
|         with: | ||||
|           context: . | ||||
|           platforms: linux/amd64 | ||||
|           build-args: | | ||||
|             COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }} | ||||
|           push: true | ||||
|           tags: ${{ steps.meta.outputs.tags }} | ||||
|           labels: ${{ steps.meta.outputs.labels }} | ||||
|           cache-from: type=gha | ||||
|           cache-to: type=gha,mode=max | ||||
							
								
								
									
										8
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -5,8 +5,8 @@ permissions: | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - '*' | ||||
|       - '!*-alpha*' | ||||
|       - "*" | ||||
|       - "!*-alpha*" | ||||
| jobs: | ||||
|   release: | ||||
|     runs-on: ubuntu-latest | ||||
| @@ -29,7 +29,7 @@ jobs: | ||||
|       - name: Set up Go | ||||
|         uses: actions/setup-go@v3 | ||||
|         with: | ||||
|           go-version: '>=1.18.0' | ||||
|           go-version: ">=1.18.0" | ||||
|       - name: Build Backend (amd64) | ||||
|         run: | | ||||
|           go mod download | ||||
| @@ -51,4 +51,4 @@ jobs: | ||||
|           draft: true | ||||
|           generate_release_notes: true | ||||
|         env: | ||||
|           GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||||
|           GITHUB_TOKEN: ${{ secrets.GT_Token }} | ||||
|   | ||||
							
								
								
									
										8
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -5,8 +5,8 @@ permissions: | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - '*' | ||||
|       - '!*-alpha*' | ||||
|       - "*" | ||||
|       - "!*-alpha*" | ||||
| jobs: | ||||
|   release: | ||||
|     runs-on: macos-latest | ||||
| @@ -29,7 +29,7 @@ jobs: | ||||
|       - name: Set up Go | ||||
|         uses: actions/setup-go@v3 | ||||
|         with: | ||||
|           go-version: '>=1.18.0' | ||||
|           go-version: ">=1.18.0" | ||||
|       - name: Build Backend | ||||
|         run: | | ||||
|           go mod download | ||||
| @@ -42,4 +42,4 @@ jobs: | ||||
|           draft: true | ||||
|           generate_release_notes: true | ||||
|         env: | ||||
|           GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||||
|           GITHUB_TOKEN: ${{ secrets.GT_Token }} | ||||
|   | ||||
							
								
								
									
										8
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -5,8 +5,8 @@ permissions: | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - '*' | ||||
|       - '!*-alpha*' | ||||
|       - "*" | ||||
|       - "!*-alpha*" | ||||
| jobs: | ||||
|   release: | ||||
|     runs-on: windows-latest | ||||
| @@ -32,7 +32,7 @@ jobs: | ||||
|       - name: Set up Go | ||||
|         uses: actions/setup-go@v3 | ||||
|         with: | ||||
|           go-version: '>=1.18.0' | ||||
|           go-version: ">=1.18.0" | ||||
|       - name: Build Backend | ||||
|         run: | | ||||
|           go mod download | ||||
| @@ -45,4 +45,4 @@ jobs: | ||||
|           draft: true | ||||
|           generate_release_notes: true | ||||
|         env: | ||||
|           GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||||
|           GITHUB_TOKEN: ${{ secrets.GT_Token }} | ||||
|   | ||||
							
								
								
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -5,4 +5,7 @@ upload | ||||
| *.db | ||||
| build | ||||
| *.db-journal | ||||
| logs | ||||
| logs | ||||
| data | ||||
| tmp/ | ||||
| .env | ||||
							
								
								
									
										138
									
								
								README.en.md
									
									
									
									
									
								
							
							
						
						
									
										138
									
								
								README.en.md
									
									
									
									
									
								
							| @@ -10,28 +10,38 @@ | ||||
|  | ||||
| # One API | ||||
|  | ||||
| _This project is a derivative of [one-api](https://github.com/songquanpeng/one-api), where the main focus has been on modularizing the module code from the original project and modifying the frontend interface. This project also adheres to the MIT License._ | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="https://raw.githubusercontent.com/MartialBE/one-api/main/LICENSE"> | ||||
|     <img src="https://img.shields.io/github/license/MartialBE/one-api?color=brightgreen" alt="license"> | ||||
|   </a> | ||||
|   <a href="https://github.com/MartialBE/one-api/releases/latest"> | ||||
|     <img src="https://img.shields.io/github/v/release/MartialBE/one-api?color=brightgreen&include_prereleases" alt="release"> | ||||
|   </a> | ||||
|   <a href="https://github.com/users/MartialBE/packages/container/package/one-api"> | ||||
|     <img src="https://img.shields.io/badge/docker-ghcr.io-blue" alt="docker"> | ||||
|   </a> | ||||
|   <a href="https://goreportcard.com/report/github.com/MartialBE/one-api"> | ||||
|     <img src="https://goreportcard.com/badge/github.com/MartialBE/one-api" alt="GoReportCard"> | ||||
|   </a> | ||||
| </p> | ||||
|  | ||||
| **Please do not mix with the original version, as the different channel ID may cause data disorder.** | ||||
|  | ||||
| ## Screenshots | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| _The following is the original project description:_ | ||||
|  | ||||
| --- | ||||
|  | ||||
| _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use ✨_ | ||||
|  | ||||
| </div> | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="https://raw.githubusercontent.com/songquanpeng/one-api/main/LICENSE"> | ||||
|     <img src="https://img.shields.io/github/license/songquanpeng/one-api?color=brightgreen" alt="license"> | ||||
|   </a> | ||||
|   <a href="https://github.com/songquanpeng/one-api/releases/latest"> | ||||
|     <img src="https://img.shields.io/github/v/release/songquanpeng/one-api?color=brightgreen&include_prereleases" alt="release"> | ||||
|   </a> | ||||
|   <a href="https://hub.docker.com/repository/docker/justsong/one-api"> | ||||
|     <img src="https://img.shields.io/docker/pulls/justsong/one-api?color=brightgreen" alt="docker pull"> | ||||
|   </a> | ||||
|   <a href="https://github.com/songquanpeng/one-api/releases/latest"> | ||||
|     <img src="https://img.shields.io/github/downloads/songquanpeng/one-api/total?color=brightgreen&include_prereleases" alt="release"> | ||||
|   </a> | ||||
|   <a href="https://goreportcard.com/report/github.com/songquanpeng/one-api"> | ||||
|     <img src="https://goreportcard.com/badge/github.com/songquanpeng/one-api" alt="GoReportCard"> | ||||
|   </a> | ||||
| </p> | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="#deployment">Deployment Tutorial</a> | ||||
|   · | ||||
| @@ -57,13 +67,14 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use | ||||
| > **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability. | ||||
|  | ||||
| ## Features | ||||
|  | ||||
| 1. Support for multiple large models: | ||||
|    + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||
|    + [x] [Anthropic Claude Series Models](https://anthropic.com) | ||||
|    + [x] [Google PaLM2 Series Models](https://developers.generativeai.google) | ||||
|    + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||
|    + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) | ||||
|    + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) | ||||
|    - [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||
|    - [x] [Anthropic Claude Series Models](https://anthropic.com) | ||||
|    - [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google) | ||||
|    - [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||
|    - [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) | ||||
|    - [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) | ||||
| 2. Supports access to multiple channels through **load balancing**. | ||||
| 3. Supports **stream mode** that enables typewriter-like effect through stream transmission. | ||||
| 4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details. | ||||
| @@ -82,13 +93,15 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use | ||||
| 15. Supports management API access through system access tokens. | ||||
| 16. Supports Cloudflare Turnstile user verification. | ||||
| 17. Supports user management and multiple user login/registration methods: | ||||
|     + Email login/registration and password reset via email. | ||||
|     + [GitHub OAuth](https://github.com/settings/applications/new). | ||||
|     + WeChat Official Account authorization (requires additional deployment of [WeChat Server](https://github.com/songquanpeng/wechat-server)). | ||||
|     - Email login/registration and password reset via email. | ||||
|     - [GitHub OAuth](https://github.com/settings/applications/new). | ||||
|     - WeChat Official Account authorization (requires additional deployment of [WeChat Server](https://github.com/songquanpeng/wechat-server)). | ||||
| 18. Immediate support and encapsulation of other major model APIs as they become available. | ||||
|  | ||||
| ## Deployment | ||||
|  | ||||
| ### Docker Deployment | ||||
|  | ||||
| Deployment command: `docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api-en` | ||||
|  | ||||
| Update command: `docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR` | ||||
| @@ -98,10 +111,11 @@ The first `3000` in `-p 3000:3000` is the port of the host, which can be modifie | ||||
| Data will be saved in the `/home/ubuntu/data/one-api` directory on the host. Ensure that the directory exists and has write permissions, or change it to a suitable directory. | ||||
|  | ||||
| Nginx reference configuration: | ||||
|  | ||||
| ``` | ||||
| server{ | ||||
|    server_name openai.justsong.cn;  # Modify your domain name accordingly | ||||
|     | ||||
|  | ||||
|    location / { | ||||
|           client_max_body_size  64m; | ||||
|           proxy_http_version 1.1; | ||||
| @@ -115,6 +129,7 @@ server{ | ||||
| ``` | ||||
|  | ||||
| Next, configure HTTPS with Let's Encrypt certbot: | ||||
|  | ||||
| ```bash | ||||
| # Install certbot on Ubuntu: | ||||
| sudo snap install --classic certbot | ||||
| @@ -129,20 +144,23 @@ sudo service nginx restart | ||||
| The initial account username is `root` and password is `123456`. | ||||
|  | ||||
| ### Manual Deployment | ||||
|  | ||||
| 1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source: | ||||
|  | ||||
|    ```shell | ||||
|    git clone https://github.com/songquanpeng/one-api.git | ||||
|     | ||||
|  | ||||
|    # Build the frontend | ||||
|    cd one-api/web | ||||
|    npm install | ||||
|    npm run build | ||||
|     | ||||
|  | ||||
|    # Build the backend | ||||
|    cd .. | ||||
|    go mod download | ||||
|    go build -ldflags "-s -w" -o one-api | ||||
|    ``` | ||||
|  | ||||
| 2. Run: | ||||
|    ```shell | ||||
|    chmod u+x one-api | ||||
| @@ -153,6 +171,7 @@ The initial account username is `root` and password is `123456`. | ||||
| For more detailed deployment tutorials, please refer to [this page](https://iamazing.cn/page/how-to-deploy-a-website). | ||||
|  | ||||
| ### Multi-machine Deployment | ||||
|  | ||||
| 1. Set the same `SESSION_SECRET` for all servers. | ||||
| 2. Set `SQL_DSN` and use MySQL instead of SQLite. All servers should connect to the same database. | ||||
| 3. Set the `NODE_TYPE` for all non-master nodes to `slave`. | ||||
| @@ -164,11 +183,13 @@ For more detailed deployment tutorials, please refer to [this page](https://iama | ||||
| Please refer to the [environment variables](#environment-variables) section for details on using environment variables. | ||||
|  | ||||
| ### Deployment on Control Panels (e.g., Baota) | ||||
|  | ||||
| Refer to [#175](https://github.com/songquanpeng/one-api/issues/175) for detailed instructions. | ||||
|  | ||||
| If you encounter a blank page after deployment, refer to [#97](https://github.com/songquanpeng/one-api/issues/97) for possible solutions. | ||||
|  | ||||
| ### Deployment on Third-Party Platforms | ||||
|  | ||||
| <details> | ||||
| <summary><strong>Deploy on Sealos</strong></summary> | ||||
| <div> | ||||
| @@ -179,7 +200,6 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co | ||||
|  | ||||
| [](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) | ||||
|  | ||||
|  | ||||
| </div> | ||||
| </details> | ||||
|  | ||||
| @@ -189,10 +209,12 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co | ||||
|  | ||||
| > Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage. | ||||
|  | ||||
| [](https://zeabur.com/templates/7Q0KO3) | ||||
|  | ||||
| 1. First, fork the code. | ||||
| 2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console. | ||||
| 3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port). | ||||
| 4. Copy the connection parameters and run ```create database `one-api` ``` to create the database. | ||||
| 4. Copy the connection parameters and run `` create database `one-api`  `` to create the database. | ||||
| 5. Then, in Service -> Add Service, select Git (authorization is required for the first use) and choose your forked repository. | ||||
| 6. Automatic deployment will start, but please cancel it for now. Go to the Variable tab, add a `PORT` with a value of `3000`, and then add a `SQL_DSN` with a value of `<username>:<password>@tcp(<addr>:<port>)/one-api`. Save the changes. Please note that if `SQL_DSN` is not set, data will not be persisted, and the data will be lost after redeployment. | ||||
| 7. Select Redeploy. | ||||
| @@ -203,6 +225,7 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co | ||||
| </details> | ||||
|  | ||||
| ## Configuration | ||||
|  | ||||
| The system is ready to use out of the box. | ||||
|  | ||||
| You can configure it by setting environment variables or command line parameters. | ||||
| @@ -210,6 +233,7 @@ You can configure it by setting environment variables or command line parameters | ||||
| After the system starts, log in as the `root` user to further configure the system. | ||||
|  | ||||
| ## Usage | ||||
|  | ||||
| Add your API Key on the `Channels` page, and then add an access token on the `Tokens` page. | ||||
|  | ||||
| You can then use your access token to access One API. The usage is consistent with the [OpenAI API](https://platform.openai.com/docs/api-reference/introduction). | ||||
| @@ -233,59 +257,65 @@ Note that the token needs to be created by an administrator to specify the chann | ||||
| If the channel ID is not provided, load balancing will be used to distribute the requests to multiple channels. | ||||
|  | ||||
| ### Environment Variables | ||||
|  | ||||
| 1. `REDIS_CONN_STRING`: When set, Redis will be used as the storage for request rate limiting instead of memory. | ||||
|     + Example: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||
|    - Example: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||
| 2. `SESSION_SECRET`: When set, a fixed session key will be used to ensure that cookies of logged-in users are still valid after the system restarts. | ||||
|     + Example: `SESSION_SECRET=random_string` | ||||
|    - Example: `SESSION_SECRET=random_string` | ||||
| 3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0. | ||||
|     + Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||
|    - Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||
| 4. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. | ||||
|     + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||
|    - Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||
| 5. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. | ||||
|     + Example: `SYNC_FREQUENCY=60` | ||||
|    - Example: `SYNC_FREQUENCY=60` | ||||
| 6. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. | ||||
|     + Example: `NODE_TYPE=slave` | ||||
|    - Example: `NODE_TYPE=slave` | ||||
| 7. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. | ||||
|     + Example: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||
|    - Example: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||
| 8. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. | ||||
|     + Example: `CHANNEL_TEST_FREQUENCY=1440` | ||||
|    - Example: `CHANNEL_TEST_FREQUENCY=1440` | ||||
| 9. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. | ||||
|     + Example: `POLLING_INTERVAL=5` | ||||
|    - Example: `POLLING_INTERVAL=5` | ||||
|  | ||||
| ### Command Line Parameters | ||||
|  | ||||
| 1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`. | ||||
|     + Example: `--port 3000` | ||||
|    - Example: `--port 3000` | ||||
| 2. `--log-dir <log_dir>`: Specifies the log directory. If not set, the logs will not be saved. | ||||
|     + Example: `--log-dir ./logs` | ||||
|    - Example: `--log-dir ./logs` | ||||
| 3. `--version`: Prints the system version number and exits. | ||||
| 4. `--help`: Displays the command usage help and parameter descriptions. | ||||
|  | ||||
| ## Screenshots | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| ## FAQ | ||||
|  | ||||
| 1. What is quota? How is it calculated? Does One API have quota calculation issues? | ||||
|     + Quota = Group multiplier * Model multiplier * (number of prompt tokens + number of completion tokens * completion multiplier) | ||||
|     + The completion multiplier is fixed at 1.33 for GPT3.5 and 2 for GPT4, consistent with the official definition. | ||||
|     + If it is not a stream mode, the official API will return the total number of tokens consumed. However, please note that the consumption multipliers for prompts and completions are different. | ||||
|    - Quota = Group multiplier _ Model multiplier _ (number of prompt tokens + number of completion tokens \* completion multiplier) | ||||
|    - The completion multiplier is fixed at 1.33 for GPT3.5 and 2 for GPT4, consistent with the official definition. | ||||
|    - If it is not a stream mode, the official API will return the total number of tokens consumed. However, please note that the consumption multipliers for prompts and completions are different. | ||||
| 2. Why does it prompt "insufficient quota" even though my account balance is sufficient? | ||||
|     + Please check if your token quota is sufficient. It is separate from the account balance. | ||||
|     + The token quota is used to set the maximum usage and can be freely set by the user. | ||||
|    - Please check if your token quota is sufficient. It is separate from the account balance. | ||||
|    - The token quota is used to set the maximum usage and can be freely set by the user. | ||||
| 3. It says "No available channels" when trying to use a channel. What should I do? | ||||
|     + Please check the user and channel group settings. | ||||
|     + Also check the channel model settings. | ||||
|    - Please check the user and channel group settings. | ||||
|    - Also check the channel model settings. | ||||
| 4. Channel testing reports an error: "invalid character '<' looking for beginning of value" | ||||
|     + This error occurs when the returned value is not valid JSON but an HTML page. | ||||
|     + Most likely, the IP of your deployment site or the node of the proxy has been blocked by CloudFlare. | ||||
|    - This error occurs when the returned value is not valid JSON but an HTML page. | ||||
|    - Most likely, the IP of your deployment site or the node of the proxy has been blocked by CloudFlare. | ||||
| 5. ChatGPT Next Web reports an error: "Failed to fetch" | ||||
|     + Do not set `BASE_URL` during deployment. | ||||
|     + Double-check that your interface address and API Key are correct. | ||||
|    - Do not set `BASE_URL` during deployment. | ||||
|    - Double-check that your interface address and API Key are correct. | ||||
|  | ||||
| ## Related Projects | ||||
|  | ||||
| [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM | ||||
|  | ||||
| ## Note | ||||
|  | ||||
| This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. | ||||
|  | ||||
| This project is released under the MIT license. Based on this, attribution and a link to this project must be included at the bottom of the page. | ||||
|   | ||||
							
								
								
									
										136
									
								
								README.ja.md
									
									
									
									
									
								
							
							
						
						
									
										136
									
								
								README.ja.md
									
									
									
									
									
								
							| @@ -10,28 +10,38 @@ | ||||
|  | ||||
| # One API | ||||
|  | ||||
| _このプロジェクトは、[one-api](https://github.com/songquanpeng/one-api)をベースにしており、元のプロジェクトのモジュールコードを分離し、モジュール化し、フロントエンドのインターフェースを変更しました。このプロジェクトも MIT ライセンスに従っています。_ | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="https://raw.githubusercontent.com/MartialBE/one-api/main/LICENSE"> | ||||
|     <img src="https://img.shields.io/github/license/MartialBE/one-api?color=brightgreen" alt="license"> | ||||
|   </a> | ||||
|   <a href="https://github.com/MartialBE/one-api/releases/latest"> | ||||
|     <img src="https://img.shields.io/github/v/release/MartialBE/one-api?color=brightgreen&include_prereleases" alt="release"> | ||||
|   </a> | ||||
|   <a href="https://github.com/users/MartialBE/packages/container/package/one-api"> | ||||
|     <img src="https://img.shields.io/badge/docker-ghcr.io-blue" alt="docker"> | ||||
|   </a> | ||||
|   <a href="https://goreportcard.com/report/github.com/MartialBE/one-api"> | ||||
|     <img src="https://goreportcard.com/badge/github.com/MartialBE/one-api" alt="GoReportCard"> | ||||
|   </a> | ||||
| </p> | ||||
|  | ||||
| **オリジナルバージョンと混合しないでください。チャンネル ID が異なるため、データの混乱を引き起こす可能性があります** | ||||
|  | ||||
| ## スクリーンショット | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| _以下は元の項目の説明です:_ | ||||
|  | ||||
| --- | ||||
|  | ||||
| _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM にアクセスでき、導入と利用が容易です ✨_ | ||||
|  | ||||
| </div> | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="https://raw.githubusercontent.com/songquanpeng/one-api/main/LICENSE"> | ||||
|     <img src="https://img.shields.io/github/license/songquanpeng/one-api?color=brightgreen" alt="license"> | ||||
|   </a> | ||||
|   <a href="https://github.com/songquanpeng/one-api/releases/latest"> | ||||
|     <img src="https://img.shields.io/github/v/release/songquanpeng/one-api?color=brightgreen&include_prereleases" alt="release"> | ||||
|   </a> | ||||
|   <a href="https://hub.docker.com/repository/docker/justsong/one-api"> | ||||
|     <img src="https://img.shields.io/docker/pulls/justsong/one-api?color=brightgreen" alt="docker pull"> | ||||
|   </a> | ||||
|   <a href="https://github.com/songquanpeng/one-api/releases/latest"> | ||||
|     <img src="https://img.shields.io/github/downloads/songquanpeng/one-api/total?color=brightgreen&include_prereleases" alt="release"> | ||||
|   </a> | ||||
|   <a href="https://goreportcard.com/report/github.com/songquanpeng/one-api"> | ||||
|     <img src="https://goreportcard.com/badge/github.com/songquanpeng/one-api" alt="GoReportCard"> | ||||
|   </a> | ||||
| </p> | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="#deployment">デプロイチュートリアル</a> | ||||
|   · | ||||
| @@ -57,13 +67,14 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に | ||||
| > **注**: Docker からプルされた最新のイメージは、`alpha` リリースかもしれません。安定性が必要な場合は、手動でバージョンを指定してください。 | ||||
|  | ||||
| ## 特徴 | ||||
|  | ||||
| 1. 複数の大型モデルをサポート: | ||||
|    + [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート) | ||||
|    + [x] [Anthropic Claude シリーズモデル](https://anthropic.com) | ||||
|    + [x] [Google PaLM2 シリーズモデル](https://developers.generativeai.google) | ||||
|    + [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||
|    + [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html) | ||||
|    + [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn) | ||||
|    - [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート) | ||||
|    - [x] [Anthropic Claude シリーズモデル](https://anthropic.com) | ||||
|    - [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google) | ||||
|    - [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||
|    - [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html) | ||||
|    - [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn) | ||||
| 2. **ロードバランシング**による複数チャンネルへのアクセスをサポート。 | ||||
| 3. ストリーム伝送によるタイプライター的効果を可能にする**ストリームモード**に対応。 | ||||
| 4. **マルチマシンデプロイ**に対応。[詳細はこちら](#multi-machine-deployment)を参照。 | ||||
| @@ -82,13 +93,15 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に | ||||
| 15. システム・アクセストークンによる管理 API アクセスをサポートする。 | ||||
| 16. Cloudflare Turnstile によるユーザー認証に対応。 | ||||
| 17. ユーザー管理と複数のユーザーログイン/登録方法をサポート: | ||||
|     + 電子メールによるログイン/登録とパスワードリセット。 | ||||
|     + [GitHub OAuth](https://github.com/settings/applications/new)。 | ||||
|     + WeChat 公式アカウントの認証([WeChat Server](https://github.com/songquanpeng/wechat-server)の追加導入が必要)。 | ||||
|     - 電子メールによるログイン/登録とパスワードリセット。 | ||||
|     - [GitHub OAuth](https://github.com/settings/applications/new)。 | ||||
|     - WeChat 公式アカウントの認証([WeChat Server](https://github.com/songquanpeng/wechat-server)の追加導入が必要)。 | ||||
| 18. 他の主要なモデル API が利用可能になった場合、即座にサポートし、カプセル化する。 | ||||
|  | ||||
| ## デプロイメント | ||||
|  | ||||
| ### Docker デプロイメント | ||||
|  | ||||
| デプロイコマンド: `docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api-en`。 | ||||
|  | ||||
| コマンドを更新する: `docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrr/watchtower -cR`。 | ||||
| @@ -97,7 +110,8 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に | ||||
|  | ||||
| データはホストの `/home/ubuntu/data/one-api` ディレクトリに保存される。このディレクトリが存在し、書き込み権限があることを確認する、もしくは適切なディレクトリに変更してください。 | ||||
|  | ||||
| Nginxリファレンス設定: | ||||
| Nginx リファレンス設定: | ||||
|  | ||||
| ``` | ||||
| server{ | ||||
|    server_name openai.justsong.cn;  # ドメイン名は適宜変更 | ||||
| @@ -116,6 +130,7 @@ server{ | ||||
| ``` | ||||
|  | ||||
| 次に、Let's Encrypt certbot を使って HTTPS を設定します: | ||||
|  | ||||
| ```bash | ||||
| # Ubuntu に certbot をインストール: | ||||
| sudo snap install --classic certbot | ||||
| @@ -130,7 +145,9 @@ sudo service nginx restart | ||||
| 初期アカウントのユーザー名は `root` で、パスワードは `123456` です。 | ||||
|  | ||||
| ### マニュアルデプロイ | ||||
|  | ||||
| 1. [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) から実行ファイルをダウンロードする、もしくはソースからコンパイルする: | ||||
|  | ||||
|    ```shell | ||||
|    git clone https://github.com/songquanpeng/one-api.git | ||||
|  | ||||
| @@ -144,6 +161,7 @@ sudo service nginx restart | ||||
|    go mod download | ||||
|    go build -ldflags "-s -w" -o one-api | ||||
|    ``` | ||||
|  | ||||
| 2. 実行: | ||||
|    ```shell | ||||
|    chmod u+x one-api | ||||
| @@ -154,6 +172,7 @@ sudo service nginx restart | ||||
| より詳細なデプロイのチュートリアルについては、[このページ](https://iamazing.cn/page/how-to-deploy-a-website) を参照してください。 | ||||
|  | ||||
| ### マルチマシンデプロイ | ||||
|  | ||||
| 1. すべてのサーバに同じ `SESSION_SECRET` を設定する。 | ||||
| 2. `SQL_DSN` を設定し、SQLite の代わりに MySQL を使用する。すべてのサーバは同じデータベースに接続する。 | ||||
| 3. マスターノード以外のノードの `NODE_TYPE` を `slave` に設定する。 | ||||
| @@ -165,11 +184,13 @@ sudo service nginx restart | ||||
| Please refer to the [environment variables](#environment-variables) section for details on using environment variables. | ||||
|  | ||||
| ### コントロールパネル(例: Baota)への展開 | ||||
|  | ||||
| 詳しい手順は [#175](https://github.com/songquanpeng/one-api/issues/175) を参照してください。 | ||||
|  | ||||
| 配置後に空白のページが表示される場合は、[#97](https://github.com/songquanpeng/one-api/issues/97) を参照してください。 | ||||
|  | ||||
| ### サードパーティプラットフォームへのデプロイ | ||||
|  | ||||
| <details> | ||||
| <summary><strong>Sealos へのデプロイ</strong></summary> | ||||
| <div> | ||||
| @@ -180,7 +201,6 @@ Please refer to the [environment variables](#environment-variables) section for | ||||
|  | ||||
| [](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) | ||||
|  | ||||
|  | ||||
| </div> | ||||
| </details> | ||||
|  | ||||
| @@ -190,10 +210,12 @@ Please refer to the [environment variables](#environment-variables) section for | ||||
|  | ||||
| > Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。 | ||||
|  | ||||
| [](https://zeabur.com/templates/7Q0KO3) | ||||
|  | ||||
| 1. まず、コードをフォークする。 | ||||
| 2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。 | ||||
| 3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。 | ||||
| 4. 接続パラメータをコピーし、```create database `one-api` ``` を実行してデータベースを作成する。 | ||||
| 3. 新しいプロジェクトを作成します。Service -> Add Service で Marketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。 | ||||
| 4. 接続パラメータをコピーし、`` create database `one-api`  `` を実行してデータベースを作成する。 | ||||
| 5. その後、Service -> Add Service で Git を選択し(最初の使用には認証が必要です)、フォークしたリポジトリを選択します。 | ||||
| 6. 自動デプロイが開始されますが、一旦キャンセルしてください。Variable タブで `PORT` に `3000` を追加し、`SQL_DSN` に `<username>:<password>@tcp(<addr>:<port>)/one-api` を追加します。変更を保存する。SQL_DSN` が設定されていないと、データが永続化されず、再デプロイ後にデータが失われるので注意すること。 | ||||
| 7. 再デプロイを選択します。 | ||||
| @@ -204,6 +226,7 @@ Please refer to the [environment variables](#environment-variables) section for | ||||
| </details> | ||||
|  | ||||
| ## コンフィグ | ||||
|  | ||||
| システムは箱から出してすぐに使えます。 | ||||
|  | ||||
| 環境変数やコマンドラインパラメータを設定することで、システムを構成することができます。 | ||||
| @@ -211,6 +234,7 @@ Please refer to the [environment variables](#environment-variables) section for | ||||
| システム起動後、`root` ユーザーとしてログインし、さらにシステムを設定します。 | ||||
|  | ||||
| ## 使用方法 | ||||
|  | ||||
| `Channels` ページで API Key を追加し、`Tokens` ページでアクセストークンを追加する。 | ||||
|  | ||||
| アクセストークンを使って One API にアクセスすることができる。使い方は [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) と同じです。 | ||||
| @@ -234,59 +258,65 @@ graph LR | ||||
| もしチャネル ID が指定されない場合、ロードバランシングによってリクエストが複数のチャネルに振り分けられます。 | ||||
|  | ||||
| ### 環境変数 | ||||
|  | ||||
| 1. `REDIS_CONN_STRING`: 設定すると、リクエストレート制限のためのストレージとして、メモリの代わりに Redis が使われる。 | ||||
|     + 例: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||
|    - 例: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||
| 2. `SESSION_SECRET`: 設定すると、固定セッションキーが使用され、システムの再起動後もログインユーザーのクッキーが有効であることが保証されます。 | ||||
|     + 例: `SESSION_SECRET=random_string` | ||||
|    - 例: `SESSION_SECRET=random_string` | ||||
| 3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。 | ||||
|     + 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||
|    - 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||
| 4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。 | ||||
|     + 例: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||
|    - 例: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||
| 5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 | ||||
|     + 例: `SYNC_FREQUENCY=60` | ||||
|    - 例: `SYNC_FREQUENCY=60` | ||||
| 6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。 | ||||
|     + 例: `NODE_TYPE=slave` | ||||
|    - 例: `NODE_TYPE=slave` | ||||
| 7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 | ||||
|     + 例: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||
|    - 例: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||
| 8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 | ||||
|     + 例: `CHANNEL_TEST_FREQUENCY=1440` | ||||
|    - 例: `CHANNEL_TEST_FREQUENCY=1440` | ||||
| 9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 | ||||
|     + 例: `POLLING_INTERVAL=5` | ||||
|    - 例: `POLLING_INTERVAL=5` | ||||
|  | ||||
| ### コマンドラインパラメータ | ||||
|  | ||||
| 1. `--port <port_number>`: サーバがリッスンするポート番号を指定。デフォルトは `3000` です。 | ||||
|     + 例: `--port 3000` | ||||
|    - 例: `--port 3000` | ||||
| 2. `--log-dir <log_dir>`: ログディレクトリを指定。設定しない場合、ログは保存されません。 | ||||
|     + 例: `--log-dir ./logs` | ||||
|    - 例: `--log-dir ./logs` | ||||
| 3. `--version`: システムのバージョン番号を表示して終了する。 | ||||
| 4. `--help`: コマンドの使用法ヘルプとパラメータの説明を表示。 | ||||
|  | ||||
| ## スクリーンショット | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| ## FAQ | ||||
|  | ||||
| 1. ノルマとは何か?どのように計算されますか?One API にはノルマ計算の問題はありますか? | ||||
|     + ノルマ = グループ倍率 * モデル倍率 * (プロンプトトークンの数 + 完了トークンの数 * 完了倍率) | ||||
|     + 完了倍率は、公式の定義と一致するように、GPT3.5 では 1.33、GPT4 では 2 に固定されています。 | ||||
|     + ストリームモードでない場合、公式 API は消費したトークンの総数を返す。ただし、プロンプトとコンプリートの消費倍率は異なるので注意してください。 | ||||
|    - ノルマ = グループ倍率 _ モデル倍率 _ (プロンプトトークンの数 + 完了トークンの数 \* 完了倍率) | ||||
|    - 完了倍率は、公式の定義と一致するように、GPT3.5 では 1.33、GPT4 では 2 に固定されています。 | ||||
|    - ストリームモードでない場合、公式 API は消費したトークンの総数を返す。ただし、プロンプトとコンプリートの消費倍率は異なるので注意してください。 | ||||
| 2. アカウント残高は十分なのに、"insufficient quota" と表示されるのはなぜですか? | ||||
|     + トークンのクォータが十分かどうかご確認ください。トークンクォータはアカウント残高とは別のものです。 | ||||
|     + トークンクォータは最大使用量を設定するためのもので、ユーザーが自由に設定できます。 | ||||
|    - トークンのクォータが十分かどうかご確認ください。トークンクォータはアカウント残高とは別のものです。 | ||||
|    - トークンクォータは最大使用量を設定するためのもので、ユーザーが自由に設定できます。 | ||||
| 3. チャンネルを使おうとすると "No available channels" と表示されます。どうすればいいですか? | ||||
|     + ユーザーとチャンネルグループの設定を確認してください。 | ||||
|     + チャンネルモデルの設定も確認してください。 | ||||
|    - ユーザーとチャンネルグループの設定を確認してください。 | ||||
|    - チャンネルモデルの設定も確認してください。 | ||||
| 4. チャンネルテストがエラーを報告する: "invalid character '<' looking for beginning of value" | ||||
|     + このエラーは、返された値が有効な JSON ではなく、HTML ページである場合に発生する。 | ||||
|     + ほとんどの場合、デプロイサイトのIPかプロキシのノードが CloudFlare によってブロックされています。 | ||||
|    - このエラーは、返された値が有効な JSON ではなく、HTML ページである場合に発生する。 | ||||
|    - ほとんどの場合、デプロイサイトの IP かプロキシのノードが CloudFlare によってブロックされています。 | ||||
| 5. ChatGPT Next Web でエラーが発生しました: "Failed to fetch" | ||||
|     + デプロイ時に `BASE_URL` を設定しないでください。 | ||||
|     + インターフェイスアドレスと API Key が正しいか再確認してください。 | ||||
|    - デプロイ時に `BASE_URL` を設定しないでください。 | ||||
|    - インターフェイスアドレスと API Key が正しいか再確認してください。 | ||||
|  | ||||
| ## 関連プロジェクト | ||||
|  | ||||
| [FastGPT](https://github.com/labring/FastGPT): LLM に基づく知識質問応答システム | ||||
|  | ||||
| ## 注 | ||||
|  | ||||
| 本プロジェクトはオープンソースプロジェクトです。OpenAI の[利用規約](https://openai.com/policies/terms-of-use)および**適用される法令**を遵守してご利用ください。違法な目的での利用はご遠慮ください。 | ||||
|  | ||||
| このプロジェクトは MIT ライセンスで公開されています。これに基づき、ページの最下部に帰属表示と本プロジェクトへのリンクを含める必要があります。 | ||||
|   | ||||
							
								
								
									
										252
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										252
									
								
								README.md
									
									
									
									
									
								
							| @@ -2,37 +2,46 @@ | ||||
|    <strong>中文</strong> | <a href="./README.en.md">English</a> | <a href="./README.ja.md">日本語</a> | ||||
| </p> | ||||
|  | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a> | ||||
|   <a href="https://github.com/MartialBE/one-api"><img src="https://raw.githubusercontent.com/MartialBE/one-api/main/web/src/assets/images/logo.svg" width="150" height="150" alt="one-api logo"></a> | ||||
| </p> | ||||
|  | ||||
| <div align="center"> | ||||
|  | ||||
| # One API | ||||
|  | ||||
| _本项目是基于[one-api](https://github.com/songquanpeng/one-api)二次开发而来的,主要将原项目中的模块代码分离,模块化,并修改了前端界面。本项目同样遵循 MIT 协议。_ | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="https://raw.githubusercontent.com/MartialBE/one-api/main/LICENSE"> | ||||
|     <img src="https://img.shields.io/github/license/MartialBE/one-api?color=brightgreen" alt="license"> | ||||
|   </a> | ||||
|   <a href="https://github.com/MartialBE/one-api/releases/latest"> | ||||
|     <img src="https://img.shields.io/github/v/release/MartialBE/one-api?color=brightgreen&include_prereleases" alt="release"> | ||||
|   </a> | ||||
|   <a href="https://github.com/users/MartialBE/packages/container/package/one-api"> | ||||
|     <img src="https://img.shields.io/badge/docker-ghcr.io-blue" alt="docker"> | ||||
|   </a> | ||||
|   <a href="https://goreportcard.com/report/github.com/MartialBE/one-api"> | ||||
|     <img src="https://goreportcard.com/badge/github.com/MartialBE/one-api" alt="GoReportCard"> | ||||
|   </a> | ||||
| </p> | ||||
|  | ||||
| **请不要和原版混用,因为 channel id 不同的原因,会导致数据错乱** | ||||
|  | ||||
| # 截图展示 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| _以下为原项目说明:_ | ||||
|  | ||||
| --- | ||||
|  | ||||
| _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 ✨_ | ||||
|  | ||||
| </div> | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="https://raw.githubusercontent.com/songquanpeng/one-api/main/LICENSE"> | ||||
|     <img src="https://img.shields.io/github/license/songquanpeng/one-api?color=brightgreen" alt="license"> | ||||
|   </a> | ||||
|   <a href="https://github.com/songquanpeng/one-api/releases/latest"> | ||||
|     <img src="https://img.shields.io/github/v/release/songquanpeng/one-api?color=brightgreen&include_prereleases" alt="release"> | ||||
|   </a> | ||||
|   <a href="https://hub.docker.com/repository/docker/justsong/one-api"> | ||||
|     <img src="https://img.shields.io/docker/pulls/justsong/one-api?color=brightgreen" alt="docker pull"> | ||||
|   </a> | ||||
|   <a href="https://github.com/songquanpeng/one-api/releases/latest"> | ||||
|     <img src="https://img.shields.io/github/downloads/songquanpeng/one-api/total?color=brightgreen&include_prereleases" alt="release"> | ||||
|   </a> | ||||
|   <a href="https://goreportcard.com/report/github.com/songquanpeng/one-api"> | ||||
|     <img src="https://goreportcard.com/badge/github.com/songquanpeng/one-api" alt="GoReportCard"> | ||||
|   </a> | ||||
| </p> | ||||
|  | ||||
| <p align="center"> | ||||
|   <a href="https://github.com/songquanpeng/one-api#部署">部署教程</a> | ||||
|   · | ||||
| @@ -51,34 +60,30 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|   <a href="https://iamazing.cn/page/reward">赞赏支持</a> | ||||
| </p> | ||||
|  | ||||
| > **Note** | ||||
| > [!NOTE] | ||||
| > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 | ||||
| >  | ||||
| > | ||||
| > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 | ||||
|  | ||||
| > **Warning** | ||||
| > [!WARNING] | ||||
| > 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 | ||||
|  | ||||
| > **Warning** | ||||
| > [!WARNING] | ||||
| > 使用 root 用户初次登录系统后,务必修改默认密码 `123456`! | ||||
|  | ||||
| ## 功能 | ||||
|  | ||||
| 1. 支持多种大模型: | ||||
|    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||
|    + [x] [Anthropic Claude 系列模型](https://anthropic.com) | ||||
|    + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) | ||||
|    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||
|    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) | ||||
|    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||
|    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) | ||||
|    + [x] [360 智脑](https://ai.360.cn) | ||||
| 2. 支持配置镜像以及众多第三方代理服务: | ||||
|    + [x] [OpenAI-SB](https://openai-sb.com) | ||||
|    + [x] [CloseAI](https://console.closeai-asia.com/r/2412) | ||||
|    + [x] [API2D](https://api2d.com/r/197971) | ||||
|    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) | ||||
|    + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) | ||||
|    + [x] 自定义渠道:例如各种未收录的第三方代理服务 | ||||
|    - [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||
|    - [x] [Anthropic Claude 系列模型](https://anthropic.com) | ||||
|    - [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) | ||||
|    - [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||
|    - [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) | ||||
|    - [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||
|    - [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) | ||||
|    - [x] [360 智脑](https://ai.360.cn) | ||||
|    - [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) | ||||
| 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | ||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||
| 5. 支持**多机部署**,[详见此处](#多机部署)。 | ||||
| @@ -91,21 +96,24 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
| 12. 支持**用户邀请奖励**。 | ||||
| 13. 支持以美元为单位显示额度。 | ||||
| 14. 支持发布公告,设置充值链接,设置新用户初始额度。 | ||||
| 15. 支持模型映射,重定向用户的请求模型。 | ||||
| 15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功。 | ||||
| 16. 支持失败自动重试。 | ||||
| 17. 支持绘图接口。 | ||||
| 18. 支持丰富的**自定义**设置, | ||||
| 18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。 | ||||
| 19. 支持丰富的**自定义**设置, | ||||
|     1. 支持自定义系统名称,logo 以及页脚。 | ||||
|     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 | ||||
| 19. 支持通过系统访问令牌访问管理 API。 | ||||
| 20. 支持 Cloudflare Turnstile 用户校验。 | ||||
| 21. 支持用户管理,支持**多种用户登录注册方式**: | ||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||
| 20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。 | ||||
| 21. 支持 Cloudflare Turnstile 用户校验。 | ||||
| 22. 支持用户管理,支持**多种用户登录注册方式**: | ||||
|     - 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||
|     - [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||
|     - 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||
|  | ||||
| ## 部署 | ||||
|  | ||||
| ### 基于 Docker 进行部署 | ||||
|  | ||||
| ```shell | ||||
| # 使用 SQLite 的部署命令: | ||||
| docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api | ||||
| @@ -127,10 +135,11 @@ docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234 | ||||
| 更新命令:`docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR` | ||||
|  | ||||
| Nginx 的参考配置: | ||||
|  | ||||
| ``` | ||||
| server{ | ||||
|    server_name openai.justsong.cn;  # 请根据实际情况修改你的域名 | ||||
|     | ||||
|  | ||||
|    location / { | ||||
|           client_max_body_size  64m; | ||||
|           proxy_http_version 1.1; | ||||
| @@ -145,6 +154,7 @@ server{ | ||||
| ``` | ||||
|  | ||||
| 之后使用 Let's Encrypt 的 certbot 配置 HTTPS: | ||||
|  | ||||
| ```bash | ||||
| # Ubuntu 安装 certbot: | ||||
| sudo snap install --classic certbot | ||||
| @@ -158,21 +168,36 @@ sudo service nginx restart | ||||
|  | ||||
| 初始账号用户名为 `root`,密码为 `123456`。 | ||||
|  | ||||
| ### 基于 Docker Compose 进行部署 | ||||
|  | ||||
| > 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分 | ||||
|  | ||||
| ```shell | ||||
| # 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内 | ||||
| docker-compose up -d | ||||
|  | ||||
| # 查看部署状态 | ||||
| docker-compose ps | ||||
| ``` | ||||
|  | ||||
| ### 手动部署 | ||||
|  | ||||
| 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: | ||||
|  | ||||
|    ```shell | ||||
|    git clone https://github.com/songquanpeng/one-api.git | ||||
|     | ||||
|  | ||||
|    # 构建前端 | ||||
|    cd one-api/web | ||||
|    npm install | ||||
|    npm run build | ||||
|     | ||||
|  | ||||
|    # 构建后端 | ||||
|    cd .. | ||||
|    go mod download | ||||
|    go build -ldflags "-s -w" -o one-api | ||||
|    ```` | ||||
|    ``` | ||||
|  | ||||
| 2. 运行: | ||||
|    ```shell | ||||
|    chmod u+x one-api | ||||
| @@ -183,6 +208,7 @@ sudo service nginx restart | ||||
| 更加详细的部署教程[参见此处](https://iamazing.cn/page/how-to-deploy-a-website)。 | ||||
|  | ||||
| ### 多机部署 | ||||
|  | ||||
| 1. 所有服务器 `SESSION_SECRET` 设置一样的值。 | ||||
| 2. 必须设置 `SQL_DSN`,使用 MySQL 数据库而非 SQLite,所有服务器连接同一个数据库。 | ||||
| 3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。 | ||||
| @@ -200,9 +226,11 @@ sudo service nginx restart | ||||
| 如果部署后访问出现空白页面,详见 [#97](https://github.com/songquanpeng/one-api/issues/97)。 | ||||
|  | ||||
| ### 部署第三方服务配合 One API 使用 | ||||
|  | ||||
| > 欢迎 PR 添加更多示例。 | ||||
|  | ||||
| #### ChatGPT Next Web | ||||
|  | ||||
| 项目主页:https://github.com/Yidadaa/ChatGPT-Next-Web | ||||
|  | ||||
| ```bash | ||||
| @@ -212,6 +240,7 @@ docker run --name chat-next-web -d -p 3001:3000 yidadaa/chatgpt-next-web | ||||
| 注意修改端口号,之后在页面上设置接口地址(例如:https://openai.justsong.cn/ )和 API Key 即可。 | ||||
|  | ||||
| #### ChatGPT Web | ||||
|  | ||||
| 项目主页:https://github.com/Chanzhaoyu/chatgpt-web | ||||
|  | ||||
| ```bash | ||||
| @@ -220,14 +249,16 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | ||||
|  | ||||
| 注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。 | ||||
|  | ||||
| #### QChatGPT - QQ机器人 | ||||
| #### QChatGPT - QQ 机器人 | ||||
|  | ||||
| 项目主页:https://github.com/RockChinQ/QChatGPT | ||||
|  | ||||
| 根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。 | ||||
| 根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的 key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。 | ||||
|  | ||||
| 可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。 | ||||
|  | ||||
| ### 部署到第三方平台 | ||||
|  | ||||
| <details> | ||||
| <summary><strong>部署到 Sealos </strong></summary> | ||||
| <div> | ||||
| @@ -247,10 +278,12 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | ||||
|  | ||||
| > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用 | ||||
|  | ||||
| [](https://zeabur.com/templates/7Q0KO3) | ||||
|  | ||||
| 1. 首先 fork 一份代码。 | ||||
| 2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 | ||||
| 3. 新建一个 Project,在 Service -> Add Service 选择 Marketplace,选择 MySQL,并记下连接参数(用户名、密码、地址、端口)。 | ||||
| 4. 复制链接参数,运行 ```create database `one-api` ``` 创建数据库。 | ||||
| 4. 复制链接参数,运行 `` create database `one-api`  `` 创建数据库。 | ||||
| 5. 然后在 Service -> Add Service,选择 Git(第一次使用需要先授权),选择你 fork 的仓库。 | ||||
| 6. Deploy 会自动开始,先取消。进入下方 Variable,添加一个 `PORT`,值为 `3000`,再添加一个 `SQL_DSN`,值为 `<username>:<password>@tcp(<addr>:<port>)/one-api` ,然后保存。 注意如果不填写 `SQL_DSN`,数据将无法持久化,重新部署后数据会丢失。 | ||||
| 7. 选择 Redeploy。 | ||||
| @@ -272,6 +305,7 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashbo | ||||
| </details> | ||||
|  | ||||
| ## 配置 | ||||
|  | ||||
| 系统本身开箱即用。 | ||||
|  | ||||
| 你可以通过设置环境变量或者命令行参数进行配置。 | ||||
| @@ -281,6 +315,7 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashbo | ||||
| **Note**:如果你不知道某个配置项的含义,可以临时删掉值以看到进一步的提示文字。 | ||||
|  | ||||
| ## 使用方法 | ||||
|  | ||||
| 在`渠道`页面中添加你的 API Key,之后在`令牌`页面中新增访问令牌。 | ||||
|  | ||||
| 之后就可以使用你的令牌访问 One API 了,使用方式与 [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) 一致。 | ||||
| @@ -290,9 +325,10 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashbo | ||||
| 注意,具体的 API Base 的格式取决于你所使用的客户端。 | ||||
|  | ||||
| 例如对于 OpenAI 的官方库: | ||||
|  | ||||
| ```bash | ||||
| OPENAI_API_KEY="sk-xxxxxx" | ||||
| OPENAI_API_BASE="https://<HOST>:<PORT>/v1"  | ||||
| OPENAI_API_BASE="https://<HOST>:<PORT>/v1" | ||||
| ``` | ||||
|  | ||||
| ```mermaid | ||||
| @@ -311,88 +347,106 @@ graph LR | ||||
| 不加的话将会使用负载均衡的方式使用多个渠道。 | ||||
|  | ||||
| ### 环境变量 | ||||
|  | ||||
| 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 | ||||
|    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||
|    + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 | ||||
|    - 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||
|    - 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 | ||||
| 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 | ||||
|    + 例子:`SESSION_SECRET=random_string` | ||||
|    - 例子:`SESSION_SECRET=random_string` | ||||
| 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 | ||||
|    + 例子: | ||||
|      + MySQL:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||
|      + PostgreSQL:`SQL_DSN=postgres://postgres:123456@localhost:5432/oneapi`(适配中,欢迎反馈) | ||||
|    + 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。 | ||||
|    + 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。 | ||||
|    + 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。 | ||||
|    + 请根据你的数据库配置修改下列参数(或者保持默认值): | ||||
|      + `SQL_MAX_IDLE_CONNS`:最大空闲连接数,默认为 `100`。 | ||||
|      + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 | ||||
|        + 如果报错 `Error 1040: Too many connections`,请适当减小该值。 | ||||
|      + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 | ||||
|    - 例子: | ||||
|      - MySQL:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||
|      - PostgreSQL:`SQL_DSN=postgres://postgres:123456@localhost:5432/oneapi`(适配中,欢迎反馈) | ||||
|    - 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。 | ||||
|    - 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。 | ||||
|    - 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。 | ||||
|    - 请根据你的数据库配置修改下列参数(或者保持默认值): | ||||
|      - `SQL_MAX_IDLE_CONNS`:最大空闲连接数,默认为 `100`。 | ||||
|      - `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 | ||||
|        - 如果报错 `Error 1040: Too many connections`,请适当减小该值。 | ||||
|      - `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 | ||||
| 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | ||||
|    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||
|    - 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||
| 5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||
|    + 例子:`MEMORY_CACHE_ENABLED=true` | ||||
|    - 例子:`MEMORY_CACHE_ENABLED=true` | ||||
| 6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 | ||||
|    + 例子:`SYNC_FREQUENCY=60` | ||||
|    - 例子:`SYNC_FREQUENCY=60` | ||||
| 7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | ||||
|    + 例子:`NODE_TYPE=slave` | ||||
|    - 例子:`NODE_TYPE=slave` | ||||
| 8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | ||||
|    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` | ||||
|    - 例子:`CHANNEL_UPDATE_FREQUENCY=1440` | ||||
| 9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | ||||
|    + 例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||
|    - 例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||
| 10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||
|     + 例子:`POLLING_INTERVAL=5` | ||||
|     - 例子:`POLLING_INTERVAL=5` | ||||
| 11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||
|     + 例子:`BATCH_UPDATE_ENABLED=true` | ||||
|     + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 | ||||
|     - 例子:`BATCH_UPDATE_ENABLED=true` | ||||
|     - 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 | ||||
| 12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||
|     + 例子:`BATCH_UPDATE_INTERVAL=5` | ||||
|     - 例子:`BATCH_UPDATE_INTERVAL=5` | ||||
| 13. 请求频率限制: | ||||
|     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 | ||||
|     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 | ||||
|     - `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 | ||||
|     - `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 | ||||
| 14. 编码器缓存设置: | ||||
|     - `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 | ||||
|     - `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 | ||||
| 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | ||||
| 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | ||||
|  | ||||
| ### 命令行参数 | ||||
|  | ||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||
|    + 例子:`--port 3000` | ||||
|    - 例子:`--port 3000` | ||||
| 2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。 | ||||
|    + 例子:`--log-dir ./logs` | ||||
|    - 例子:`--log-dir ./logs` | ||||
| 3. `--version`: 打印系统版本号并退出。 | ||||
| 4. `--help`: 查看命令的使用帮助和参数说明。 | ||||
|  | ||||
| ## 演示 | ||||
|  | ||||
| ### 在线演示 | ||||
|  | ||||
| 注意,该演示站不提供对外服务: | ||||
| https://openai.justsong.cn | ||||
|  | ||||
| ### 截图展示 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| ## 常见问题 | ||||
|  | ||||
| 1. 额度是什么?怎么计算的?One API 的额度计算有问题? | ||||
|    + 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率) | ||||
|    + 其中补全倍率对于 GPT3.5 固定为 1.33,GPT4 为 2,与官方保持一致。 | ||||
|    + 如果是非流模式,官方接口会返回消耗的总 token,但是你要注意提示和补全的消耗倍率不一样。 | ||||
|    + 注意,One API 的默认倍率就是官方倍率,是已经调整过的。 | ||||
|    - 额度 = 分组倍率 _ 模型倍率 _ (提示 token 数 + 补全 token 数 \* 补全倍率) | ||||
|    - 其中补全倍率对于 GPT3.5 固定为 1.33,GPT4 为 2,与官方保持一致。 | ||||
|    - 如果是非流模式,官方接口会返回消耗的总 token,但是你要注意提示和补全的消耗倍率不一样。 | ||||
|    - 注意,One API 的默认倍率就是官方倍率,是已经调整过的。 | ||||
| 2. 账户额度足够为什么提示额度不足? | ||||
|    + 请检查你的令牌额度是否足够,这个和账户额度是分开的。 | ||||
|    + 令牌额度仅供用户设置最大使用量,用户可自由设置。 | ||||
|    - 请检查你的令牌额度是否足够,这个和账户额度是分开的。 | ||||
|    - 令牌额度仅供用户设置最大使用量,用户可自由设置。 | ||||
| 3. 提示无可用渠道? | ||||
|    + 请检查的用户分组和渠道分组设置。 | ||||
|    + 以及渠道的模型设置。 | ||||
|    - 请检查的用户分组和渠道分组设置。 | ||||
|    - 以及渠道的模型设置。 | ||||
| 4. 渠道测试报错:`invalid character '<' looking for beginning of value` | ||||
|    + 这是因为返回值不是合法的 JSON,而是一个 HTML 页面。 | ||||
|    + 大概率是你的部署站的 IP 或代理的节点被 CloudFlare 封禁了。 | ||||
|    - 这是因为返回值不是合法的 JSON,而是一个 HTML 页面。 | ||||
|    - 大概率是你的部署站的 IP 或代理的节点被 CloudFlare 封禁了。 | ||||
| 5. ChatGPT Next Web 报错:`Failed to fetch` | ||||
|    + 部署的时候不要设置 `BASE_URL`。 | ||||
|    + 检查你的接口地址和 API Key 有没有填对。 | ||||
|    + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 | ||||
|    - 部署的时候不要设置 `BASE_URL`。 | ||||
|    - 检查你的接口地址和 API Key 有没有填对。 | ||||
|    - 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 | ||||
| 6. 报错:`当前分组负载已饱和,请稍后再试` | ||||
|    + 上游通道 429 了。 | ||||
|    - 上游通道 429 了。 | ||||
| 7. 升级之后我的数据会丢失吗? | ||||
|    - 如果使用 MySQL,不会。 | ||||
|    - 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 | ||||
| 8. 升级之前数据库需要做变更吗? | ||||
|    - 一般情况下不需要,系统将在初始化的时候自动调整。 | ||||
|    - 如果需要的话,我会在更新日志中说明,并给出脚本。 | ||||
|  | ||||
| ## 相关项目 | ||||
| * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | ||||
| * [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web):  一键拥有你自己的跨平台 ChatGPT 应用 | ||||
|  | ||||
| - [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | ||||
| - [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用 | ||||
|  | ||||
| ## 注意 | ||||
|  | ||||
|   | ||||
							
								
								
									
										257
									
								
								common/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										257
									
								
								common/client.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,257 @@ | ||||
| package common | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/types" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| var HttpClient *http.Client | ||||
|  | ||||
| func init() { | ||||
| 	if RelayTimeout == 0 { | ||||
| 		HttpClient = &http.Client{} | ||||
| 	} else { | ||||
| 		HttpClient = &http.Client{ | ||||
| 			Timeout: time.Duration(RelayTimeout) * time.Second, | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type Client struct { | ||||
| 	requestBuilder    RequestBuilder | ||||
| 	CreateFormBuilder func(io.Writer) FormBuilder | ||||
| } | ||||
|  | ||||
| func NewClient() *Client { | ||||
| 	return &Client{ | ||||
| 		requestBuilder: NewRequestBuilder(), | ||||
| 		CreateFormBuilder: func(body io.Writer) FormBuilder { | ||||
| 			return NewFormBuilder(body) | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type requestOptions struct { | ||||
| 	body   any | ||||
| 	header http.Header | ||||
| } | ||||
|  | ||||
| type requestOption func(*requestOptions) | ||||
|  | ||||
| type Stringer interface { | ||||
| 	GetString() *string | ||||
| } | ||||
|  | ||||
| func WithBody(body any) requestOption { | ||||
| 	return func(args *requestOptions) { | ||||
| 		args.body = body | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func WithHeader(header map[string]string) requestOption { | ||||
| 	return func(args *requestOptions) { | ||||
| 		for k, v := range header { | ||||
| 			args.header.Set(k, v) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func WithContentType(contentType string) requestOption { | ||||
| 	return func(args *requestOptions) { | ||||
| 		args.header.Set("Content-Type", contentType) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type RequestError struct { | ||||
| 	HTTPStatusCode int | ||||
| 	Err            error | ||||
| } | ||||
|  | ||||
| func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) { | ||||
| 	// Default Options | ||||
| 	args := &requestOptions{ | ||||
| 		body:   nil, | ||||
| 		header: make(http.Header), | ||||
| 	} | ||||
| 	for _, setter := range setters { | ||||
| 		setter(args) | ||||
| 	} | ||||
| 	req, err := c.requestBuilder.Build(method, url, args.body, args.header) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return req, nil | ||||
| } | ||||
|  | ||||
| func SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) { | ||||
| 	// 发送请求 | ||||
| 	resp, err := HttpClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	if !outputResp { | ||||
| 		defer resp.Body.Close() | ||||
| 	} | ||||
|  | ||||
| 	// 处理响应 | ||||
| 	if IsFailureStatusCode(resp) { | ||||
| 		return nil, HandleErrorResp(resp) | ||||
| 	} | ||||
|  | ||||
| 	// 解析响应 | ||||
| 	if outputResp { | ||||
| 		var buf bytes.Buffer | ||||
| 		tee := io.TeeReader(resp.Body, &buf) | ||||
| 		err = DecodeResponse(tee, response) | ||||
|  | ||||
| 		// 将响应体重新写入 resp.Body | ||||
| 		resp.Body = io.NopCloser(&buf) | ||||
| 	} else { | ||||
| 		err = DecodeResponse(resp.Body, response) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return nil, ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	if outputResp { | ||||
| 		return resp, nil | ||||
| 	} | ||||
|  | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| type GeneralErrorResponse struct { | ||||
| 	Error    types.OpenAIError `json:"error"` | ||||
| 	Message  string            `json:"message"` | ||||
| 	Msg      string            `json:"msg"` | ||||
| 	Err      string            `json:"err"` | ||||
| 	ErrorMsg string            `json:"error_msg"` | ||||
| 	Header   struct { | ||||
| 		Message string `json:"message"` | ||||
| 	} `json:"header"` | ||||
| 	Response struct { | ||||
| 		Error struct { | ||||
| 			Message string `json:"message"` | ||||
| 		} `json:"error"` | ||||
| 	} `json:"response"` | ||||
| } | ||||
|  | ||||
| func (e GeneralErrorResponse) ToMessage() string { | ||||
| 	if e.Error.Message != "" { | ||||
| 		return e.Error.Message | ||||
| 	} | ||||
| 	if e.Message != "" { | ||||
| 		return e.Message | ||||
| 	} | ||||
| 	if e.Msg != "" { | ||||
| 		return e.Msg | ||||
| 	} | ||||
| 	if e.Err != "" { | ||||
| 		return e.Err | ||||
| 	} | ||||
| 	if e.ErrorMsg != "" { | ||||
| 		return e.ErrorMsg | ||||
| 	} | ||||
| 	if e.Header.Message != "" { | ||||
| 		return e.Header.Message | ||||
| 	} | ||||
| 	if e.Response.Error.Message != "" { | ||||
| 		return e.Response.Error.Message | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| // 处理错误响应 | ||||
| func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { | ||||
| 	openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{ | ||||
| 		StatusCode: resp.StatusCode, | ||||
| 		OpenAIError: types.OpenAIError{ | ||||
| 			Message: "", | ||||
| 			Type:    "upstream_error", | ||||
| 			Code:    "bad_response_status_code", | ||||
| 			Param:   strconv.Itoa(resp.StatusCode), | ||||
| 		}, | ||||
| 	} | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	// var errorResponse types.OpenAIErrorResponse | ||||
| 	var errorResponse GeneralErrorResponse | ||||
| 	err = json.Unmarshal(responseBody, &errorResponse) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if errorResponse.Error.Message != "" { | ||||
| 		// OpenAI format error, so we override the default one | ||||
| 		openAIErrorWithStatusCode.OpenAIError = errorResponse.Error | ||||
| 	} else { | ||||
| 		openAIErrorWithStatusCode.OpenAIError.Message = errorResponse.ToMessage() | ||||
| 	} | ||||
| 	if openAIErrorWithStatusCode.OpenAIError.Message == "" { | ||||
| 		openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (c *Client) SendRequestRaw(req *http.Request) (body io.ReadCloser, err error) { | ||||
| 	resp, err := HttpClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	return resp.Body, nil | ||||
| } | ||||
|  | ||||
| func IsFailureStatusCode(resp *http.Response) bool { | ||||
| 	return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest | ||||
| } | ||||
|  | ||||
| func DecodeResponse(body io.Reader, v any) error { | ||||
| 	if v == nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	if result, ok := v.(*string); ok { | ||||
| 		return DecodeString(body, result) | ||||
| 	} | ||||
|  | ||||
| 	if stringer, ok := v.(Stringer); ok { | ||||
| 		return DecodeString(body, stringer.GetString()) | ||||
| 	} | ||||
|  | ||||
| 	return json.NewDecoder(body).Decode(v) | ||||
| } | ||||
|  | ||||
| func DecodeString(body io.Reader, output *string) error { | ||||
| 	b, err := io.ReadAll(body) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	*output = string(b) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func SetEventStreamHeaders(c *gin.Context) { | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| } | ||||
| @@ -21,12 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens | ||||
| var DisplayInCurrencyEnabled = true | ||||
| var DisplayTokenStatEnabled = true | ||||
|  | ||||
| var UsingSQLite = false | ||||
|  | ||||
| // Any options with "Secret", "Token" in its key won't be return by GetOptions | ||||
|  | ||||
| var SessionSecret = uuid.New().String() | ||||
| var SQLitePath = "one-api.db" | ||||
|  | ||||
| var OptionMap map[string]string | ||||
| var OptionMapRWMutex sync.RWMutex | ||||
| @@ -81,6 +78,7 @@ var QuotaForInviter = 0 | ||||
| var QuotaForInvitee = 0 | ||||
| var ChannelDisableThreshold = 5.0 | ||||
| var AutomaticDisableChannelEnabled = false | ||||
| var AutomaticEnableChannelEnabled = false | ||||
| var QuotaRemindThreshold = 1000 | ||||
| var PreConsumedQuota = 500 | ||||
| var ApproximateTokenEnabled = false | ||||
| @@ -98,6 +96,8 @@ var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second | ||||
| var BatchUpdateEnabled = false | ||||
| var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) | ||||
|  | ||||
| var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second | ||||
|  | ||||
| const ( | ||||
| 	RequestIdKey = "X-Oneapi-Request-Id" | ||||
| ) | ||||
| @@ -122,7 +122,7 @@ var ( | ||||
| 	GlobalApiRateLimitNum            = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) | ||||
| 	GlobalApiRateLimitDuration int64 = 3 * 60 | ||||
|  | ||||
| 	GlobalWebRateLimitNum            = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) | ||||
| 	GlobalWebRateLimitNum            = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 100) | ||||
| 	GlobalWebRateLimitDuration int64 = 3 * 60 | ||||
|  | ||||
| 	UploadRateLimitNum            = 10 | ||||
| @@ -156,9 +156,10 @@ const ( | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	ChannelStatusUnknown  = 0 | ||||
| 	ChannelStatusEnabled  = 1 // don't use 0, 0 is the default value! | ||||
| 	ChannelStatusDisabled = 2 // also don't use 0 | ||||
| 	ChannelStatusUnknown          = 0 | ||||
| 	ChannelStatusEnabled          = 1 // don't use 0, 0 is the default value! | ||||
| 	ChannelStatusManuallyDisabled = 2 // also don't use 0 | ||||
| 	ChannelStatusAutoDisabled     = 3 | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -185,30 +186,51 @@ const ( | ||||
| 	ChannelTypeOpenRouter     = 20 | ||||
| 	ChannelTypeAIProxyLibrary = 21 | ||||
| 	ChannelTypeFastGPT        = 22 | ||||
| 	ChannelTypeTencent        = 23 | ||||
| 	ChannelTypeAzureSpeech    = 24 | ||||
| 	ChannelTypeGemini         = 25 | ||||
| ) | ||||
|  | ||||
| var ChannelBaseURLs = []string{ | ||||
| 	"",                                // 0 | ||||
| 	"https://api.openai.com",          // 1 | ||||
| 	"https://oa.api2d.net",            // 2 | ||||
| 	"",                                // 3 | ||||
| 	"https://api.closeai-proxy.xyz",   // 4 | ||||
| 	"https://api.openai-sb.com",       // 5 | ||||
| 	"https://api.openaimax.com",       // 6 | ||||
| 	"https://api.ohmygpt.com",         // 7 | ||||
| 	"",                                // 8 | ||||
| 	"https://api.caipacity.com",       // 9 | ||||
| 	"https://api.aiproxy.io",          // 10 | ||||
| 	"",                                // 11 | ||||
| 	"https://api.api2gpt.com",         // 12 | ||||
| 	"https://api.aigc2d.com",          // 13 | ||||
| 	"https://api.anthropic.com",       // 14 | ||||
| 	"https://aip.baidubce.com",        // 15 | ||||
| 	"https://open.bigmodel.cn",        // 16 | ||||
| 	"https://dashscope.aliyuncs.com",  // 17 | ||||
| 	"",                                // 18 | ||||
| 	"https://ai.360.cn",               // 19 | ||||
| 	"https://openrouter.ai/api",       // 20 | ||||
| 	"https://api.aiproxy.io",          // 21 | ||||
| 	"https://fastgpt.run/api/openapi", // 22 | ||||
| 	"",                                  // 0 | ||||
| 	"https://api.openai.com",            // 1 | ||||
| 	"https://oa.api2d.net",              // 2 | ||||
| 	"",                                  // 3 | ||||
| 	"https://api.closeai-proxy.xyz",     // 4 | ||||
| 	"https://api.openai-sb.com",         // 5 | ||||
| 	"https://api.openaimax.com",         // 6 | ||||
| 	"https://api.ohmygpt.com",           // 7 | ||||
| 	"",                                  // 8 | ||||
| 	"https://api.caipacity.com",         // 9 | ||||
| 	"https://api.aiproxy.io",            // 10 | ||||
| 	"",                                  // 11 | ||||
| 	"https://api.api2gpt.com",           // 12 | ||||
| 	"https://api.aigc2d.com",            // 13 | ||||
| 	"https://api.anthropic.com",         // 14 | ||||
| 	"https://aip.baidubce.com",          // 15 | ||||
| 	"https://open.bigmodel.cn",          // 16 | ||||
| 	"https://dashscope.aliyuncs.com",    // 17 | ||||
| 	"",                                  // 18 | ||||
| 	"https://ai.360.cn",                 // 19 | ||||
| 	"https://openrouter.ai/api",         // 20 | ||||
| 	"https://api.aiproxy.io",            // 21 | ||||
| 	"https://fastgpt.run/api/openapi",   // 22 | ||||
| 	"https://hunyuan.cloud.tencent.com", //23 | ||||
| 	"",                                  //24 | ||||
| 	"",                                  //25 | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	RelayModeUnknown = iota | ||||
| 	RelayModeChatCompletions | ||||
| 	RelayModeCompletions | ||||
| 	RelayModeEmbeddings | ||||
| 	RelayModeModerations | ||||
| 	RelayModeImagesGenerations | ||||
| 	RelayModeImagesEdits | ||||
| 	RelayModeImagesVariations | ||||
| 	RelayModeEdits | ||||
| 	RelayModeAudioSpeech | ||||
| 	RelayModeAudioTranscription | ||||
| 	RelayModeAudioTranslation | ||||
| ) | ||||
|   | ||||
							
								
								
									
										7
									
								
								common/database.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								common/database.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| package common | ||||
|  | ||||
| var UsingSQLite = false | ||||
| var UsingPostgreSQL = false | ||||
|  | ||||
| var SQLitePath = "one-api.db" | ||||
| var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000) | ||||
| @@ -1,11 +1,13 @@ | ||||
| package common | ||||
|  | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| 	"crypto/tls" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"net/smtp" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func SendEmail(subject string, receiver string, content string) error { | ||||
| @@ -13,15 +15,32 @@ func SendEmail(subject string, receiver string, content string) error { | ||||
| 		SMTPFrom = SMTPAccount | ||||
| 	} | ||||
| 	encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) | ||||
|  | ||||
| 	// Extract domain from SMTPFrom | ||||
| 	parts := strings.Split(SMTPFrom, "@") | ||||
| 	var domain string | ||||
| 	if len(parts) > 1 { | ||||
| 		domain = parts[1] | ||||
| 	} | ||||
| 	// Generate a unique Message-ID | ||||
| 	buf := make([]byte, 16) | ||||
| 	_, err := rand.Read(buf) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	messageId := fmt.Sprintf("<%x@%s>", buf, domain) | ||||
|  | ||||
| 	mail := []byte(fmt.Sprintf("To: %s\r\n"+ | ||||
| 		"From: %s<%s>\r\n"+ | ||||
| 		"Subject: %s\r\n"+ | ||||
| 		"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 | ||||
| 		"Date: %s\r\n"+ | ||||
| 		"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", | ||||
| 		receiver, SystemName, SMTPFrom, encodedSubject, content)) | ||||
| 		receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) | ||||
| 	auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) | ||||
| 	addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) | ||||
| 	to := strings.Split(receiver, ";") | ||||
| 	var err error | ||||
|  | ||||
| 	if SMTPPort == 465 { | ||||
| 		tlsConfig := &tls.Config{ | ||||
| 			InsecureSkipVerify: true, | ||||
|   | ||||
							
								
								
									
										71
									
								
								common/form_builder.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								common/form_builder.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | ||||
| package common | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"mime/multipart" | ||||
| 	"path" | ||||
| ) | ||||
|  | ||||
| type FormBuilder interface { | ||||
| 	CreateFormFile(fieldname string, fileHeader *multipart.FileHeader) error | ||||
| 	CreateFormFileReader(fieldname string, r io.Reader, filename string) error | ||||
| 	WriteField(fieldname, value string) error | ||||
| 	Close() error | ||||
| 	FormDataContentType() string | ||||
| } | ||||
|  | ||||
| type DefaultFormBuilder struct { | ||||
| 	writer *multipart.Writer | ||||
| } | ||||
|  | ||||
| func NewFormBuilder(body io.Writer) *DefaultFormBuilder { | ||||
| 	return &DefaultFormBuilder{ | ||||
| 		writer: multipart.NewWriter(body), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, fileHeader *multipart.FileHeader) error { | ||||
| 	file, err := fileHeader.Open() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	defer file.Close() | ||||
|  | ||||
| 	return fb.createFormFile(fieldname, file, fileHeader.Filename) | ||||
| } | ||||
|  | ||||
| func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { | ||||
| 	return fb.createFormFile(fieldname, r, path.Base(filename)) | ||||
| } | ||||
|  | ||||
| func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error { | ||||
| 	if filename == "" { | ||||
| 		return fmt.Errorf("filename cannot be empty") | ||||
| 	} | ||||
|  | ||||
| 	fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	_, err = io.Copy(fieldWriter, r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error { | ||||
| 	return fb.writer.WriteField(fieldname, value) | ||||
| } | ||||
|  | ||||
| func (fb *DefaultFormBuilder) Close() error { | ||||
| 	return fb.writer.Close() | ||||
| } | ||||
|  | ||||
| func (fb *DefaultFormBuilder) FormDataContentType() string { | ||||
| 	return fb.writer.FormDataContentType() | ||||
| } | ||||
| @@ -2,9 +2,12 @@ package common | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"one-api/types" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/go-playground/validator/v10" | ||||
| ) | ||||
|  | ||||
| func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||
| @@ -16,11 +19,43 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	err = json.Unmarshal(requestBody, &v) | ||||
| 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||
| 	err = c.ShouldBind(v) | ||||
| 	if err != nil { | ||||
| 		if errs, ok := err.(validator.ValidationErrors); ok { | ||||
| 			// 返回第一个错误字段的名称 | ||||
| 			return fmt.Errorf("field %s is required", errs[0].Field()) | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
| 	// Reset request body | ||||
|  | ||||
| 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func ErrorWrapper(err error, code string, statusCode int) *types.OpenAIErrorWithStatusCode { | ||||
| 	return StringErrorWrapper(err.Error(), code, statusCode) | ||||
| } | ||||
|  | ||||
| func StringErrorWrapper(err string, code string, statusCode int) *types.OpenAIErrorWithStatusCode { | ||||
| 	openAIError := types.OpenAIError{ | ||||
| 		Message: err, | ||||
| 		Type:    "one_api_error", | ||||
| 		Code:    code, | ||||
| 	} | ||||
| 	return &types.OpenAIErrorWithStatusCode{ | ||||
| 		OpenAIError: openAIError, | ||||
| 		StatusCode:  statusCode, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func AbortWithMessage(c *gin.Context, statusCode int, message string) { | ||||
| 	c.JSON(statusCode, gin.H{ | ||||
| 		"error": gin.H{ | ||||
| 			"message": message, | ||||
| 			"type":    "one_api_error", | ||||
| 		}, | ||||
| 	}) | ||||
| 	c.Abort() | ||||
| 	LogError(c.Request.Context(), message) | ||||
| } | ||||
|   | ||||
							
								
								
									
										64
									
								
								common/image/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								common/image/image.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,64 @@ | ||||
| package image | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/base64" | ||||
| 	"image" | ||||
| 	_ "image/gif" | ||||
| 	_ "image/jpeg" | ||||
| 	_ "image/png" | ||||
| 	"net/http" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
|  | ||||
| 	_ "golang.org/x/image/webp" | ||||
| ) | ||||
|  | ||||
| func GetImageSizeFromUrl(url string) (width int, height int, err error) { | ||||
| 	resp, err := http.Get(url) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 	img, _, err := image.DecodeConfig(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	return img.Width, img.Height, nil | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	reg = regexp.MustCompile(`data:image/([^;]+);base64,`) | ||||
| ) | ||||
|  | ||||
| var readerPool = sync.Pool{ | ||||
| 	New: func() interface{} { | ||||
| 		return &bytes.Reader{} | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| func GetImageSizeFromBase64(encoded string) (width int, height int, err error) { | ||||
| 	decoded, err := base64.StdEncoding.DecodeString(reg.ReplaceAllString(encoded, "")) | ||||
| 	if err != nil { | ||||
| 		return 0, 0, err | ||||
| 	} | ||||
|  | ||||
| 	reader := readerPool.Get().(*bytes.Reader) | ||||
| 	defer readerPool.Put(reader) | ||||
| 	reader.Reset(decoded) | ||||
|  | ||||
| 	img, _, err := image.DecodeConfig(reader) | ||||
| 	if err != nil { | ||||
| 		return 0, 0, err | ||||
| 	} | ||||
|  | ||||
| 	return img.Width, img.Height, nil | ||||
| } | ||||
|  | ||||
| func GetImageSize(image string) (width int, height int, err error) { | ||||
| 	if strings.HasPrefix(image, "data:image/") { | ||||
| 		return GetImageSizeFromBase64(image) | ||||
| 	} | ||||
| 	return GetImageSizeFromUrl(image) | ||||
| } | ||||
							
								
								
									
										154
									
								
								common/image/image_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								common/image/image_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,154 @@ | ||||
| package image_test | ||||
|  | ||||
| import ( | ||||
| 	"encoding/base64" | ||||
| 	"image" | ||||
| 	_ "image/gif" | ||||
| 	_ "image/jpeg" | ||||
| 	_ "image/png" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
|  | ||||
| 	img "one-api/common/image" | ||||
|  | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	_ "golang.org/x/image/webp" | ||||
| ) | ||||
|  | ||||
| type CountingReader struct { | ||||
| 	reader    io.Reader | ||||
| 	BytesRead int | ||||
| } | ||||
|  | ||||
| func (r *CountingReader) Read(p []byte) (n int, err error) { | ||||
| 	n, err = r.reader.Read(p) | ||||
| 	r.BytesRead += n | ||||
| 	return n, err | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	cases = []struct { | ||||
| 		url    string | ||||
| 		format string | ||||
| 		width  int | ||||
| 		height int | ||||
| 	}{ | ||||
| 		{"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "jpeg", 2560, 1669}, | ||||
| 		{"https://upload.wikimedia.org/wikipedia/commons/9/97/Basshunter_live_performances.png", "png", 4500, 2592}, | ||||
| 		{"https://upload.wikimedia.org/wikipedia/commons/c/c6/TO_THE_ONE_SOMETHINGNESS.webp", "webp", 984, 985}, | ||||
| 		{"https://upload.wikimedia.org/wikipedia/commons/d/d0/01_Das_Sandberg-Modell.gif", "gif", 1917, 1533}, | ||||
| 		{"https://upload.wikimedia.org/wikipedia/commons/6/62/102Cervus.jpg", "jpeg", 270, 230}, | ||||
| 	} | ||||
| ) | ||||
|  | ||||
| func TestDecode(t *testing.T) { | ||||
| 	// Bytes read: varies sometimes | ||||
| 	// jpeg: 1063892 | ||||
| 	// png: 294462 | ||||
| 	// webp: 99529 | ||||
| 	// gif: 956153 | ||||
| 	// jpeg#01: 32805 | ||||
| 	for _, c := range cases { | ||||
| 		t.Run("Decode:"+c.format, func(t *testing.T) { | ||||
| 			resp, err := http.Get(c.url) | ||||
| 			assert.NoError(t, err) | ||||
| 			defer resp.Body.Close() | ||||
| 			reader := &CountingReader{reader: resp.Body} | ||||
| 			img, format, err := image.Decode(reader) | ||||
| 			assert.NoError(t, err) | ||||
| 			size := img.Bounds().Size() | ||||
| 			assert.Equal(t, c.format, format) | ||||
| 			assert.Equal(t, c.width, size.X) | ||||
| 			assert.Equal(t, c.height, size.Y) | ||||
| 			t.Logf("Bytes read: %d", reader.BytesRead) | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	// Bytes read: | ||||
| 	// jpeg: 4096 | ||||
| 	// png: 4096 | ||||
| 	// webp: 4096 | ||||
| 	// gif: 4096 | ||||
| 	// jpeg#01: 4096 | ||||
| 	for _, c := range cases { | ||||
| 		t.Run("DecodeConfig:"+c.format, func(t *testing.T) { | ||||
| 			resp, err := http.Get(c.url) | ||||
| 			assert.NoError(t, err) | ||||
| 			defer resp.Body.Close() | ||||
| 			reader := &CountingReader{reader: resp.Body} | ||||
| 			config, format, err := image.DecodeConfig(reader) | ||||
| 			assert.NoError(t, err) | ||||
| 			assert.Equal(t, c.format, format) | ||||
| 			assert.Equal(t, c.width, config.Width) | ||||
| 			assert.Equal(t, c.height, config.Height) | ||||
| 			t.Logf("Bytes read: %d", reader.BytesRead) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestBase64(t *testing.T) { | ||||
| 	// Bytes read: | ||||
| 	// jpeg: 1063892 | ||||
| 	// png: 294462 | ||||
| 	// webp: 99072 | ||||
| 	// gif: 953856 | ||||
| 	// jpeg#01: 32805 | ||||
| 	for _, c := range cases { | ||||
| 		t.Run("Decode:"+c.format, func(t *testing.T) { | ||||
| 			resp, err := http.Get(c.url) | ||||
| 			assert.NoError(t, err) | ||||
| 			defer resp.Body.Close() | ||||
| 			data, err := io.ReadAll(resp.Body) | ||||
| 			assert.NoError(t, err) | ||||
| 			encoded := base64.StdEncoding.EncodeToString(data) | ||||
| 			body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) | ||||
| 			reader := &CountingReader{reader: body} | ||||
| 			img, format, err := image.Decode(reader) | ||||
| 			assert.NoError(t, err) | ||||
| 			size := img.Bounds().Size() | ||||
| 			assert.Equal(t, c.format, format) | ||||
| 			assert.Equal(t, c.width, size.X) | ||||
| 			assert.Equal(t, c.height, size.Y) | ||||
| 			t.Logf("Bytes read: %d", reader.BytesRead) | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	// Bytes read: | ||||
| 	// jpeg: 1536 | ||||
| 	// png: 768 | ||||
| 	// webp: 768 | ||||
| 	// gif: 1536 | ||||
| 	// jpeg#01: 3840 | ||||
| 	for _, c := range cases { | ||||
| 		t.Run("DecodeConfig:"+c.format, func(t *testing.T) { | ||||
| 			resp, err := http.Get(c.url) | ||||
| 			assert.NoError(t, err) | ||||
| 			defer resp.Body.Close() | ||||
| 			data, err := io.ReadAll(resp.Body) | ||||
| 			assert.NoError(t, err) | ||||
| 			encoded := base64.StdEncoding.EncodeToString(data) | ||||
| 			body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) | ||||
| 			reader := &CountingReader{reader: body} | ||||
| 			config, format, err := image.DecodeConfig(reader) | ||||
| 			assert.NoError(t, err) | ||||
| 			assert.Equal(t, c.format, format) | ||||
| 			assert.Equal(t, c.width, config.Width) | ||||
| 			assert.Equal(t, c.height, config.Height) | ||||
| 			t.Logf("Bytes read: %d", reader.BytesRead) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestGetImageSize(t *testing.T) { | ||||
| 	for i, c := range cases { | ||||
| 		t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { | ||||
| 			width, height, err := img.GetImageSize(c.url) | ||||
| 			assert.NoError(t, err) | ||||
| 			assert.Equal(t, c.width, width) | ||||
| 			assert.Equal(t, c.height, height) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| @@ -6,6 +6,8 @@ import ( | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
|  | ||||
| 	"github.com/joho/godotenv" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| @@ -23,6 +25,11 @@ func printHelp() { | ||||
| } | ||||
|  | ||||
| func init() { | ||||
| 	// 加载.env文件 | ||||
| 	err := godotenv.Load() | ||||
| 	if err != nil { | ||||
| 		SysLog("failed to load .env file: " + err.Error()) | ||||
| 	} | ||||
| 	flag.Parse() | ||||
|  | ||||
| 	if *PrintVersion { | ||||
| @@ -36,7 +43,11 @@ func init() { | ||||
| 	} | ||||
|  | ||||
| 	if os.Getenv("SESSION_SECRET") != "" { | ||||
| 		SessionSecret = os.Getenv("SESSION_SECRET") | ||||
| 		if os.Getenv("SESSION_SECRET") == "random_string" { | ||||
| 			SysError("SESSION_SECRET is set to an example value, please change it to a random string.") | ||||
| 		} else { | ||||
| 			SessionSecret = os.Getenv("SESSION_SECRET") | ||||
| 		} | ||||
| 	} | ||||
| 	if os.Getenv("SQLITE_PATH") != "" { | ||||
| 		SQLitePath = os.Getenv("SQLITE_PATH") | ||||
|   | ||||
							
								
								
									
										15
									
								
								common/marshaller.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								common/marshaller.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| package common | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| ) | ||||
|  | ||||
| type Marshaller interface { | ||||
| 	Marshal(value any) ([]byte, error) | ||||
| } | ||||
|  | ||||
| type JSONMarshaller struct{} | ||||
|  | ||||
| func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) { | ||||
| 	return json.Marshal(value) | ||||
| } | ||||
| @@ -3,8 +3,32 @@ package common | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var DalleSizeRatios = map[string]map[string]float64{ | ||||
| 	"dall-e-2": { | ||||
| 		"256x256":   1, | ||||
| 		"512x512":   1.125, | ||||
| 		"1024x1024": 1.25, | ||||
| 	}, | ||||
| 	"dall-e-3": { | ||||
| 		"1024x1024": 1, | ||||
| 		"1024x1792": 2, | ||||
| 		"1792x1024": 2, | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| var DalleGenerationImageAmounts = map[string][2]int{ | ||||
| 	"dall-e-2": {1, 10}, | ||||
| 	"dall-e-3": {1, 1}, // OpenAI allows n=1 currently. | ||||
| } | ||||
|  | ||||
| var DalleImagePromptLengthLimitations = map[string]int{ | ||||
| 	"dall-e-2": 1000, | ||||
| 	"dall-e-3": 4000, | ||||
| } | ||||
|  | ||||
| // ModelRatio | ||||
| // https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf | ||||
| @@ -19,12 +43,15 @@ var ModelRatio = map[string]float64{ | ||||
| 	"gpt-4-32k":                 30, | ||||
| 	"gpt-4-32k-0314":            30, | ||||
| 	"gpt-4-32k-0613":            30, | ||||
| 	"gpt-4-1106-preview":        5,    // $0.01 / 1K tokens | ||||
| 	"gpt-4-vision-preview":      5,    // $0.01 / 1K tokens | ||||
| 	"gpt-3.5-turbo":             0.75, // $0.0015 / 1K tokens | ||||
| 	"gpt-3.5-turbo-0301":        0.75, | ||||
| 	"gpt-3.5-turbo-0613":        0.75, | ||||
| 	"gpt-3.5-turbo-16k":         1.5, // $0.003 / 1K tokens | ||||
| 	"gpt-3.5-turbo-16k-0613":    1.5, | ||||
| 	"gpt-3.5-turbo-instruct":    0.75, // $0.0015 / 1K tokens | ||||
| 	"gpt-3.5-turbo-1106":        0.5,  // $0.001 / 1K tokens | ||||
| 	"text-ada-001":              0.2, | ||||
| 	"text-babbage-001":          0.25, | ||||
| 	"text-curie-001":            1, | ||||
| @@ -32,7 +59,11 @@ var ModelRatio = map[string]float64{ | ||||
| 	"text-davinci-003":          10, | ||||
| 	"text-davinci-edit-001":     10, | ||||
| 	"code-davinci-edit-001":     10, | ||||
| 	"whisper-1":                 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens | ||||
| 	"whisper-1":                 15,  // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens | ||||
| 	"tts-1":                     7.5, // $0.015 / 1K characters | ||||
| 	"tts-1-1106":                7.5, | ||||
| 	"tts-1-hd":                  15, // $0.030 / 1K characters | ||||
| 	"tts-1-hd-1106":             15, | ||||
| 	"davinci":                   10, | ||||
| 	"curie":                     10, | ||||
| 	"babbage":                   10, | ||||
| @@ -41,25 +72,33 @@ var ModelRatio = map[string]float64{ | ||||
| 	"text-search-ada-doc-001":   10, | ||||
| 	"text-moderation-stable":    0.1, | ||||
| 	"text-moderation-latest":    0.1, | ||||
| 	"dall-e":                    8, | ||||
| 	"dall-e-2":                  8,      // $0.016 - $0.020 / image | ||||
| 	"dall-e-3":                  20,     // $0.040 - $0.120 / image | ||||
| 	"claude-instant-1":          0.815,  // $1.63 / 1M tokens | ||||
| 	"claude-2":                  5.51,   // $11.02 / 1M tokens | ||||
| 	"claude-2.0":                5.51,   // $11.02 / 1M tokens | ||||
| 	"claude-2.1":                5.51,   // $11.02 / 1M tokens | ||||
| 	"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens | ||||
| 	"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens | ||||
| 	"ERNIE-Bot-4":               8.572,  // ¥0.12 / 1k tokens | ||||
| 	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens | ||||
| 	"PaLM-2":                    1, | ||||
| 	"gemini-pro":                1,      // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||
| 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens | ||||
| 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | ||||
| 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | ||||
| 	"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens | ||||
| 	"qwen-turbo":                0.8572, // ¥0.012 / 1k tokens | ||||
| 	"qwen-plus":                 10,     // ¥0.14 / 1k tokens | ||||
| 	"qwen-turbo":                0.5715, // ¥0.008 / 1k tokens  // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing | ||||
| 	"qwen-plus":                 1.4286, // ¥0.02 / 1k tokens | ||||
| 	"qwen-max":                  1.4286, // ¥0.02 / 1k tokens | ||||
| 	"qwen-max-longcontext":      1.4286, // ¥0.02 / 1k tokens | ||||
| 	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens | ||||
| 	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens | ||||
| 	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens | ||||
| 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens | ||||
| 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens | ||||
| 	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens | ||||
| 	"360GPT_S2_V9.4":            0.8572, // ¥0.012 / 1k tokens | ||||
| 	"hunyuan":                   7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 | ||||
| } | ||||
|  | ||||
| func ModelRatio2JSONString() string { | ||||
| @@ -86,9 +125,24 @@ func GetModelRatio(name string) float64 { | ||||
|  | ||||
| func GetCompletionRatio(name string) float64 { | ||||
| 	if strings.HasPrefix(name, "gpt-3.5") { | ||||
| 		if strings.HasSuffix(name, "1106") { | ||||
| 			return 2 | ||||
| 		} | ||||
| 		if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" { | ||||
| 			// TODO: clear this after 2023-12-11 | ||||
| 			now := time.Now() | ||||
| 			// https://platform.openai.com/docs/models/continuous-model-upgrades | ||||
| 			// if after 2023-12-11, use 2 | ||||
| 			if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) { | ||||
| 				return 2 | ||||
| 			} | ||||
| 		} | ||||
| 		return 1.333333 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "gpt-4") { | ||||
| 		if strings.HasSuffix(name, "preview") { | ||||
| 			return 3 | ||||
| 		} | ||||
| 		return 2 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "claude-instant-1") { | ||||
|   | ||||
							
								
								
									
										59
									
								
								common/quota.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								common/quota.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,59 @@ | ||||
| package common | ||||
|  | ||||
| // type Quota struct { | ||||
| // 	ModelName  string | ||||
| // 	ModelRatio float64 | ||||
| // 	GroupRatio float64 | ||||
| // 	Ratio      float64 | ||||
| // 	UserQuota  int | ||||
| // } | ||||
|  | ||||
| // func CreateQuota(modelName string, userQuota int, group string) *Quota { | ||||
| // 	modelRatio := GetModelRatio(modelName) | ||||
| // 	groupRatio := GetGroupRatio(group) | ||||
|  | ||||
| // 	return &Quota{ | ||||
| // 		ModelName:  modelName, | ||||
| // 		ModelRatio: modelRatio, | ||||
| // 		GroupRatio: groupRatio, | ||||
| // 		Ratio:      modelRatio * groupRatio, | ||||
| // 		UserQuota:  userQuota, | ||||
| // 	} | ||||
| // } | ||||
|  | ||||
| // func (q *Quota) getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | ||||
| // 	if ApproximateTokenEnabled { | ||||
| // 		return int(float64(len(text)) * 0.38) | ||||
| // 	} | ||||
| // 	return len(tokenEncoder.Encode(text, nil, nil)) | ||||
| // } | ||||
|  | ||||
| // func (q *Quota) CountTokenMessages(messages []Message, model string) int { | ||||
| // 	tokenEncoder := q.getTokenEncoder(model) | ||||
| // 	// Reference: | ||||
| // 	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | ||||
| // 	// https://github.com/pkoukk/tiktoken-go/issues/6 | ||||
| // 	// | ||||
| // 	// Every message follows <|start|>{role/name}\n{content}<|end|>\n | ||||
| // 	var tokensPerMessage int | ||||
| // 	var tokensPerName int | ||||
| // 	if model == "gpt-3.5-turbo-0301" { | ||||
| // 		tokensPerMessage = 4 | ||||
| // 		tokensPerName = -1 // If there's a name, the role is omitted | ||||
| // 	} else { | ||||
| // 		tokensPerMessage = 3 | ||||
| // 		tokensPerName = 1 | ||||
| // 	} | ||||
| // 	tokenNum := 0 | ||||
| // 	for _, message := range messages { | ||||
| // 		tokenNum += tokensPerMessage | ||||
| // 		tokenNum += q.getTokenNum(tokenEncoder, message.StringContent()) | ||||
| // 		tokenNum += q.getTokenNum(tokenEncoder, message.Role) | ||||
| // 		if message.Name != nil { | ||||
| // 			tokenNum += tokensPerName | ||||
| // 			tokenNum += q.getTokenNum(tokenEncoder, *message.Name) | ||||
| // 		} | ||||
| // 	} | ||||
| // 	tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> | ||||
| // 	return tokenNum | ||||
| // } | ||||
							
								
								
									
										50
									
								
								common/request_builder.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								common/request_builder.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | ||||
| package common | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| type RequestBuilder interface { | ||||
| 	Build(method, url string, body any, header http.Header) (*http.Request, error) | ||||
| } | ||||
|  | ||||
| type HTTPRequestBuilder struct { | ||||
| 	marshaller Marshaller | ||||
| } | ||||
|  | ||||
| func NewRequestBuilder() *HTTPRequestBuilder { | ||||
| 	return &HTTPRequestBuilder{ | ||||
| 		marshaller: &JSONMarshaller{}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (b *HTTPRequestBuilder) Build( | ||||
| 	method string, | ||||
| 	url string, | ||||
| 	body any, | ||||
| 	header http.Header, | ||||
| ) (req *http.Request, err error) { | ||||
| 	var bodyReader io.Reader | ||||
| 	if body != nil { | ||||
| 		if v, ok := body.(io.Reader); ok { | ||||
| 			bodyReader = v | ||||
| 		} else { | ||||
| 			var reqBytes []byte | ||||
| 			reqBytes, err = b.marshaller.Marshal(body) | ||||
| 			if err != nil { | ||||
| 				return | ||||
| 			} | ||||
| 			bodyReader = bytes.NewBuffer(reqBytes) | ||||
| 		} | ||||
| 	} | ||||
| 	req, err = http.NewRequest(method, url, bodyReader) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	if header != nil { | ||||
| 		req.Header = header | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										238
									
								
								common/token.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										238
									
								
								common/token.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,238 @@ | ||||
| package common | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"strings" | ||||
|  | ||||
| 	"one-api/common/image" | ||||
| 	"one-api/types" | ||||
|  | ||||
| 	"github.com/pkoukk/tiktoken-go" | ||||
| ) | ||||
|  | ||||
| var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | ||||
| var defaultTokenEncoder *tiktoken.Tiktoken | ||||
|  | ||||
| func InitTokenEncoders() { | ||||
| 	SysLog("initializing token encoders") | ||||
| 	gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") | ||||
| 	if err != nil { | ||||
| 		FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) | ||||
| 	} | ||||
| 	defaultTokenEncoder = gpt35TokenEncoder | ||||
| 	gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") | ||||
| 	if err != nil { | ||||
| 		FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) | ||||
| 	} | ||||
| 	for model, _ := range ModelRatio { | ||||
| 		if strings.HasPrefix(model, "gpt-3.5") { | ||||
| 			tokenEncoderMap[model] = gpt35TokenEncoder | ||||
| 		} else if strings.HasPrefix(model, "gpt-4") { | ||||
| 			tokenEncoderMap[model] = gpt4TokenEncoder | ||||
| 		} else { | ||||
| 			tokenEncoderMap[model] = nil | ||||
| 		} | ||||
| 	} | ||||
| 	SysLog("token encoders initialized") | ||||
| } | ||||
|  | ||||
| func getTokenEncoder(model string) *tiktoken.Tiktoken { | ||||
| 	tokenEncoder, ok := tokenEncoderMap[model] | ||||
| 	if ok && tokenEncoder != nil { | ||||
| 		return tokenEncoder | ||||
| 	} | ||||
| 	if ok { | ||||
| 		tokenEncoder, err := tiktoken.EncodingForModel(model) | ||||
| 		if err != nil { | ||||
| 			SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) | ||||
| 			tokenEncoder = defaultTokenEncoder | ||||
| 		} | ||||
| 		tokenEncoderMap[model] = tokenEncoder | ||||
| 		return tokenEncoder | ||||
| 	} | ||||
| 	return defaultTokenEncoder | ||||
| } | ||||
|  | ||||
| func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | ||||
| 	if ApproximateTokenEnabled { | ||||
| 		return int(float64(len(text)) * 0.38) | ||||
| 	} | ||||
| 	return len(tokenEncoder.Encode(text, nil, nil)) | ||||
| } | ||||
|  | ||||
| func CountTokenMessages(messages []types.ChatCompletionMessage, model string) int { | ||||
| 	tokenEncoder := getTokenEncoder(model) | ||||
| 	// Reference: | ||||
| 	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | ||||
| 	// https://github.com/pkoukk/tiktoken-go/issues/6 | ||||
| 	// | ||||
| 	// Every message follows <|start|>{role/name}\n{content}<|end|>\n | ||||
| 	var tokensPerMessage int | ||||
| 	var tokensPerName int | ||||
| 	if model == "gpt-3.5-turbo-0301" { | ||||
| 		tokensPerMessage = 4 | ||||
| 		tokensPerName = -1 // If there's a name, the role is omitted | ||||
| 	} else { | ||||
| 		tokensPerMessage = 3 | ||||
| 		tokensPerName = 1 | ||||
| 	} | ||||
| 	tokenNum := 0 | ||||
| 	for _, message := range messages { | ||||
| 		tokenNum += tokensPerMessage | ||||
| 		switch v := message.Content.(type) { | ||||
| 		case string: | ||||
| 			tokenNum += getTokenNum(tokenEncoder, v) | ||||
| 		case []any: | ||||
| 			for _, it := range v { | ||||
| 				m := it.(map[string]any) | ||||
| 				switch m["type"] { | ||||
| 				case "text": | ||||
| 					tokenNum += getTokenNum(tokenEncoder, m["text"].(string)) | ||||
| 				case "image_url": | ||||
| 					imageUrl, ok := m["image_url"].(map[string]any) | ||||
| 					if ok { | ||||
| 						url := imageUrl["url"].(string) | ||||
| 						detail := "" | ||||
| 						if imageUrl["detail"] != nil { | ||||
| 							detail = imageUrl["detail"].(string) | ||||
| 						} | ||||
| 						imageTokens, err := countImageTokens(url, detail) | ||||
| 						if err != nil { | ||||
| 							SysError("error counting image tokens: " + err.Error()) | ||||
| 						} else { | ||||
| 							tokenNum += imageTokens | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		tokenNum += getTokenNum(tokenEncoder, message.StringContent()) | ||||
| 		tokenNum += getTokenNum(tokenEncoder, message.Role) | ||||
| 		if message.Name != nil { | ||||
| 			tokenNum += tokensPerName | ||||
| 			tokenNum += getTokenNum(tokenEncoder, *message.Name) | ||||
| 		} | ||||
| 	} | ||||
| 	tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> | ||||
| 	return tokenNum | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	lowDetailCost         = 85 | ||||
| 	highDetailCostPerTile = 170 | ||||
| 	additionalCost        = 85 | ||||
| ) | ||||
|  | ||||
| // https://platform.openai.com/docs/guides/vision/calculating-costs | ||||
| // https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb | ||||
| func countImageTokens(url string, detail string) (_ int, err error) { | ||||
| 	var fetchSize = true | ||||
| 	var width, height int | ||||
| 	// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding | ||||
| 	// detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting. | ||||
| 	// According to the official guide, "low" disable the high-res model, | ||||
| 	// and only receive low-res 512px x 512px version of the image, indicating | ||||
| 	// that image is treated as low-res when size is smaller than 512px x 512px, | ||||
| 	// then we can assume that image size larger than 512px x 512px is treated | ||||
| 	// as high-res. Then we have the following logic: | ||||
| 	// if detail == "" || detail == "auto" { | ||||
| 	// 	width, height, err = image.GetImageSize(url) | ||||
| 	// 	if err != nil { | ||||
| 	// 		return 0, err | ||||
| 	// 	} | ||||
| 	// 	fetchSize = false | ||||
| 	// 	// not sure if this is correct | ||||
| 	// 	if width > 512 || height > 512 { | ||||
| 	// 		detail = "high" | ||||
| 	// 	} else { | ||||
| 	// 		detail = "low" | ||||
| 	// 	} | ||||
| 	// } | ||||
|  | ||||
| 	// However, in my test, it seems to be always the same as "high". | ||||
| 	// The following image, which is 125x50, is still treated as high-res, taken | ||||
| 	// 255 tokens in the response of non-stream chat completion api. | ||||
| 	// https://upload.wikimedia.org/wikipedia/commons/1/10/18_Infantry_Division_Messina.jpg | ||||
| 	if detail == "" || detail == "auto" { | ||||
| 		// assume by test, not sure if this is correct | ||||
| 		detail = "high" | ||||
| 	} | ||||
| 	switch detail { | ||||
| 	case "low": | ||||
| 		return lowDetailCost, nil | ||||
| 	case "high": | ||||
| 		if fetchSize { | ||||
| 			width, height, err = image.GetImageSize(url) | ||||
| 			if err != nil { | ||||
| 				return 0, err | ||||
| 			} | ||||
| 		} | ||||
| 		if width > 2048 || height > 2048 { // max(width, height) > 2048 | ||||
| 			ratio := float64(2048) / math.Max(float64(width), float64(height)) | ||||
| 			width = int(float64(width) * ratio) | ||||
| 			height = int(float64(height) * ratio) | ||||
| 		} | ||||
| 		if width > 768 && height > 768 { // min(width, height) > 768 | ||||
| 			ratio := float64(768) / math.Min(float64(width), float64(height)) | ||||
| 			width = int(float64(width) * ratio) | ||||
| 			height = int(float64(height) * ratio) | ||||
| 		} | ||||
| 		numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512)) | ||||
| 		result := numSquares*highDetailCostPerTile + additionalCost | ||||
| 		return result, nil | ||||
| 	default: | ||||
| 		return 0, errors.New("invalid detail option") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func CountTokenInput(input any, model string) int { | ||||
| 	switch v := input.(type) { | ||||
| 	case string: | ||||
| 		return CountTokenInput(v, model) | ||||
| 	case []string: | ||||
| 		text := "" | ||||
| 		for _, s := range v { | ||||
| 			text += s | ||||
| 		} | ||||
| 		return CountTokenInput(text, model) | ||||
| 	} | ||||
| 	return 0 | ||||
| } | ||||
|  | ||||
| func CountTokenText(text string, model string) int { | ||||
| 	tokenEncoder := getTokenEncoder(model) | ||||
| 	return getTokenNum(tokenEncoder, text) | ||||
| } | ||||
|  | ||||
| func CountTokenImage(input interface{}) (int, error) { | ||||
| 	switch v := input.(type) { | ||||
| 	case types.ImageRequest: | ||||
| 		// 处理 ImageRequest | ||||
| 		return calculateToken(v.Model, v.Size, v.N, v.Quality) | ||||
| 	case types.ImageEditRequest: | ||||
| 		// 处理 ImageEditsRequest | ||||
| 		return calculateToken(v.Model, v.Size, v.N, "") | ||||
| 	default: | ||||
| 		return 0, errors.New("unsupported type") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func calculateToken(model string, size string, n int, quality string) (int, error) { | ||||
| 	imageCostRatio, hasValidSize := DalleSizeRatios[model][size] | ||||
|  | ||||
| 	if hasValidSize { | ||||
| 		if quality == "hd" && model == "dall-e-3" { | ||||
| 			if size == "1024x1024" { | ||||
| 				imageCostRatio *= 2 | ||||
| 			} else { | ||||
| 				imageCostRatio *= 1.5 | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		return 0, errors.New("size not supported for this image model") | ||||
| 	} | ||||
|  | ||||
| 	return int(imageCostRatio*1000) * n, nil | ||||
| } | ||||
| @@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int { | ||||
| func MessageWithRequestId(message string, id string) string { | ||||
| 	return fmt.Sprintf("%s (request id: %s)", message, id) | ||||
| } | ||||
|  | ||||
| func String2Int(str string) int { | ||||
| 	num, err := strconv.Atoi(str) | ||||
| 	if err != nil { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|   | ||||
| @@ -1,9 +1,11 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"one-api/types" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func GetSubscription(c *gin.Context) { | ||||
| @@ -21,13 +23,23 @@ func GetSubscription(c *gin.Context) { | ||||
| 	} else { | ||||
| 		userId := c.GetInt("id") | ||||
| 		remainQuota, err = model.GetUserQuota(userId) | ||||
| 		if err != nil { | ||||
| 			openAIError := types.OpenAIError{ | ||||
| 				Message: err.Error(), | ||||
| 				Type:    "upstream_error", | ||||
| 			} | ||||
| 			c.JSON(200, gin.H{ | ||||
| 				"error": openAIError, | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 		usedQuota, err = model.GetUserUsedQuota(userId) | ||||
| 	} | ||||
| 	if expiredTime <= 0 { | ||||
| 		expiredTime = 0 | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		openAIError := OpenAIError{ | ||||
| 		openAIError := types.OpenAIError{ | ||||
| 			Message: err.Error(), | ||||
| 			Type:    "upstream_error", | ||||
| 		} | ||||
| @@ -53,7 +65,6 @@ func GetSubscription(c *gin.Context) { | ||||
| 		AccessUntil:        expiredTime, | ||||
| 	} | ||||
| 	c.JSON(200, subscription) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func GetUsage(c *gin.Context) { | ||||
| @@ -69,7 +80,7 @@ func GetUsage(c *gin.Context) { | ||||
| 		quota, err = model.GetUserUsedQuota(userId) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		openAIError := OpenAIError{ | ||||
| 		openAIError := types.OpenAIError{ | ||||
| 			Message: err.Error(), | ||||
| 			Type:    "one_api_error", | ||||
| 		} | ||||
| @@ -87,5 +98,4 @@ func GetUsage(c *gin.Context) { | ||||
| 		TotalUsage: amount * 100, | ||||
| 	} | ||||
| 	c.JSON(200, usage) | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -1,13 +1,13 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"one-api/providers" | ||||
| 	providersBase "one-api/providers/base" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| @@ -46,216 +46,30 @@ type OpenAIUsageResponse struct { | ||||
| 	TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar | ||||
| } | ||||
|  | ||||
| type OpenAISBUsageResponse struct { | ||||
| 	Msg  string `json:"msg"` | ||||
| 	Data *struct { | ||||
| 		Credit string `json:"credit"` | ||||
| 	} `json:"data"` | ||||
| } | ||||
|  | ||||
| type AIProxyUserOverviewResponse struct { | ||||
| 	Success   bool   `json:"success"` | ||||
| 	Message   string `json:"message"` | ||||
| 	ErrorCode int    `json:"error_code"` | ||||
| 	Data      struct { | ||||
| 		TotalPoints float64 `json:"totalPoints"` | ||||
| 	} `json:"data"` | ||||
| } | ||||
|  | ||||
| type API2GPTUsageResponse struct { | ||||
| 	Object         string  `json:"object"` | ||||
| 	TotalGranted   float64 `json:"total_granted"` | ||||
| 	TotalUsed      float64 `json:"total_used"` | ||||
| 	TotalRemaining float64 `json:"total_remaining"` | ||||
| } | ||||
|  | ||||
| type APGC2DGPTUsageResponse struct { | ||||
| 	//Grants         interface{} `json:"grants"` | ||||
| 	Object         string  `json:"object"` | ||||
| 	TotalAvailable float64 `json:"total_available"` | ||||
| 	TotalGranted   float64 `json:"total_granted"` | ||||
| 	TotalUsed      float64 `json:"total_used"` | ||||
| } | ||||
|  | ||||
| // GetAuthHeader get auth header | ||||
| func GetAuthHeader(token string) http.Header { | ||||
| 	h := http.Header{} | ||||
| 	h.Add("Authorization", fmt.Sprintf("Bearer %s", token)) | ||||
| 	return h | ||||
| } | ||||
|  | ||||
| func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { | ||||
| 	req, err := http.NewRequest(method, url, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	for k := range headers { | ||||
| 		req.Header.Add(k, headers.Get(k)) | ||||
| 	} | ||||
| 	res, err := httpClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if res.StatusCode != http.StatusOK { | ||||
| 		return nil, fmt.Errorf("status code: %d", res.StatusCode) | ||||
| 	} | ||||
| 	body, err := io.ReadAll(res.Body) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	err = res.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return body, nil | ||||
| } | ||||
|  | ||||
| func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL()) | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	response := OpenAICreditGrants{} | ||||
| 	err = json.Unmarshal(body, &response) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	channel.UpdateBalance(response.TotalAvailable) | ||||
| 	return response.TotalAvailable, nil | ||||
| } | ||||
|  | ||||
| func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key) | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	response := OpenAISBUsageResponse{} | ||||
| 	err = json.Unmarshal(body, &response) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	if response.Data == nil { | ||||
| 		return 0, errors.New(response.Msg) | ||||
| 	} | ||||
| 	balance, err := strconv.ParseFloat(response.Data.Credit, 64) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	channel.UpdateBalance(balance) | ||||
| 	return balance, nil | ||||
| } | ||||
|  | ||||
| func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := "https://aiproxy.io/api/report/getUserOverview" | ||||
| 	headers := http.Header{} | ||||
| 	headers.Add("Api-Key", channel.Key) | ||||
| 	body, err := GetResponseBody("GET", url, channel, headers) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	response := AIProxyUserOverviewResponse{} | ||||
| 	err = json.Unmarshal(body, &response) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	if !response.Success { | ||||
| 		return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message) | ||||
| 	} | ||||
| 	channel.UpdateBalance(response.Data.TotalPoints) | ||||
| 	return response.Data.TotalPoints, nil | ||||
| } | ||||
|  | ||||
| func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := "https://api.api2gpt.com/dashboard/billing/credit_grants" | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	response := API2GPTUsageResponse{} | ||||
| 	err = json.Unmarshal(body, &response) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	channel.UpdateBalance(response.TotalRemaining) | ||||
| 	return response.TotalRemaining, nil | ||||
| } | ||||
|  | ||||
| func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := "https://api.aigc2d.com/dashboard/billing/credit_grants" | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	response := APGC2DGPTUsageResponse{} | ||||
| 	err = json.Unmarshal(body, &response) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	channel.UpdateBalance(response.TotalAvailable) | ||||
| 	return response.TotalAvailable, nil | ||||
| } | ||||
|  | ||||
| func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||
| 	baseURL := common.ChannelBaseURLs[channel.Type] | ||||
| 	if channel.GetBaseURL() == "" { | ||||
| 		channel.BaseURL = &baseURL | ||||
| 	req, err := http.NewRequest("POST", "/balance", nil) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	switch channel.Type { | ||||
| 	case common.ChannelTypeOpenAI: | ||||
| 		if channel.GetBaseURL() != "" { | ||||
| 			baseURL = channel.GetBaseURL() | ||||
| 		} | ||||
| 	case common.ChannelTypeAzure: | ||||
| 		return 0, errors.New("尚未实现") | ||||
| 	case common.ChannelTypeCustom: | ||||
| 		baseURL = channel.GetBaseURL() | ||||
| 	case common.ChannelTypeCloseAI: | ||||
| 		return updateChannelCloseAIBalance(channel) | ||||
| 	case common.ChannelTypeOpenAISB: | ||||
| 		return updateChannelOpenAISBBalance(channel) | ||||
| 	case common.ChannelTypeAIProxy: | ||||
| 		return updateChannelAIProxyBalance(channel) | ||||
| 	case common.ChannelTypeAPI2GPT: | ||||
| 		return updateChannelAPI2GPTBalance(channel) | ||||
| 	case common.ChannelTypeAIGC2D: | ||||
| 		return updateChannelAIGC2DBalance(channel) | ||||
| 	default: | ||||
| 		return 0, errors.New("尚未实现") | ||||
| 	} | ||||
| 	url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL) | ||||
| 	w := httptest.NewRecorder() | ||||
| 	c, _ := gin.CreateTestContext(w) | ||||
| 	c.Request = req | ||||
|  | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	setChannelToContext(c, channel) | ||||
| 	req.Header.Set("Content-Type", "application/json") | ||||
|  | ||||
| 	provider := providers.GetProvider(channel.Type, c) | ||||
| 	if provider == nil { | ||||
| 		return 0, errors.New("provider not found") | ||||
| 	} | ||||
| 	subscription := OpenAISubscriptionResponse{} | ||||
| 	err = json.Unmarshal(body, &subscription) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
|  | ||||
| 	balanceProvider, ok := provider.(providersBase.BalanceInterface) | ||||
| 	if !ok { | ||||
| 		return 0, errors.New("provider not implemented") | ||||
| 	} | ||||
| 	now := time.Now() | ||||
| 	startDate := fmt.Sprintf("%s-01", now.Format("2006-01")) | ||||
| 	endDate := now.Format("2006-01-02") | ||||
| 	if !subscription.HasPaymentMethod { | ||||
| 		startDate = now.AddDate(0, 0, -100).Format("2006-01-02") | ||||
| 	} | ||||
| 	url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate) | ||||
| 	body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	usage := OpenAIUsageResponse{} | ||||
| 	err = json.Unmarshal(body, &usage) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	balance := subscription.HardLimitUSD - usage.TotalUsage/100 | ||||
| 	channel.UpdateBalance(balance) | ||||
| 	return balance, nil | ||||
|  | ||||
| 	return balanceProvider.Balance(channel) | ||||
|  | ||||
| } | ||||
|  | ||||
| func UpdateChannelBalance(c *gin.Context) { | ||||
|   | ||||
| @@ -1,95 +1,97 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"one-api/providers" | ||||
| 	providers_base "one-api/providers/base" | ||||
| 	"one-api/types" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { | ||||
| 	switch channel.Type { | ||||
| 	case common.ChannelTypePaLM: | ||||
| 		fallthrough | ||||
| 	case common.ChannelTypeAnthropic: | ||||
| 		fallthrough | ||||
| 	case common.ChannelTypeBaidu: | ||||
| 		fallthrough | ||||
| 	case common.ChannelTypeZhipu: | ||||
| 		fallthrough | ||||
| 	case common.ChannelTypeAli: | ||||
| 		fallthrough | ||||
| 	case common.ChannelType360: | ||||
| 		fallthrough | ||||
| 	case common.ChannelTypeXunfei: | ||||
| 		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil | ||||
| 	case common.ChannelTypeAzure: | ||||
| 		request.Model = "gpt-35-turbo" | ||||
| 		defer func() { | ||||
| 			if err != nil { | ||||
| 				err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!") | ||||
| 			} | ||||
| 		}() | ||||
| 	default: | ||||
| 		request.Model = "gpt-3.5-turbo" | ||||
| 	} | ||||
| 	requestURL := common.ChannelBaseURLs[channel.Type] | ||||
| 	if channel.Type == common.ChannelTypeAzure { | ||||
| 		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) | ||||
| 	} else { | ||||
| 		if channel.GetBaseURL() != "" { | ||||
| 			requestURL = channel.GetBaseURL() | ||||
| 		} | ||||
| 		requestURL += "/v1/chat/completions" | ||||
| 	} | ||||
|  | ||||
| 	jsonData, err := json.Marshal(request) | ||||
| func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (err error, openaiErr *types.OpenAIError) { | ||||
| 	// 创建一个 http.Request | ||||
| 	req, err := http.NewRequest("POST", "/v1/chat/completions", nil) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 	} | ||||
| 	req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 	} | ||||
| 	if channel.Type == common.ChannelTypeAzure { | ||||
| 		req.Header.Set("api-key", channel.Key) | ||||
| 	} else { | ||||
| 		req.Header.Set("Authorization", "Bearer "+channel.Key) | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/json") | ||||
| 	resp, err := httpClient.Do(req) | ||||
|  | ||||
| 	w := httptest.NewRecorder() | ||||
| 	c, _ := gin.CreateTestContext(w) | ||||
| 	c.Request = req | ||||
|  | ||||
| 	setChannelToContext(c, channel) | ||||
| 	// 创建映射 | ||||
| 	channelTypeToModel := map[int]string{ | ||||
| 		common.ChannelTypePaLM:      "PaLM-2", | ||||
| 		common.ChannelTypeAnthropic: "claude-2", | ||||
| 		common.ChannelTypeBaidu:     "ERNIE-Bot", | ||||
| 		common.ChannelTypeZhipu:     "chatglm_lite", | ||||
| 		common.ChannelTypeAli:       "qwen-turbo", | ||||
| 		common.ChannelType360:       "360GPT_S2_V9", | ||||
| 		common.ChannelTypeXunfei:    "SparkDesk", | ||||
| 		common.ChannelTypeTencent:   "hunyuan", | ||||
| 		common.ChannelTypeAzure:     "gpt-3.5-turbo", | ||||
| 	} | ||||
|  | ||||
| 	// 从映射中获取模型名称 | ||||
| 	model, ok := channelTypeToModel[channel.Type] | ||||
| 	if !ok { | ||||
| 		model = "gpt-3.5-turbo" // 默认值 | ||||
| 	} | ||||
| 	request.Model = model | ||||
|  | ||||
| 	provider := providers.GetProvider(channel.Type, c) | ||||
| 	if provider == nil { | ||||
| 		return errors.New("channel not implemented"), nil | ||||
| 	} | ||||
| 	chatProvider, ok := provider.(providers_base.ChatInterface) | ||||
| 	if !ok { | ||||
| 		return errors.New("channel not implemented"), nil | ||||
| 	} | ||||
|  | ||||
| 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 	var response TextResponse | ||||
| 	err = json.NewDecoder(resp.Body).Decode(&response) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 	if modelMap != nil && modelMap[request.Model] != "" { | ||||
| 		request.Model = modelMap[request.Model] | ||||
| 	} | ||||
| 	if response.Usage.CompletionTokens == 0 { | ||||
| 		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error | ||||
|  | ||||
| 	promptTokens := common.CountTokenMessages(request.Messages, request.Model) | ||||
| 	Usage, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens) | ||||
| 	if openAIErrorWithStatusCode != nil { | ||||
| 		return nil, &openAIErrorWithStatusCode.OpenAIError | ||||
| 	} | ||||
|  | ||||
| 	if Usage.CompletionTokens == 0 { | ||||
| 		return errors.New(fmt.Sprintf("channel %s, message 补全 tokens 非预期返回 0", channel.Name)), nil | ||||
| 	} | ||||
|  | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func buildTestRequest() *ChatRequest { | ||||
| 	testRequest := &ChatRequest{ | ||||
| 		Model:     "", // this will be set later | ||||
| func buildTestRequest() *types.ChatCompletionRequest { | ||||
| 	testRequest := &types.ChatCompletionRequest{ | ||||
| 		Messages: []types.ChatCompletionMessage{ | ||||
| 			{ | ||||
| 				Role:    "user", | ||||
| 				Content: "You just need to output 'hi' next.", | ||||
| 			}, | ||||
| 		}, | ||||
| 		Model:     "", | ||||
| 		MaxTokens: 1, | ||||
| 		Stream:    false, | ||||
| 	} | ||||
| 	testMessage := Message{ | ||||
| 		Role:    "user", | ||||
| 		Content: "hi", | ||||
| 	} | ||||
| 	testRequest.Messages = append(testRequest.Messages, testMessage) | ||||
| 	return testRequest | ||||
| } | ||||
|  | ||||
| @@ -136,20 +138,32 @@ func TestChannel(c *gin.Context) { | ||||
| var testAllChannelsLock sync.Mutex | ||||
| var testAllChannelsRunning bool = false | ||||
|  | ||||
| // disable & notify | ||||
| func disableChannel(channelId int, channelName string, reason string) { | ||||
| func notifyRootUser(subject string, content string) { | ||||
| 	if common.RootUserEmail == "" { | ||||
| 		common.RootUserEmail = model.GetRootUserEmail() | ||||
| 	} | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||
| 	err := common.SendEmail(subject, common.RootUserEmail, content) | ||||
| 	if err != nil { | ||||
| 		common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // disable & notify | ||||
| func disableChannel(channelId int, channelName string, reason string) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| // enable & notify | ||||
| func enableChannel(channelId int, channelName string) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| func testAllChannels(notify bool) error { | ||||
| 	if common.RootUserEmail == "" { | ||||
| 		common.RootUserEmail = model.GetRootUserEmail() | ||||
| @@ -172,20 +186,21 @@ func testAllChannels(notify bool) error { | ||||
| 	} | ||||
| 	go func() { | ||||
| 		for _, channel := range channels { | ||||
| 			if channel.Status != common.ChannelStatusEnabled { | ||||
| 				continue | ||||
| 			} | ||||
| 			isChannelEnabled := channel.Status == common.ChannelStatusEnabled | ||||
| 			tik := time.Now() | ||||
| 			err, openaiErr := testChannel(channel, *testRequest) | ||||
| 			tok := time.Now() | ||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | ||||
| 			if milliseconds > disableThreshold { | ||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||
| 				err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 			} | ||||
| 			if shouldDisableChannel(openaiErr, -1) { | ||||
| 			if isChannelEnabled && shouldDisableChannel(openaiErr, -1) { | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 			} | ||||
| 			if !isChannelEnabled && shouldEnableChannel(err, openaiErr) { | ||||
| 				enableChannel(channel.Id, channel.Name) | ||||
| 			} | ||||
| 			channel.UpdateResponseTime(milliseconds) | ||||
| 			time.Sleep(common.RequestInterval) | ||||
| 		} | ||||
| @@ -215,7 +230,6 @@ func TestAllChannels(c *gin.Context) { | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func AutomaticallyTestChannels(frequency int) { | ||||
|   | ||||
| @@ -127,6 +127,23 @@ func DeleteChannel(c *gin.Context) { | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func DeleteDisabledChannel(c *gin.Context) { | ||||
| 	rows, err := model.DeleteDisabledChannel() | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data":    rows, | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func UpdateChannel(c *gin.Context) { | ||||
| 	channel := model.Channel{} | ||||
| 	err := c.ShouldBindJSON(&channel) | ||||
|   | ||||
| @@ -1,9 +1,10 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func GetGroups(c *gin.Context) { | ||||
|   | ||||
| @@ -2,6 +2,7 @@ package controller | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"one-api/types" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
| @@ -55,12 +56,21 @@ func init() { | ||||
| 	// https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||
| 	openAIModels = []OpenAIModels{ | ||||
| 		{ | ||||
| 			Id:         "dall-e", | ||||
| 			Id:         "dall-e-2", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "dall-e", | ||||
| 			Root:       "dall-e-2", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "dall-e-3", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "dall-e-3", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| @@ -72,6 +82,42 @@ func init() { | ||||
| 			Root:       "whisper-1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "tts-1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "tts-1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "tts-1-1106", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "tts-1-1106", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "tts-1-hd", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "tts-1-hd", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "tts-1-hd-1106", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "tts-1-hd-1106", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-3.5-turbo", | ||||
| 			Object:     "model", | ||||
| @@ -117,6 +163,15 @@ func init() { | ||||
| 			Root:       "gpt-3.5-turbo-16k-0613", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-3.5-turbo-1106", | ||||
| 			Object:     "model", | ||||
| 			Created:    1699593571, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "gpt-3.5-turbo-1106", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-3.5-turbo-instruct", | ||||
| 			Object:     "model", | ||||
| @@ -180,6 +235,24 @@ func init() { | ||||
| 			Root:       "gpt-4-32k-0613", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-4-1106-preview", | ||||
| 			Object:     "model", | ||||
| 			Created:    1699593571, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "gpt-4-1106-preview", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-4-vision-preview", | ||||
| 			Object:     "model", | ||||
| 			Created:    1699593571, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "gpt-4-vision-preview", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "text-embedding-ada-002", | ||||
| 			Object:     "model", | ||||
| @@ -274,7 +347,7 @@ func init() { | ||||
| 			Id:         "claude-instant-1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anturopic", | ||||
| 			OwnedBy:    "anthropic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-instant-1", | ||||
| 			Parent:     nil, | ||||
| @@ -283,11 +356,29 @@ func init() { | ||||
| 			Id:         "claude-2", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anturopic", | ||||
| 			OwnedBy:    "anthropic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-2", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "claude-2.1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anthropic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-2.1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "claude-2.0", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anthropic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-2.0", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "ERNIE-Bot", | ||||
| 			Object:     "model", | ||||
| @@ -306,6 +397,15 @@ func init() { | ||||
| 			Root:       "ERNIE-Bot-turbo", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "ERNIE-Bot-4", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "baidu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "ERNIE-Bot-4", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "Embedding-V1", | ||||
| 			Object:     "model", | ||||
| @@ -324,6 +424,24 @@ func init() { | ||||
| 			Root:       "PaLM-2", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gemini-pro", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "google", | ||||
| 			Permission: permission, | ||||
| 			Root:       "gemini-pro", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "chatglm_turbo", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "zhipu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "chatglm_turbo", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "chatglm_pro", | ||||
| 			Object:     "model", | ||||
| @@ -369,6 +487,24 @@ func init() { | ||||
| 			Root:       "qwen-plus", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "qwen-max", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "ali", | ||||
| 			Permission: permission, | ||||
| 			Root:       "qwen-max", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "qwen-max-longcontext", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "ali", | ||||
| 			Permission: permission, | ||||
| 			Root:       "qwen-max-longcontext", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "text-embedding-v1", | ||||
| 			Object:     "model", | ||||
| @@ -424,12 +560,12 @@ func init() { | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "360GPT_S2_V9.4", | ||||
| 			Id:         "hunyuan", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "360", | ||||
| 			OwnedBy:    "tencent", | ||||
| 			Permission: permission, | ||||
| 			Root:       "360GPT_S2_V9.4", | ||||
| 			Root:       "hunyuan", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 	} | ||||
| @@ -451,7 +587,7 @@ func RetrieveModel(c *gin.Context) { | ||||
| 	if model, ok := openAIModelsMap[modelId]; ok { | ||||
| 		c.JSON(200, model) | ||||
| 	} else { | ||||
| 		openAIError := OpenAIError{ | ||||
| 		openAIError := types.OpenAIError{ | ||||
| 			Message: fmt.Sprintf("The model '%s' does not exist", modelId), | ||||
| 			Type:    "invalid_request_error", | ||||
| 			Param:   "model", | ||||
|   | ||||
| @@ -46,7 +46,7 @@ func UpdateOption(c *gin.Context) { | ||||
| 		if option.Value == "true" && common.GitHubClientId == "" { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client ID 以及 GitHub Client Secret!", | ||||
| 				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
|   | ||||
| @@ -1,220 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 | ||||
|  | ||||
| type AIProxyLibraryRequest struct { | ||||
| 	Model     string `json:"model"` | ||||
| 	Query     string `json:"query"` | ||||
| 	LibraryId string `json:"libraryId"` | ||||
| 	Stream    bool   `json:"stream"` | ||||
| } | ||||
|  | ||||
| type AIProxyLibraryError struct { | ||||
| 	ErrCode int    `json:"errCode"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type AIProxyLibraryDocument struct { | ||||
| 	Title string `json:"title"` | ||||
| 	URL   string `json:"url"` | ||||
| } | ||||
|  | ||||
| type AIProxyLibraryResponse struct { | ||||
| 	Success   bool                     `json:"success"` | ||||
| 	Answer    string                   `json:"answer"` | ||||
| 	Documents []AIProxyLibraryDocument `json:"documents"` | ||||
| 	AIProxyLibraryError | ||||
| } | ||||
|  | ||||
| type AIProxyLibraryStreamResponse struct { | ||||
| 	Content   string                   `json:"content"` | ||||
| 	Finish    bool                     `json:"finish"` | ||||
| 	Model     string                   `json:"model"` | ||||
| 	Documents []AIProxyLibraryDocument `json:"documents"` | ||||
| } | ||||
|  | ||||
| func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { | ||||
| 	query := "" | ||||
| 	if len(request.Messages) != 0 { | ||||
| 		query = request.Messages[len(request.Messages)-1].Content | ||||
| 	} | ||||
| 	return &AIProxyLibraryRequest{ | ||||
| 		Model:  request.Model, | ||||
| 		Stream: request.Stream, | ||||
| 		Query:  query, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { | ||||
| 	if len(documents) == 0 { | ||||
| 		return "" | ||||
| 	} | ||||
| 	content := "\n\n参考文档:\n" | ||||
| 	for i, document := range documents { | ||||
| 		content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL) | ||||
| 	} | ||||
| 	return content | ||||
| } | ||||
|  | ||||
| func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { | ||||
| 	content := response.Answer + aiProxyDocuments2Markdown(response.Documents) | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: content, | ||||
| 		}, | ||||
| 		FinishReason: "stop", | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      common.GetUUID(), | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = aiProxyDocuments2Markdown(documents) | ||||
| 	choice.FinishReason = &stopFinishReason | ||||
| 	return &ChatCompletionsStreamResponse{ | ||||
| 		Id:      common.GetUUID(), | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = response.Content | ||||
| 	return &ChatCompletionsStreamResponse{ | ||||
| 		Id:      common.GetUUID(), | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   response.Model, | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] != "data:" { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[5:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	var documents []AIProxyLibraryDocument | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var AIProxyLibraryResponse AIProxyLibraryStreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			if len(AIProxyLibraryResponse.Documents) != 0 { | ||||
| 				documents = AIProxyLibraryResponse.Documents | ||||
| 			} | ||||
| 			response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			response := documentsAIProxyLibrary(documents) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var AIProxyLibraryResponse AIProxyLibraryResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if AIProxyLibraryResponse.ErrCode != 0 { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: AIProxyLibraryResponse.Message, | ||||
| 				Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode), | ||||
| 				Code:    AIProxyLibraryResponse.ErrCode, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
| @@ -1,329 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r | ||||
|  | ||||
| type AliMessage struct { | ||||
| 	User string `json:"user"` | ||||
| 	Bot  string `json:"bot"` | ||||
| } | ||||
|  | ||||
| type AliInput struct { | ||||
| 	Prompt  string       `json:"prompt"` | ||||
| 	History []AliMessage `json:"history"` | ||||
| } | ||||
|  | ||||
| type AliParameters struct { | ||||
| 	TopP         float64 `json:"top_p,omitempty"` | ||||
| 	TopK         int     `json:"top_k,omitempty"` | ||||
| 	Seed         uint64  `json:"seed,omitempty"` | ||||
| 	EnableSearch bool    `json:"enable_search,omitempty"` | ||||
| } | ||||
|  | ||||
| type AliChatRequest struct { | ||||
| 	Model      string        `json:"model"` | ||||
| 	Input      AliInput      `json:"input"` | ||||
| 	Parameters AliParameters `json:"parameters,omitempty"` | ||||
| } | ||||
|  | ||||
| type AliEmbeddingRequest struct { | ||||
| 	Model string `json:"model"` | ||||
| 	Input struct { | ||||
| 		Texts []string `json:"texts"` | ||||
| 	} `json:"input"` | ||||
| 	Parameters *struct { | ||||
| 		TextType string `json:"text_type,omitempty"` | ||||
| 	} `json:"parameters,omitempty"` | ||||
| } | ||||
|  | ||||
| type AliEmbedding struct { | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| 	TextIndex int       `json:"text_index"` | ||||
| } | ||||
|  | ||||
| type AliEmbeddingResponse struct { | ||||
| 	Output struct { | ||||
| 		Embeddings []AliEmbedding `json:"embeddings"` | ||||
| 	} `json:"output"` | ||||
| 	Usage AliUsage `json:"usage"` | ||||
| 	AliError | ||||
| } | ||||
|  | ||||
| type AliError struct { | ||||
| 	Code      string `json:"code"` | ||||
| 	Message   string `json:"message"` | ||||
| 	RequestId string `json:"request_id"` | ||||
| } | ||||
|  | ||||
| type AliUsage struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| 	TotalTokens  int `json:"total_tokens"` | ||||
| } | ||||
|  | ||||
| type AliOutput struct { | ||||
| 	Text         string `json:"text"` | ||||
| 	FinishReason string `json:"finish_reason"` | ||||
| } | ||||
|  | ||||
| type AliChatResponse struct { | ||||
| 	Output AliOutput `json:"output"` | ||||
| 	Usage  AliUsage  `json:"usage"` | ||||
| 	AliError | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { | ||||
| 	messages := make([]AliMessage, 0, len(request.Messages)) | ||||
| 	prompt := "" | ||||
| 	for i := 0; i < len(request.Messages); i++ { | ||||
| 		message := request.Messages[i] | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, AliMessage{ | ||||
| 				User: message.Content, | ||||
| 				Bot:  "Okay", | ||||
| 			}) | ||||
| 			continue | ||||
| 		} else { | ||||
| 			if i == len(request.Messages)-1 { | ||||
| 				prompt = message.Content | ||||
| 				break | ||||
| 			} | ||||
| 			messages = append(messages, AliMessage{ | ||||
| 				User: message.Content, | ||||
| 				Bot:  request.Messages[i+1].Content, | ||||
| 			}) | ||||
| 			i++ | ||||
| 		} | ||||
| 	} | ||||
| 	return &AliChatRequest{ | ||||
| 		Model: request.Model, | ||||
| 		Input: AliInput{ | ||||
| 			Prompt:  prompt, | ||||
| 			History: messages, | ||||
| 		}, | ||||
| 		//Parameters: AliParameters{  // ChatGPT's parameters are not compatible with Ali's | ||||
| 		//	TopP: request.TopP, | ||||
| 		//	TopK: 50, | ||||
| 		//	//Seed:         0, | ||||
| 		//	//EnableSearch: false, | ||||
| 		//}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { | ||||
| 	return &AliEmbeddingRequest{ | ||||
| 		Model: "text-embedding-v1", | ||||
| 		Input: struct { | ||||
| 			Texts []string `json:"texts"` | ||||
| 		}{ | ||||
| 			Texts: request.ParseInput(), | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var aliResponse AliEmbeddingResponse | ||||
| 	err := json.NewDecoder(resp.Body).Decode(&aliResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	if aliResponse.Code != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: aliResponse.Message, | ||||
| 				Type:    aliResponse.Code, | ||||
| 				Param:   aliResponse.RequestId, | ||||
| 				Code:    aliResponse.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
|  | ||||
| 	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|  | ||||
| func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { | ||||
| 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), | ||||
| 		Model:  "text-embedding-v1", | ||||
| 		Usage:  Usage{TotalTokens: response.Usage.TotalTokens}, | ||||
| 	} | ||||
|  | ||||
| 	for _, item := range response.Output.Embeddings { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ | ||||
| 			Object:    `embedding`, | ||||
| 			Index:     item.TextIndex, | ||||
| 			Embedding: item.Embedding, | ||||
| 		}) | ||||
| 	} | ||||
| 	return &openAIEmbeddingResponse | ||||
| } | ||||
|  | ||||
| func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Output.Text, | ||||
| 		}, | ||||
| 		FinishReason: response.Output.FinishReason, | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      response.RequestId, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Usage: Usage{ | ||||
| 			PromptTokens:     response.Usage.InputTokens, | ||||
| 			CompletionTokens: response.Usage.OutputTokens, | ||||
| 			TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens, | ||||
| 		}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = aliResponse.Output.Text | ||||
| 	if aliResponse.Output.FinishReason != "null" { | ||||
| 		finishReason := aliResponse.Output.FinishReason | ||||
| 		choice.FinishReason = &finishReason | ||||
| 	} | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Id:      aliResponse.RequestId, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "ernie-bot", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] != "data:" { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[5:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	lastResponseText := "" | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var aliResponse AliChatResponse | ||||
| 			err := json.Unmarshal([]byte(data), &aliResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			if aliResponse.Usage.OutputTokens != 0 { | ||||
| 				usage.PromptTokens = aliResponse.Usage.InputTokens | ||||
| 				usage.CompletionTokens = aliResponse.Usage.OutputTokens | ||||
| 				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | ||||
| 			} | ||||
| 			response := streamResponseAli2OpenAI(&aliResponse) | ||||
| 			response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) | ||||
| 			lastResponseText = aliResponse.Output.Text | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var aliResponse AliChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &aliResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if aliResponse.Code != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: aliResponse.Message, | ||||
| 				Type:    aliResponse.Code, | ||||
| 				Param:   aliResponse.RequestId, | ||||
| 				Code:    aliResponse.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseAli2OpenAI(&aliResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
| @@ -1,153 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	audioModel := "whisper-1" | ||||
|  | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| 	channelType := c.GetInt("channel") | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	userId := c.GetInt("id") | ||||
| 	group := c.GetString("group") | ||||
|  | ||||
| 	preConsumedTokens := common.PreConsumedQuota | ||||
| 	modelRatio := common.GetModelRatio(audioModel) | ||||
| 	groupRatio := common.GetGroupRatio(group) | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	preConsumedQuota := int(float64(preConsumedTokens) * ratio) | ||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	if userQuota-preConsumedQuota < 0 { | ||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| 	} | ||||
| 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	if userQuota > 100*preConsumedQuota { | ||||
| 		// in this case, we do not pre-consume quota | ||||
| 		// because the user has enough quota | ||||
| 		preConsumedQuota = 0 | ||||
| 	} | ||||
| 	if preConsumedQuota > 0 { | ||||
| 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// map model name | ||||
| 	modelMapping := c.GetString("model_mapping") | ||||
| 	if modelMapping != "" { | ||||
| 		modelMap := make(map[string]string) | ||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		if modelMap[audioModel] != "" { | ||||
| 			audioModel = modelMap[audioModel] | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	baseURL := common.ChannelBaseURLs[channelType] | ||||
| 	requestURL := c.Request.URL.String() | ||||
|  | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
|  | ||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
| 	requestBody := c.Request.Body | ||||
|  | ||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
|  | ||||
| 	resp, err := httpClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	err = req.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = c.Request.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	var audioResponse AudioResponse | ||||
|  | ||||
| 	defer func(ctx context.Context) { | ||||
| 		go func() { | ||||
| 			quota := countTokenText(audioResponse.Text, audioModel) | ||||
| 			quotaDelta := quota - preConsumedQuota | ||||
| 			err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 			} | ||||
| 			err = model.CacheUpdateUserQuota(userId) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error update user quota cache: " + err.Error()) | ||||
| 			} | ||||
| 			if quota != 0 { | ||||
| 				tokenName := c.GetString("token_name") | ||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) | ||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 				channelId := c.GetInt("channel_id") | ||||
| 				model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 			} | ||||
| 		}() | ||||
| 	}(c.Request.Context()) | ||||
|  | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &audioResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
|  | ||||
| 	for k, v := range resp.Header { | ||||
| 		c.Writer.Header().Set(k, v[0]) | ||||
| 	} | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
|  | ||||
| 	_, err = io.Copy(c.Writer, resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| @@ -1,359 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 | ||||
|  | ||||
| type BaiduTokenResponse struct { | ||||
| 	ExpiresIn   int    `json:"expires_in"` | ||||
| 	AccessToken string `json:"access_token"` | ||||
| } | ||||
|  | ||||
| type BaiduMessage struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type BaiduChatRequest struct { | ||||
| 	Messages []BaiduMessage `json:"messages"` | ||||
| 	Stream   bool           `json:"stream"` | ||||
| 	UserId   string         `json:"user_id,omitempty"` | ||||
| } | ||||
|  | ||||
| type BaiduError struct { | ||||
| 	ErrorCode int    `json:"error_code"` | ||||
| 	ErrorMsg  string `json:"error_msg"` | ||||
| } | ||||
|  | ||||
| type BaiduChatResponse struct { | ||||
| 	Id               string `json:"id"` | ||||
| 	Object           string `json:"object"` | ||||
| 	Created          int64  `json:"created"` | ||||
| 	Result           string `json:"result"` | ||||
| 	IsTruncated      bool   `json:"is_truncated"` | ||||
| 	NeedClearHistory bool   `json:"need_clear_history"` | ||||
| 	Usage            Usage  `json:"usage"` | ||||
| 	BaiduError | ||||
| } | ||||
|  | ||||
| type BaiduChatStreamResponse struct { | ||||
| 	BaiduChatResponse | ||||
| 	SentenceId int  `json:"sentence_id"` | ||||
| 	IsEnd      bool `json:"is_end"` | ||||
| } | ||||
|  | ||||
| type BaiduEmbeddingRequest struct { | ||||
| 	Input []string `json:"input"` | ||||
| } | ||||
|  | ||||
| type BaiduEmbeddingData struct { | ||||
| 	Object    string    `json:"object"` | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| 	Index     int       `json:"index"` | ||||
| } | ||||
|  | ||||
| type BaiduEmbeddingResponse struct { | ||||
| 	Id      string               `json:"id"` | ||||
| 	Object  string               `json:"object"` | ||||
| 	Created int64                `json:"created"` | ||||
| 	Data    []BaiduEmbeddingData `json:"data"` | ||||
| 	Usage   Usage                `json:"usage"` | ||||
| 	BaiduError | ||||
| } | ||||
|  | ||||
| type BaiduAccessToken struct { | ||||
| 	AccessToken      string    `json:"access_token"` | ||||
| 	Error            string    `json:"error,omitempty"` | ||||
| 	ErrorDescription string    `json:"error_description,omitempty"` | ||||
| 	ExpiresIn        int64     `json:"expires_in,omitempty"` | ||||
| 	ExpiresAt        time.Time `json:"-"` | ||||
| } | ||||
|  | ||||
| var baiduTokenStore sync.Map | ||||
|  | ||||
| func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { | ||||
| 	messages := make([]BaiduMessage, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	return &BaiduChatRequest{ | ||||
| 		Messages: messages, | ||||
| 		Stream:   request.Stream, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Result, | ||||
| 		}, | ||||
| 		FinishReason: "stop", | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      response.Id, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: response.Created, | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Usage:   response.Usage, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = baiduResponse.Result | ||||
| 	if baiduResponse.IsEnd { | ||||
| 		choice.FinishReason = &stopFinishReason | ||||
| 	} | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Id:      baiduResponse.Id, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: baiduResponse.Created, | ||||
| 		Model:   "ernie-bot", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { | ||||
| 	return &BaiduEmbeddingRequest{ | ||||
| 		Input: request.ParseInput(), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { | ||||
| 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), | ||||
| 		Model:  "baidu-embedding", | ||||
| 		Usage:  response.Usage, | ||||
| 	} | ||||
| 	for _, item := range response.Data { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ | ||||
| 			Object:    item.Object, | ||||
| 			Index:     item.Index, | ||||
| 			Embedding: item.Embedding, | ||||
| 		}) | ||||
| 	} | ||||
| 	return &openAIEmbeddingResponse | ||||
| } | ||||
|  | ||||
| func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 6 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[6:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var baiduResponse BaiduChatStreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &baiduResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			if baiduResponse.Usage.TotalTokens != 0 { | ||||
| 				usage.TotalTokens = baiduResponse.Usage.TotalTokens | ||||
| 				usage.PromptTokens = baiduResponse.Usage.PromptTokens | ||||
| 				usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens | ||||
| 			} | ||||
| 			response := streamResponseBaidu2OpenAI(&baiduResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var baiduResponse BaiduChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &baiduResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if baiduResponse.ErrorMsg != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: baiduResponse.ErrorMsg, | ||||
| 				Type:    "baidu_error", | ||||
| 				Param:   "", | ||||
| 				Code:    baiduResponse.ErrorCode, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseBaidu2OpenAI(&baiduResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|  | ||||
| func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var baiduResponse BaiduEmbeddingResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &baiduResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if baiduResponse.ErrorMsg != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: baiduResponse.ErrorMsg, | ||||
| 				Type:    "baidu_error", | ||||
| 				Param:   "", | ||||
| 				Code:    baiduResponse.ErrorCode, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|  | ||||
| func getBaiduAccessToken(apiKey string) (string, error) { | ||||
| 	if val, ok := baiduTokenStore.Load(apiKey); ok { | ||||
| 		var accessToken BaiduAccessToken | ||||
| 		if accessToken, ok = val.(BaiduAccessToken); ok { | ||||
| 			// soon this will expire | ||||
| 			if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { | ||||
| 				go func() { | ||||
| 					_, _ = getBaiduAccessTokenHelper(apiKey) | ||||
| 				}() | ||||
| 			} | ||||
| 			return accessToken.AccessToken, nil | ||||
| 		} | ||||
| 	} | ||||
| 	accessToken, err := getBaiduAccessTokenHelper(apiKey) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	if accessToken == nil { | ||||
| 		return "", errors.New("getBaiduAccessToken return a nil token") | ||||
| 	} | ||||
| 	return (*accessToken).AccessToken, nil | ||||
| } | ||||
|  | ||||
| func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { | ||||
| 	parts := strings.Split(apiKey, "|") | ||||
| 	if len(parts) != 2 { | ||||
| 		return nil, errors.New("invalid baidu apikey") | ||||
| 	} | ||||
| 	req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", | ||||
| 		parts[0], parts[1]), nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req.Header.Add("Content-Type", "application/json") | ||||
| 	req.Header.Add("Accept", "application/json") | ||||
| 	res, err := impatientHTTPClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	var accessToken BaiduAccessToken | ||||
| 	err = json.NewDecoder(res.Body).Decode(&accessToken) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if accessToken.Error != "" { | ||||
| 		return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) | ||||
| 	} | ||||
| 	if accessToken.AccessToken == "" { | ||||
| 		return nil, errors.New("getBaiduAccessTokenHelper get empty access token") | ||||
| 	} | ||||
| 	accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second) | ||||
| 	baiduTokenStore.Store(apiKey, accessToken) | ||||
| 	return &accessToken, nil | ||||
| } | ||||
							
								
								
									
										94
									
								
								controller/relay-chat.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								controller/relay-chat.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"math" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	providersBase "one-api/providers/base" | ||||
| 	"one-api/types" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func RelayChat(c *gin.Context) { | ||||
|  | ||||
| 	var chatRequest types.ChatCompletionRequest | ||||
| 	if err := common.UnmarshalBodyReusable(c, &chatRequest); err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	channel, pass := fetchChannel(c, chatRequest.Model) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if chatRequest.MaxTokens < 0 || chatRequest.MaxTokens > math.MaxInt32/2 { | ||||
| 		common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 解析模型映射 | ||||
| 	var isModelMapped bool | ||||
| 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||
| 	if err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	if modelMap != nil && modelMap[chatRequest.Model] != "" { | ||||
| 		chatRequest.Model = modelMap[chatRequest.Model] | ||||
| 		isModelMapped = true | ||||
| 	} | ||||
|  | ||||
| 	// 获取供应商 | ||||
| 	provider, pass := getProvider(c, channel.Type, common.RelayModeChatCompletions) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
| 	chatProvider, ok := provider.(providersBase.ChatInterface) | ||||
| 	if !ok { | ||||
| 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 获取Input Tokens | ||||
| 	promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model) | ||||
|  | ||||
| 	var quotaInfo *QuotaInfo | ||||
| 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||
| 	var usage *types.Usage | ||||
| 	quotaInfo, errWithCode = generateQuotaInfo(c, chatRequest.Model, promptTokens) | ||||
| 	if errWithCode != nil { | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	usage, errWithCode = chatProvider.ChatAction(&chatRequest, isModelMapped, promptTokens) | ||||
|  | ||||
| 	// 如果报错,则退还配额 | ||||
| 	if errWithCode != nil { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		if quotaInfo.HandelStatus { | ||||
| 			go func(ctx context.Context) { | ||||
| 				// return pre-consumed quota | ||||
| 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||
| 				} | ||||
| 			}(c.Request.Context()) | ||||
| 		} | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} else { | ||||
| 		tokenName := c.GetString("token_name") | ||||
| 		// 如果没有报错,则消费配额 | ||||
| 		go func(ctx context.Context) { | ||||
| 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, err.Error()) | ||||
| 			} | ||||
| 		}(c.Request.Context()) | ||||
| 	} | ||||
| } | ||||
| @@ -1,220 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type ClaudeMetadata struct { | ||||
| 	UserId string `json:"user_id"` | ||||
| } | ||||
|  | ||||
| type ClaudeRequest struct { | ||||
| 	Model             string   `json:"model"` | ||||
| 	Prompt            string   `json:"prompt"` | ||||
| 	MaxTokensToSample int      `json:"max_tokens_to_sample"` | ||||
| 	StopSequences     []string `json:"stop_sequences,omitempty"` | ||||
| 	Temperature       float64  `json:"temperature,omitempty"` | ||||
| 	TopP              float64  `json:"top_p,omitempty"` | ||||
| 	TopK              int      `json:"top_k,omitempty"` | ||||
| 	//ClaudeMetadata    `json:"metadata,omitempty"` | ||||
| 	Stream bool `json:"stream,omitempty"` | ||||
| } | ||||
|  | ||||
| type ClaudeError struct { | ||||
| 	Type    string `json:"type"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type ClaudeResponse struct { | ||||
| 	Completion string      `json:"completion"` | ||||
| 	StopReason string      `json:"stop_reason"` | ||||
| 	Model      string      `json:"model"` | ||||
| 	Error      ClaudeError `json:"error"` | ||||
| } | ||||
|  | ||||
| func stopReasonClaude2OpenAI(reason string) string { | ||||
| 	switch reason { | ||||
| 	case "stop_sequence": | ||||
| 		return "stop" | ||||
| 	case "max_tokens": | ||||
| 		return "length" | ||||
| 	default: | ||||
| 		return reason | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { | ||||
| 	claudeRequest := ClaudeRequest{ | ||||
| 		Model:             textRequest.Model, | ||||
| 		Prompt:            "", | ||||
| 		MaxTokensToSample: textRequest.MaxTokens, | ||||
| 		StopSequences:     nil, | ||||
| 		Temperature:       textRequest.Temperature, | ||||
| 		TopP:              textRequest.TopP, | ||||
| 		Stream:            textRequest.Stream, | ||||
| 	} | ||||
| 	if claudeRequest.MaxTokensToSample == 0 { | ||||
| 		claudeRequest.MaxTokensToSample = 1000000 | ||||
| 	} | ||||
| 	prompt := "" | ||||
| 	for _, message := range textRequest.Messages { | ||||
| 		if message.Role == "user" { | ||||
| 			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) | ||||
| 		} else if message.Role == "assistant" { | ||||
| 			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) | ||||
| 		} else if message.Role == "system" { | ||||
| 			prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) | ||||
| 		} | ||||
| 	} | ||||
| 	prompt += "\n\nAssistant:" | ||||
| 	claudeRequest.Prompt = prompt | ||||
| 	return &claudeRequest | ||||
| } | ||||
|  | ||||
| func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = claudeResponse.Completion | ||||
| 	finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) | ||||
| 	if finishReason != "null" { | ||||
| 		choice.FinishReason = &finishReason | ||||
| 	} | ||||
| 	var response ChatCompletionsStreamResponse | ||||
| 	response.Object = "chat.completion.chunk" | ||||
| 	response.Model = claudeResponse.Model | ||||
| 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: strings.TrimPrefix(claudeResponse.Completion, " "), | ||||
| 			Name:    nil, | ||||
| 		}, | ||||
| 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||||
| 	responseText := "" | ||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | ||||
| 	createdTime := common.GetTimestamp() | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { | ||||
| 			return i + 4, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if !strings.HasPrefix(data, "event: completion") { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = strings.TrimPrefix(data, "event: completion\r\ndata: ") | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			var claudeResponse ClaudeResponse | ||||
| 			err := json.Unmarshal([]byte(data), &claudeResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			responseText += claudeResponse.Completion | ||||
| 			response := streamResponseClaude2OpenAI(&claudeResponse) | ||||
| 			response.Id = responseId | ||||
| 			response.Created = createdTime | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var claudeResponse ClaudeResponse | ||||
| 	err = json.Unmarshal(responseBody, &claudeResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if claudeResponse.Error.Type != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: claudeResponse.Error.Message, | ||||
| 				Type:    claudeResponse.Error.Type, | ||||
| 				Param:   "", | ||||
| 				Code:    claudeResponse.Error.Type, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseClaude2OpenAI(&claudeResponse) | ||||
| 	completionTokens := countTokenText(claudeResponse.Completion, model) | ||||
| 	usage := Usage{ | ||||
| 		PromptTokens:     promptTokens, | ||||
| 		CompletionTokens: completionTokens, | ||||
| 		TotalTokens:      promptTokens + completionTokens, | ||||
| 	} | ||||
| 	fullTextResponse.Usage = usage | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &usage | ||||
| } | ||||
							
								
								
									
										94
									
								
								controller/relay-completions.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								controller/relay-completions.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"math" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	providersBase "one-api/providers/base" | ||||
| 	"one-api/types" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func RelayCompletions(c *gin.Context) { | ||||
|  | ||||
| 	var completionRequest types.CompletionRequest | ||||
| 	if err := common.UnmarshalBodyReusable(c, &completionRequest); err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	channel, pass := fetchChannel(c, completionRequest.Model) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if completionRequest.MaxTokens < 0 || completionRequest.MaxTokens > math.MaxInt32/2 { | ||||
| 		common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 解析模型映射 | ||||
| 	var isModelMapped bool | ||||
| 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||
| 	if err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	if modelMap != nil && modelMap[completionRequest.Model] != "" { | ||||
| 		completionRequest.Model = modelMap[completionRequest.Model] | ||||
| 		isModelMapped = true | ||||
| 	} | ||||
|  | ||||
| 	// 获取供应商 | ||||
| 	provider, pass := getProvider(c, channel.Type, common.RelayModeCompletions) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
| 	completionProvider, ok := provider.(providersBase.CompletionInterface) | ||||
| 	if !ok { | ||||
| 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 获取Input Tokens | ||||
| 	promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model) | ||||
|  | ||||
| 	var quotaInfo *QuotaInfo | ||||
| 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||
| 	var usage *types.Usage | ||||
| 	quotaInfo, errWithCode = generateQuotaInfo(c, completionRequest.Model, promptTokens) | ||||
| 	if errWithCode != nil { | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	usage, errWithCode = completionProvider.CompleteAction(&completionRequest, isModelMapped, promptTokens) | ||||
|  | ||||
| 	// 如果报错,则退还配额 | ||||
| 	if errWithCode != nil { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		if quotaInfo.HandelStatus { | ||||
| 			go func(ctx context.Context) { | ||||
| 				// return pre-consumed quota | ||||
| 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||
| 				} | ||||
| 			}(c.Request.Context()) | ||||
| 		} | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} else { | ||||
| 		tokenName := c.GetString("token_name") | ||||
| 		// 如果没有报错,则消费配额 | ||||
| 		go func(ctx context.Context) { | ||||
| 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, err.Error()) | ||||
| 			} | ||||
| 		}(c.Request.Context()) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										93
									
								
								controller/relay-embeddings.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								controller/relay-embeddings.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,93 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	providersBase "one-api/providers/base" | ||||
| 	"one-api/types" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func RelayEmbeddings(c *gin.Context) { | ||||
|  | ||||
| 	var embeddingsRequest types.EmbeddingRequest | ||||
| 	if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||
| 		embeddingsRequest.Model = c.Param("model") | ||||
| 	} | ||||
|  | ||||
| 	if err := common.UnmarshalBodyReusable(c, &embeddingsRequest); err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	channel, pass := fetchChannel(c, embeddingsRequest.Model) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 解析模型映射 | ||||
| 	var isModelMapped bool | ||||
| 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||
| 	if err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	if modelMap != nil && modelMap[embeddingsRequest.Model] != "" { | ||||
| 		embeddingsRequest.Model = modelMap[embeddingsRequest.Model] | ||||
| 		isModelMapped = true | ||||
| 	} | ||||
|  | ||||
| 	// 获取供应商 | ||||
| 	provider, pass := getProvider(c, channel.Type, common.RelayModeEmbeddings) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
| 	embeddingsProvider, ok := provider.(providersBase.EmbeddingsInterface) | ||||
| 	if !ok { | ||||
| 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 获取Input Tokens | ||||
| 	promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model) | ||||
|  | ||||
| 	var quotaInfo *QuotaInfo | ||||
| 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||
| 	var usage *types.Usage | ||||
| 	quotaInfo, errWithCode = generateQuotaInfo(c, embeddingsRequest.Model, promptTokens) | ||||
| 	if errWithCode != nil { | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	usage, errWithCode = embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens) | ||||
|  | ||||
| 	// 如果报错,则退还配额 | ||||
| 	if errWithCode != nil { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		if quotaInfo.HandelStatus { | ||||
| 			go func(ctx context.Context) { | ||||
| 				// return pre-consumed quota | ||||
| 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||
| 				} | ||||
| 			}(c.Request.Context()) | ||||
| 		} | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} else { | ||||
| 		tokenName := c.GetString("token_name") | ||||
| 		// 如果没有报错,则消费配额 | ||||
| 		go func(ctx context.Context) { | ||||
| 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, err.Error()) | ||||
| 			} | ||||
| 		}(c.Request.Context()) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										106
									
								
								controller/relay-image-edits.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								controller/relay-image-edits.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,106 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	providersBase "one-api/providers/base" | ||||
| 	"one-api/types" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func RelayImageEdits(c *gin.Context) { | ||||
|  | ||||
| 	var imageEditRequest types.ImageEditRequest | ||||
|  | ||||
| 	if err := common.UnmarshalBodyReusable(c, &imageEditRequest); err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if imageEditRequest.Prompt == "" { | ||||
| 		common.AbortWithMessage(c, http.StatusBadRequest, "field prompt is required") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if imageEditRequest.Model == "" { | ||||
| 		imageEditRequest.Model = "dall-e-2" | ||||
| 	} | ||||
|  | ||||
| 	if imageEditRequest.Size == "" { | ||||
| 		imageEditRequest.Size = "1024x1024" | ||||
| 	} | ||||
|  | ||||
| 	channel, pass := fetchChannel(c, imageEditRequest.Model) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 解析模型映射 | ||||
| 	var isModelMapped bool | ||||
| 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||
| 	if err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	if modelMap != nil && modelMap[imageEditRequest.Model] != "" { | ||||
| 		imageEditRequest.Model = modelMap[imageEditRequest.Model] | ||||
| 		isModelMapped = true | ||||
| 	} | ||||
|  | ||||
| 	// 获取供应商 | ||||
| 	provider, pass := getProvider(c, channel.Type, common.RelayModeImagesEdits) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
| 	imageEditsProvider, ok := provider.(providersBase.ImageEditsInterface) | ||||
| 	if !ok { | ||||
| 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 获取Input Tokens | ||||
| 	promptTokens, err := common.CountTokenImage(imageEditRequest) | ||||
| 	if err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var quotaInfo *QuotaInfo | ||||
| 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||
| 	var usage *types.Usage | ||||
| 	quotaInfo, errWithCode = generateQuotaInfo(c, imageEditRequest.Model, promptTokens) | ||||
| 	if errWithCode != nil { | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	usage, errWithCode = imageEditsProvider.ImageEditsAction(&imageEditRequest, isModelMapped, promptTokens) | ||||
|  | ||||
| 	// 如果报错,则退还配额 | ||||
| 	if errWithCode != nil { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		if quotaInfo.HandelStatus { | ||||
| 			go func(ctx context.Context) { | ||||
| 				// return pre-consumed quota | ||||
| 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||
| 				} | ||||
| 			}(c.Request.Context()) | ||||
| 		} | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} else { | ||||
| 		tokenName := c.GetString("token_name") | ||||
| 		// 如果没有报错,则消费配额 | ||||
| 		go func(ctx context.Context) { | ||||
| 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, err.Error()) | ||||
| 			} | ||||
| 		}(c.Request.Context()) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										109
									
								
								controller/relay-image-generations.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								controller/relay-image-generations.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,109 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	providersBase "one-api/providers/base" | ||||
| 	"one-api/types" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func RelayImageGenerations(c *gin.Context) { | ||||
|  | ||||
| 	var imageRequest types.ImageRequest | ||||
|  | ||||
| 	if err := common.UnmarshalBodyReusable(c, &imageRequest); err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if imageRequest.Model == "" { | ||||
| 		imageRequest.Model = "dall-e-2" | ||||
| 	} | ||||
|  | ||||
| 	if imageRequest.N == 0 { | ||||
| 		imageRequest.N = 1 | ||||
| 	} | ||||
|  | ||||
| 	if imageRequest.Size == "" { | ||||
| 		imageRequest.Size = "1024x1024" | ||||
| 	} | ||||
|  | ||||
| 	if imageRequest.Quality == "" { | ||||
| 		imageRequest.Quality = "standard" | ||||
| 	} | ||||
|  | ||||
| 	channel, pass := fetchChannel(c, imageRequest.Model) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 解析模型映射 | ||||
| 	var isModelMapped bool | ||||
| 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||
| 	if err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	if modelMap != nil && modelMap[imageRequest.Model] != "" { | ||||
| 		imageRequest.Model = modelMap[imageRequest.Model] | ||||
| 		isModelMapped = true | ||||
| 	} | ||||
|  | ||||
| 	// 获取供应商 | ||||
| 	provider, pass := getProvider(c, channel.Type, common.RelayModeImagesGenerations) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
| 	imageGenerationsProvider, ok := provider.(providersBase.ImageGenerationsInterface) | ||||
| 	if !ok { | ||||
| 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 获取Input Tokens | ||||
| 	promptTokens, err := common.CountTokenImage(imageRequest) | ||||
| 	if err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var quotaInfo *QuotaInfo | ||||
| 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||
| 	var usage *types.Usage | ||||
| 	quotaInfo, errWithCode = generateQuotaInfo(c, imageRequest.Model, promptTokens) | ||||
| 	if errWithCode != nil { | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	usage, errWithCode = imageGenerationsProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens) | ||||
|  | ||||
| 	// 如果报错,则退还配额 | ||||
| 	if errWithCode != nil { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		if quotaInfo.HandelStatus { | ||||
| 			go func(ctx context.Context) { | ||||
| 				// return pre-consumed quota | ||||
| 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||
| 				} | ||||
| 			}(c.Request.Context()) | ||||
| 		} | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} else { | ||||
| 		tokenName := c.GetString("token_name") | ||||
| 		// 如果没有报错,则消费配额 | ||||
| 		go func(ctx context.Context) { | ||||
| 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, err.Error()) | ||||
| 			} | ||||
| 		}(c.Request.Context()) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										101
									
								
								controller/relay-image-variationsy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								controller/relay-image-variationsy.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,101 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	providersBase "one-api/providers/base" | ||||
| 	"one-api/types" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func RelayImageVariations(c *gin.Context) { | ||||
|  | ||||
| 	var imageEditRequest types.ImageEditRequest | ||||
|  | ||||
| 	if err := common.UnmarshalBodyReusable(c, &imageEditRequest); err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if imageEditRequest.Model == "" { | ||||
| 		imageEditRequest.Model = "dall-e-2" | ||||
| 	} | ||||
|  | ||||
| 	if imageEditRequest.Size == "" { | ||||
| 		imageEditRequest.Size = "1024x1024" | ||||
| 	} | ||||
|  | ||||
| 	channel, pass := fetchChannel(c, imageEditRequest.Model) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 解析模型映射 | ||||
| 	var isModelMapped bool | ||||
| 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||
| 	if err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	if modelMap != nil && modelMap[imageEditRequest.Model] != "" { | ||||
| 		imageEditRequest.Model = modelMap[imageEditRequest.Model] | ||||
| 		isModelMapped = true | ||||
| 	} | ||||
|  | ||||
| 	// 获取供应商 | ||||
| 	provider, pass := getProvider(c, channel.Type, common.RelayModeImagesVariations) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
| 	imageVariations, ok := provider.(providersBase.ImageVariationsInterface) | ||||
| 	if !ok { | ||||
| 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 获取Input Tokens | ||||
| 	promptTokens, err := common.CountTokenImage(imageEditRequest) | ||||
| 	if err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var quotaInfo *QuotaInfo | ||||
| 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||
| 	var usage *types.Usage | ||||
| 	quotaInfo, errWithCode = generateQuotaInfo(c, imageEditRequest.Model, promptTokens) | ||||
| 	if errWithCode != nil { | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	usage, errWithCode = imageVariations.ImageVariationsAction(&imageEditRequest, isModelMapped, promptTokens) | ||||
|  | ||||
| 	// 如果报错,则退还配额 | ||||
| 	if errWithCode != nil { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		if quotaInfo.HandelStatus { | ||||
| 			go func(ctx context.Context) { | ||||
| 				// return pre-consumed quota | ||||
| 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||
| 				} | ||||
| 			}(c.Request.Context()) | ||||
| 		} | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} else { | ||||
| 		tokenName := c.GetString("token_name") | ||||
| 		// 如果没有报错,则消费配额 | ||||
| 		go func(ctx context.Context) { | ||||
| 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, err.Error()) | ||||
| 			} | ||||
| 		}(c.Request.Context()) | ||||
| 	} | ||||
| } | ||||
| @@ -1,182 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	imageModel := "dall-e" | ||||
|  | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| 	channelType := c.GetInt("channel") | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	userId := c.GetInt("id") | ||||
| 	consumeQuota := c.GetBool("consume_quota") | ||||
| 	group := c.GetString("group") | ||||
|  | ||||
| 	var imageRequest ImageRequest | ||||
| 	if consumeQuota { | ||||
| 		err := common.UnmarshalBodyReusable(c, &imageRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Prompt validation | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 		return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Not "256x256", "512x512", or "1024x1024" | ||||
| 	if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { | ||||
| 		return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// N should between 1 and 10 | ||||
| 	if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { | ||||
| 		return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// map model name | ||||
| 	modelMapping := c.GetString("model_mapping") | ||||
| 	isModelMapped := false | ||||
| 	if modelMapping != "" { | ||||
| 		modelMap := make(map[string]string) | ||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		if modelMap[imageModel] != "" { | ||||
| 			imageModel = modelMap[imageModel] | ||||
| 			isModelMapped = true | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	baseURL := common.ChannelBaseURLs[channelType] | ||||
| 	requestURL := c.Request.URL.String() | ||||
|  | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
|  | ||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
|  | ||||
| 	var requestBody io.Reader | ||||
| 	if isModelMapped { | ||||
| 		jsonStr, err := json.Marshal(imageRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} else { | ||||
| 		requestBody = c.Request.Body | ||||
| 	} | ||||
|  | ||||
| 	modelRatio := common.GetModelRatio(imageModel) | ||||
| 	groupRatio := common.GetGroupRatio(group) | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | ||||
|  | ||||
| 	sizeRatio := 1.0 | ||||
| 	// Size | ||||
| 	if imageRequest.Size == "256x256" { | ||||
| 		sizeRatio = 1 | ||||
| 	} else if imageRequest.Size == "512x512" { | ||||
| 		sizeRatio = 1.125 | ||||
| 	} else if imageRequest.Size == "1024x1024" { | ||||
| 		sizeRatio = 1.25 | ||||
| 	} | ||||
| 	quota := int(ratio*sizeRatio*1000) * imageRequest.N | ||||
|  | ||||
| 	if consumeQuota && userQuota-quota < 0 { | ||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| 	} | ||||
|  | ||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
|  | ||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
|  | ||||
| 	resp, err := httpClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	err = req.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = c.Request.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	var textResponse ImageResponse | ||||
|  | ||||
| 	defer func(ctx context.Context) { | ||||
| 		if consumeQuota { | ||||
| 			err := model.PostConsumeTokenQuota(tokenId, quota) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 			} | ||||
| 			err = model.CacheUpdateUserQuota(userId) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error update user quota cache: " + err.Error()) | ||||
| 			} | ||||
| 			if quota != 0 { | ||||
| 				tokenName := c.GetString("token_name") | ||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) | ||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 				channelId := c.GetInt("channel_id") | ||||
| 				model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 			} | ||||
| 		} | ||||
| 	}(c.Request.Context()) | ||||
|  | ||||
| 	if consumeQuota { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
|  | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = json.Unmarshal(responseBody, &textResponse) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
|  | ||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| 	} | ||||
|  | ||||
| 	for k, v := range resp.Header { | ||||
| 		c.Writer.Header().Set(k, v[0]) | ||||
| 	} | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
|  | ||||
| 	_, err = io.Copy(c.Writer, resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										93
									
								
								controller/relay-moderations.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								controller/relay-moderations.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,93 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	providersBase "one-api/providers/base" | ||||
| 	"one-api/types" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func RelayModerations(c *gin.Context) { | ||||
|  | ||||
| 	var moderationRequest types.ModerationRequest | ||||
|  | ||||
| 	if err := common.UnmarshalBodyReusable(c, &moderationRequest); err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if moderationRequest.Model == "" { | ||||
| 		moderationRequest.Model = "text-moderation-stable" | ||||
| 	} | ||||
|  | ||||
| 	channel, pass := fetchChannel(c, moderationRequest.Model) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 解析模型映射 | ||||
| 	var isModelMapped bool | ||||
| 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||
| 	if err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	if modelMap != nil && modelMap[moderationRequest.Model] != "" { | ||||
| 		moderationRequest.Model = modelMap[moderationRequest.Model] | ||||
| 		isModelMapped = true | ||||
| 	} | ||||
|  | ||||
| 	// 获取供应商 | ||||
| 	provider, pass := getProvider(c, channel.Type, common.RelayModeModerations) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
| 	moderationProvider, ok := provider.(providersBase.ModerationInterface) | ||||
| 	if !ok { | ||||
| 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 获取Input Tokens | ||||
| 	promptTokens := common.CountTokenInput(moderationRequest.Input, moderationRequest.Model) | ||||
|  | ||||
| 	var quotaInfo *QuotaInfo | ||||
| 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||
| 	var usage *types.Usage | ||||
| 	quotaInfo, errWithCode = generateQuotaInfo(c, moderationRequest.Model, promptTokens) | ||||
| 	if errWithCode != nil { | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	usage, errWithCode = moderationProvider.ModerationAction(&moderationRequest, isModelMapped, promptTokens) | ||||
|  | ||||
| 	// 如果报错,则退还配额 | ||||
| 	if errWithCode != nil { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		if quotaInfo.HandelStatus { | ||||
| 			go func(ctx context.Context) { | ||||
| 				// return pre-consumed quota | ||||
| 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||
| 				} | ||||
| 			}(c.Request.Context()) | ||||
| 		} | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} else { | ||||
| 		tokenName := c.GetString("token_name") | ||||
| 		// 如果没有报错,则消费配额 | ||||
| 		go func(ctx context.Context) { | ||||
| 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, err.Error()) | ||||
| 			} | ||||
| 		}(c.Request.Context()) | ||||
| 	} | ||||
| } | ||||
| @@ -1,144 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { | ||||
| 	responseText := "" | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 6 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:6] != "data: " && data[:6] != "[DONE]" { | ||||
| 				continue | ||||
| 			} | ||||
| 			dataChan <- data | ||||
| 			data = data[6:] | ||||
| 			if !strings.HasPrefix(data, "[DONE]") { | ||||
| 				switch relayMode { | ||||
| 				case RelayModeChatCompletions: | ||||
| 					var streamResponse ChatCompletionsStreamResponse | ||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||
| 					if err != nil { | ||||
| 						common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 						continue // just ignore the error | ||||
| 					} | ||||
| 					for _, choice := range streamResponse.Choices { | ||||
| 						responseText += choice.Delta.Content | ||||
| 					} | ||||
| 				case RelayModeCompletions: | ||||
| 					var streamResponse CompletionsStreamResponse | ||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||
| 					if err != nil { | ||||
| 						common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 						continue | ||||
| 					} | ||||
| 					for _, choice := range streamResponse.Choices { | ||||
| 						responseText += choice.Text | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			if strings.HasPrefix(data, "data: [DONE]") { | ||||
| 				data = data[:12] | ||||
| 			} | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			c.Render(-1, common.CustomEvent{Data: data}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var textResponse TextResponse | ||||
| 	if consumeQuota { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
| 		err = json.Unmarshal(responseBody, &textResponse) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
| 		if textResponse.Error.Type != "" { | ||||
| 			return &OpenAIErrorWithStatusCode{ | ||||
| 				OpenAIError: textResponse.Error, | ||||
| 				StatusCode:  resp.StatusCode, | ||||
| 			}, nil | ||||
| 		} | ||||
| 		// Reset response body | ||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| 	} | ||||
| 	// We shouldn't set the header before we parse the response body, because the parse part may fail. | ||||
| 	// And then we will have to send an error response, but in this case, the header has already been set. | ||||
| 	// So the httpClient will be confused by the response. | ||||
| 	// For example, Postman will report error, and we cannot check the response at all. | ||||
| 	for k, v := range resp.Header { | ||||
| 		c.Writer.Header().Set(k, v[0]) | ||||
| 	} | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err := io.Copy(c.Writer, resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	if textResponse.Usage.TotalTokens == 0 { | ||||
| 		completionTokens := 0 | ||||
| 		for _, choice := range textResponse.Choices { | ||||
| 			completionTokens += countTokenText(choice.Message.Content, model) | ||||
| 		} | ||||
| 		textResponse.Usage = Usage{ | ||||
| 			PromptTokens:     promptTokens, | ||||
| 			CompletionTokens: completionTokens, | ||||
| 			TotalTokens:      promptTokens + completionTokens, | ||||
| 		} | ||||
| 	} | ||||
| 	return nil, &textResponse.Usage | ||||
| } | ||||
| @@ -1,205 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| ) | ||||
|  | ||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body | ||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body | ||||
|  | ||||
| type PaLMChatMessage struct { | ||||
| 	Author  string `json:"author"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type PaLMFilter struct { | ||||
| 	Reason  string `json:"reason"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type PaLMPrompt struct { | ||||
| 	Messages []PaLMChatMessage `json:"messages"` | ||||
| } | ||||
|  | ||||
| type PaLMChatRequest struct { | ||||
| 	Prompt         PaLMPrompt `json:"prompt"` | ||||
| 	Temperature    float64    `json:"temperature,omitempty"` | ||||
| 	CandidateCount int        `json:"candidateCount,omitempty"` | ||||
| 	TopP           float64    `json:"topP,omitempty"` | ||||
| 	TopK           int        `json:"topK,omitempty"` | ||||
| } | ||||
|  | ||||
| type PaLMError struct { | ||||
| 	Code    int    `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| 	Status  string `json:"status"` | ||||
| } | ||||
|  | ||||
| type PaLMChatResponse struct { | ||||
| 	Candidates []PaLMChatMessage `json:"candidates"` | ||||
| 	Messages   []Message         `json:"messages"` | ||||
| 	Filters    []PaLMFilter      `json:"filters"` | ||||
| 	Error      PaLMError         `json:"error"` | ||||
| } | ||||
|  | ||||
| func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { | ||||
| 	palmRequest := PaLMChatRequest{ | ||||
| 		Prompt: PaLMPrompt{ | ||||
| 			Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), | ||||
| 		}, | ||||
| 		Temperature:    textRequest.Temperature, | ||||
| 		CandidateCount: textRequest.N, | ||||
| 		TopP:           textRequest.TopP, | ||||
| 		TopK:           textRequest.MaxTokens, | ||||
| 	} | ||||
| 	for _, message := range textRequest.Messages { | ||||
| 		palmMessage := PaLMChatMessage{ | ||||
| 			Content: message.Content, | ||||
| 		} | ||||
| 		if message.Role == "user" { | ||||
| 			palmMessage.Author = "0" | ||||
| 		} else { | ||||
| 			palmMessage.Author = "1" | ||||
| 		} | ||||
| 		palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage) | ||||
| 	} | ||||
| 	return &palmRequest | ||||
| } | ||||
|  | ||||
| func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), | ||||
| 	} | ||||
| 	for i, candidate := range response.Candidates { | ||||
| 		choice := OpenAITextResponseChoice{ | ||||
| 			Index: i, | ||||
| 			Message: Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: candidate.Content, | ||||
| 			}, | ||||
| 			FinishReason: "stop", | ||||
| 		} | ||||
| 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice) | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	if len(palmResponse.Candidates) > 0 { | ||||
| 		choice.Delta.Content = palmResponse.Candidates[0].Content | ||||
| 	} | ||||
| 	choice.FinishReason = &stopFinishReason | ||||
| 	var response ChatCompletionsStreamResponse | ||||
| 	response.Object = "chat.completion.chunk" | ||||
| 	response.Model = "palm2" | ||||
| 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||||
| 	responseText := "" | ||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | ||||
| 	createdTime := common.GetTimestamp() | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error reading stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			common.SysError("error closing stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		var palmResponse PaLMChatResponse | ||||
| 		err = json.Unmarshal(responseBody, &palmResponse) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) | ||||
| 		fullTextResponse.Id = responseId | ||||
| 		fullTextResponse.Created = createdTime | ||||
| 		if len(palmResponse.Candidates) > 0 { | ||||
| 			responseText = palmResponse.Candidates[0].Content | ||||
| 		} | ||||
| 		jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		dataChan <- string(jsonResponse) | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + data}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var palmResponse PaLMChatResponse | ||||
| 	err = json.Unmarshal(responseBody, &palmResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: palmResponse.Error.Message, | ||||
| 				Type:    palmResponse.Error.Status, | ||||
| 				Param:   "", | ||||
| 				Code:    palmResponse.Error.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responsePaLM2OpenAI(&palmResponse) | ||||
| 	completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) | ||||
| 	usage := Usage{ | ||||
| 		PromptTokens:     promptTokens, | ||||
| 		CompletionTokens: completionTokens, | ||||
| 		TotalTokens:      promptTokens + completionTokens, | ||||
| 	} | ||||
| 	fullTextResponse.Usage = usage | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &usage | ||||
| } | ||||
							
								
								
									
										89
									
								
								controller/relay-speech.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								controller/relay-speech.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	providersBase "one-api/providers/base" | ||||
| 	"one-api/types" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func RelaySpeech(c *gin.Context) { | ||||
|  | ||||
| 	var speechRequest types.SpeechAudioRequest | ||||
|  | ||||
| 	if err := common.UnmarshalBodyReusable(c, &speechRequest); err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	channel, pass := fetchChannel(c, speechRequest.Model) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 解析模型映射 | ||||
| 	var isModelMapped bool | ||||
| 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||
| 	if err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	if modelMap != nil && modelMap[speechRequest.Model] != "" { | ||||
| 		speechRequest.Model = modelMap[speechRequest.Model] | ||||
| 		isModelMapped = true | ||||
| 	} | ||||
|  | ||||
| 	// 获取供应商 | ||||
| 	provider, pass := getProvider(c, channel.Type, common.RelayModeAudioSpeech) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
| 	speechProvider, ok := provider.(providersBase.SpeechInterface) | ||||
| 	if !ok { | ||||
| 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 获取Input Tokens | ||||
| 	promptTokens := len(speechRequest.Input) | ||||
|  | ||||
| 	var quotaInfo *QuotaInfo | ||||
| 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||
| 	var usage *types.Usage | ||||
| 	quotaInfo, errWithCode = generateQuotaInfo(c, speechRequest.Model, promptTokens) | ||||
| 	if errWithCode != nil { | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	usage, errWithCode = speechProvider.SpeechAction(&speechRequest, isModelMapped, promptTokens) | ||||
|  | ||||
| 	// 如果报错,则退还配额 | ||||
| 	if errWithCode != nil { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		if quotaInfo.HandelStatus { | ||||
| 			go func(ctx context.Context) { | ||||
| 				// return pre-consumed quota | ||||
| 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||
| 				} | ||||
| 			}(c.Request.Context()) | ||||
| 		} | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} else { | ||||
| 		tokenName := c.GetString("token_name") | ||||
| 		// 如果没有报错,则消费配额 | ||||
| 		go func(ctx context.Context) { | ||||
| 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, err.Error()) | ||||
| 			} | ||||
| 		}(c.Request.Context()) | ||||
| 	} | ||||
| } | ||||
| @@ -1,590 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	APITypeOpenAI = iota | ||||
| 	APITypeClaude | ||||
| 	APITypePaLM | ||||
| 	APITypeBaidu | ||||
| 	APITypeZhipu | ||||
| 	APITypeAli | ||||
| 	APITypeXunfei | ||||
| 	APITypeAIProxyLibrary | ||||
| ) | ||||
|  | ||||
| var httpClient *http.Client | ||||
| var impatientHTTPClient *http.Client | ||||
|  | ||||
| func init() { | ||||
| 	httpClient = &http.Client{} | ||||
| 	impatientHTTPClient = &http.Client{ | ||||
| 		Timeout: 5 * time.Second, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	channelType := c.GetInt("channel") | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| 	userId := c.GetInt("id") | ||||
| 	consumeQuota := c.GetBool("consume_quota") | ||||
| 	group := c.GetString("group") | ||||
| 	var textRequest GeneralOpenAIRequest | ||||
| 	if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { | ||||
| 		err := common.UnmarshalBodyReusable(c, &textRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
| 	if relayMode == RelayModeModerations && textRequest.Model == "" { | ||||
| 		textRequest.Model = "text-moderation-latest" | ||||
| 	} | ||||
| 	if relayMode == RelayModeEmbeddings && textRequest.Model == "" { | ||||
| 		textRequest.Model = c.Param("model") | ||||
| 	} | ||||
| 	// request validation | ||||
| 	if textRequest.Model == "" { | ||||
| 		return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 	} | ||||
| 	switch relayMode { | ||||
| 	case RelayModeCompletions: | ||||
| 		if textRequest.Prompt == "" { | ||||
| 			return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 		} | ||||
| 	case RelayModeChatCompletions: | ||||
| 		if textRequest.Messages == nil || len(textRequest.Messages) == 0 { | ||||
| 			return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 		} | ||||
| 	case RelayModeEmbeddings: | ||||
| 	case RelayModeModerations: | ||||
| 		if textRequest.Input == "" { | ||||
| 			return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 		} | ||||
| 	case RelayModeEdits: | ||||
| 		if textRequest.Instruction == "" { | ||||
| 			return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
| 	// map model name | ||||
| 	modelMapping := c.GetString("model_mapping") | ||||
| 	isModelMapped := false | ||||
| 	if modelMapping != "" && modelMapping != "{}" { | ||||
| 		modelMap := make(map[string]string) | ||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		if modelMap[textRequest.Model] != "" { | ||||
| 			textRequest.Model = modelMap[textRequest.Model] | ||||
| 			isModelMapped = true | ||||
| 		} | ||||
| 	} | ||||
| 	apiType := APITypeOpenAI | ||||
| 	switch channelType { | ||||
| 	case common.ChannelTypeAnthropic: | ||||
| 		apiType = APITypeClaude | ||||
| 	case common.ChannelTypeBaidu: | ||||
| 		apiType = APITypeBaidu | ||||
| 	case common.ChannelTypePaLM: | ||||
| 		apiType = APITypePaLM | ||||
| 	case common.ChannelTypeZhipu: | ||||
| 		apiType = APITypeZhipu | ||||
| 	case common.ChannelTypeAli: | ||||
| 		apiType = APITypeAli | ||||
| 	case common.ChannelTypeXunfei: | ||||
| 		apiType = APITypeXunfei | ||||
| 	case common.ChannelTypeAIProxyLibrary: | ||||
| 		apiType = APITypeAIProxyLibrary | ||||
| 	} | ||||
| 	baseURL := common.ChannelBaseURLs[channelType] | ||||
| 	requestURL := c.Request.URL.String() | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
| 	switch apiType { | ||||
| 	case APITypeOpenAI: | ||||
| 		if channelType == common.ChannelTypeAzure { | ||||
| 			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | ||||
| 			query := c.Request.URL.Query() | ||||
| 			apiVersion := query.Get("api-version") | ||||
| 			if apiVersion == "" { | ||||
| 				apiVersion = c.GetString("api_version") | ||||
| 			} | ||||
| 			requestURL := strings.Split(requestURL, "?")[0] | ||||
| 			requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) | ||||
| 			baseURL = c.GetString("base_url") | ||||
| 			task := strings.TrimPrefix(requestURL, "/v1/") | ||||
| 			model_ := textRequest.Model | ||||
| 			model_ = strings.Replace(model_, ".", "", -1) | ||||
| 			// https://github.com/songquanpeng/one-api/issues/67 | ||||
| 			model_ = strings.TrimSuffix(model_, "-0301") | ||||
| 			model_ = strings.TrimSuffix(model_, "-0314") | ||||
| 			model_ = strings.TrimSuffix(model_, "-0613") | ||||
| 			fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) | ||||
| 		} | ||||
| 	case APITypeClaude: | ||||
| 		fullRequestURL = "https://api.anthropic.com/v1/complete" | ||||
| 		if baseURL != "" { | ||||
| 			fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL) | ||||
| 		} | ||||
| 	case APITypeBaidu: | ||||
| 		switch textRequest.Model { | ||||
| 		case "ERNIE-Bot": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" | ||||
| 		case "ERNIE-Bot-turbo": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" | ||||
| 		case "BLOOMZ-7B": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" | ||||
| 		case "Embedding-V1": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" | ||||
| 		} | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		var err error | ||||
| 		if apiKey, err = getBaiduAccessToken(apiKey); err != nil { | ||||
| 			return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		fullRequestURL += "?access_token=" + apiKey | ||||
| 	case APITypePaLM: | ||||
| 		fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" | ||||
| 		if baseURL != "" { | ||||
| 			fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) | ||||
| 		} | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		fullRequestURL += "?key=" + apiKey | ||||
| 	case APITypeZhipu: | ||||
| 		method := "invoke" | ||||
| 		if textRequest.Stream { | ||||
| 			method = "sse-invoke" | ||||
| 		} | ||||
| 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) | ||||
| 	case APITypeAli: | ||||
| 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" | ||||
| 		if relayMode == RelayModeEmbeddings { | ||||
| 			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" | ||||
| 		} | ||||
| 	case APITypeAIProxyLibrary: | ||||
| 		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) | ||||
| 	} | ||||
| 	var promptTokens int | ||||
| 	var completionTokens int | ||||
| 	switch relayMode { | ||||
| 	case RelayModeChatCompletions: | ||||
| 		promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) | ||||
| 	case RelayModeCompletions: | ||||
| 		promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) | ||||
| 	case RelayModeModerations: | ||||
| 		promptTokens = countTokenInput(textRequest.Input, textRequest.Model) | ||||
| 	} | ||||
| 	preConsumedTokens := common.PreConsumedQuota | ||||
| 	if textRequest.MaxTokens != 0 { | ||||
| 		preConsumedTokens = promptTokens + textRequest.MaxTokens | ||||
| 	} | ||||
| 	modelRatio := common.GetModelRatio(textRequest.Model) | ||||
| 	groupRatio := common.GetGroupRatio(group) | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	preConsumedQuota := int(float64(preConsumedTokens) * ratio) | ||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	if userQuota-preConsumedQuota < 0 { | ||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| 	} | ||||
| 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	if userQuota > 100*preConsumedQuota { | ||||
| 		// in this case, we do not pre-consume quota | ||||
| 		// because the user has enough quota | ||||
| 		preConsumedQuota = 0 | ||||
| 		common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) | ||||
| 	} | ||||
| 	if consumeQuota && preConsumedQuota > 0 { | ||||
| 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||
| 		} | ||||
| 	} | ||||
| 	var requestBody io.Reader | ||||
| 	if isModelMapped { | ||||
| 		jsonStr, err := json.Marshal(textRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} else { | ||||
| 		requestBody = c.Request.Body | ||||
| 	} | ||||
| 	switch apiType { | ||||
| 	case APITypeClaude: | ||||
| 		claudeRequest := requestOpenAI2Claude(textRequest) | ||||
| 		jsonStr, err := json.Marshal(claudeRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeBaidu: | ||||
| 		var jsonData []byte | ||||
| 		var err error | ||||
| 		switch relayMode { | ||||
| 		case RelayModeEmbeddings: | ||||
| 			baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) | ||||
| 			jsonData, err = json.Marshal(baiduEmbeddingRequest) | ||||
| 		default: | ||||
| 			baiduRequest := requestOpenAI2Baidu(textRequest) | ||||
| 			jsonData, err = json.Marshal(baiduRequest) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonData) | ||||
| 	case APITypePaLM: | ||||
| 		palmRequest := requestOpenAI2PaLM(textRequest) | ||||
| 		jsonStr, err := json.Marshal(palmRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeZhipu: | ||||
| 		zhipuRequest := requestOpenAI2Zhipu(textRequest) | ||||
| 		jsonStr, err := json.Marshal(zhipuRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeAli: | ||||
| 		var jsonStr []byte | ||||
| 		var err error | ||||
| 		switch relayMode { | ||||
| 		case RelayModeEmbeddings: | ||||
| 			aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) | ||||
| 			jsonStr, err = json.Marshal(aliEmbeddingRequest) | ||||
| 		default: | ||||
| 			aliRequest := requestOpenAI2Ali(textRequest) | ||||
| 			jsonStr, err = json.Marshal(aliRequest) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeAIProxyLibrary: | ||||
| 		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) | ||||
| 		aiProxyLibraryRequest.LibraryId = c.GetString("library_id") | ||||
| 		jsonStr, err := json.Marshal(aiProxyLibraryRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} | ||||
|  | ||||
| 	var req *http.Request | ||||
| 	var resp *http.Response | ||||
| 	isStream := textRequest.Stream | ||||
|  | ||||
| 	if apiType != APITypeXunfei { // cause xunfei use websocket | ||||
| 		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		switch apiType { | ||||
| 		case APITypeOpenAI: | ||||
| 			if channelType == common.ChannelTypeAzure { | ||||
| 				req.Header.Set("api-key", apiKey) | ||||
| 			} else { | ||||
| 				req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
| 				if channelType == common.ChannelTypeOpenRouter { | ||||
| 					req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") | ||||
| 					req.Header.Set("X-Title", "One API") | ||||
| 				} | ||||
| 			} | ||||
| 		case APITypeClaude: | ||||
| 			req.Header.Set("x-api-key", apiKey) | ||||
| 			anthropicVersion := c.Request.Header.Get("anthropic-version") | ||||
| 			if anthropicVersion == "" { | ||||
| 				anthropicVersion = "2023-06-01" | ||||
| 			} | ||||
| 			req.Header.Set("anthropic-version", anthropicVersion) | ||||
| 		case APITypeZhipu: | ||||
| 			token := getZhipuToken(apiKey) | ||||
| 			req.Header.Set("Authorization", token) | ||||
| 		case APITypeAli: | ||||
| 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||
| 			if textRequest.Stream { | ||||
| 				req.Header.Set("X-DashScope-SSE", "enable") | ||||
| 			} | ||||
| 		default: | ||||
| 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||
| 		} | ||||
| 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
| 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | ||||
| 		resp, err = httpClient.Do(req) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = req.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = c.Request.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||
|  | ||||
| 		if resp.StatusCode != http.StatusOK { | ||||
| 			if preConsumedQuota != 0 { | ||||
| 				go func(ctx context.Context) { | ||||
| 					// return pre-consumed quota | ||||
| 					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) | ||||
| 					if err != nil { | ||||
| 						common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||
| 					} | ||||
| 				}(c.Request.Context()) | ||||
| 			} | ||||
| 			return relayErrorHandler(resp) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	var textResponse TextResponse | ||||
| 	tokenName := c.GetString("token_name") | ||||
|  | ||||
| 	defer func(ctx context.Context) { | ||||
| 		// c.Writer.Flush() | ||||
| 		go func() { | ||||
| 			if consumeQuota { | ||||
| 				quota := 0 | ||||
| 				completionRatio := common.GetCompletionRatio(textRequest.Model) | ||||
| 				promptTokens = textResponse.Usage.PromptTokens | ||||
| 				completionTokens = textResponse.Usage.CompletionTokens | ||||
|  | ||||
| 				quota = promptTokens + int(float64(completionTokens)*completionRatio) | ||||
| 				quota = int(float64(quota) * ratio) | ||||
| 				if ratio != 0 && quota <= 0 { | ||||
| 					quota = 1 | ||||
| 				} | ||||
| 				totalTokens := promptTokens + completionTokens | ||||
| 				if totalTokens == 0 { | ||||
| 					// in this case, must be some error happened | ||||
| 					// we cannot just return, because we may have to return the pre-consumed quota | ||||
| 					quota = 0 | ||||
| 				} | ||||
| 				quotaDelta := quota - preConsumedQuota | ||||
| 				err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error consuming token remain quota: "+err.Error()) | ||||
| 				} | ||||
| 				err = model.CacheUpdateUserQuota(userId) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error update user quota cache: "+err.Error()) | ||||
| 				} | ||||
| 				if quota != 0 { | ||||
| 					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 					model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) | ||||
| 					model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 					model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 				} | ||||
| 			} | ||||
| 		}() | ||||
| 	}(c.Request.Context()) | ||||
| 	switch apiType { | ||||
| 	case APITypeOpenAI: | ||||
| 		if isStream { | ||||
| 			err, responseText := openaiStreamHandler(c, resp, relayMode) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			textResponse.Usage.PromptTokens = promptTokens | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeClaude: | ||||
| 		if isStream { | ||||
| 			err, responseText := claudeStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			textResponse.Usage.PromptTokens = promptTokens | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeBaidu: | ||||
| 		if isStream { | ||||
| 			err, usage := baiduStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			var err *OpenAIErrorWithStatusCode | ||||
| 			var usage *Usage | ||||
| 			switch relayMode { | ||||
| 			case RelayModeEmbeddings: | ||||
| 				err, usage = baiduEmbeddingHandler(c, resp) | ||||
| 			default: | ||||
| 				err, usage = baiduHandler(c, resp) | ||||
| 			} | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypePaLM: | ||||
| 		if textRequest.Stream { // PaLM2 API does not support stream | ||||
| 			err, responseText := palmStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			textResponse.Usage.PromptTokens = promptTokens | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeZhipu: | ||||
| 		if isStream { | ||||
| 			err, usage := zhipuStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			// zhipu's API does not return prompt tokens & completion tokens | ||||
| 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := zhipuHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			// zhipu's API does not return prompt tokens & completion tokens | ||||
| 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeAli: | ||||
| 		if isStream { | ||||
| 			err, usage := aliStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			var err *OpenAIErrorWithStatusCode | ||||
| 			var usage *Usage | ||||
| 			switch relayMode { | ||||
| 			case RelayModeEmbeddings: | ||||
| 				err, usage = aliEmbeddingHandler(c, resp) | ||||
| 			default: | ||||
| 				err, usage = aliHandler(c, resp) | ||||
| 			} | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeXunfei: | ||||
| 		auth := c.Request.Header.Get("Authorization") | ||||
| 		auth = strings.TrimPrefix(auth, "Bearer ") | ||||
| 		splits := strings.Split(auth, "|") | ||||
| 		if len(splits) != 3 { | ||||
| 			return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) | ||||
| 		} | ||||
| 		var err *OpenAIErrorWithStatusCode | ||||
| 		var usage *Usage | ||||
| 		if isStream { | ||||
| 			err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) | ||||
| 		} else { | ||||
| 			err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2]) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if usage != nil { | ||||
| 			textResponse.Usage = *usage | ||||
| 		} | ||||
| 		return nil | ||||
| 	case APITypeAIProxyLibrary: | ||||
| 		if isStream { | ||||
| 			err, usage := aiProxyLibraryStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := aiProxyLibraryHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	default: | ||||
| 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										89
									
								
								controller/relay-transcriptions.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								controller/relay-transcriptions.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	providersBase "one-api/providers/base" | ||||
| 	"one-api/types" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func RelayTranscriptions(c *gin.Context) { | ||||
|  | ||||
| 	var audioRequest types.AudioRequest | ||||
|  | ||||
| 	if err := common.UnmarshalBodyReusable(c, &audioRequest); err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	channel, pass := fetchChannel(c, audioRequest.Model) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 解析模型映射 | ||||
| 	var isModelMapped bool | ||||
| 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||
| 	if err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	if modelMap != nil && modelMap[audioRequest.Model] != "" { | ||||
| 		audioRequest.Model = modelMap[audioRequest.Model] | ||||
| 		isModelMapped = true | ||||
| 	} | ||||
|  | ||||
| 	// 获取供应商 | ||||
| 	provider, pass := getProvider(c, channel.Type, common.RelayModeAudioTranscription) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
| 	transcriptionsProvider, ok := provider.(providersBase.TranscriptionsInterface) | ||||
| 	if !ok { | ||||
| 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 获取Input Tokens | ||||
| 	promptTokens := 0 | ||||
|  | ||||
| 	var quotaInfo *QuotaInfo | ||||
| 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||
| 	var usage *types.Usage | ||||
| 	quotaInfo, errWithCode = generateQuotaInfo(c, audioRequest.Model, promptTokens) | ||||
| 	if errWithCode != nil { | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	usage, errWithCode = transcriptionsProvider.TranscriptionsAction(&audioRequest, isModelMapped, promptTokens) | ||||
|  | ||||
| 	// 如果报错,则退还配额 | ||||
| 	if errWithCode != nil { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		if quotaInfo.HandelStatus { | ||||
| 			go func(ctx context.Context) { | ||||
| 				// return pre-consumed quota | ||||
| 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||
| 				} | ||||
| 			}(c.Request.Context()) | ||||
| 		} | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} else { | ||||
| 		tokenName := c.GetString("token_name") | ||||
| 		// 如果没有报错,则消费配额 | ||||
| 		go func(ctx context.Context) { | ||||
| 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, err.Error()) | ||||
| 			} | ||||
| 		}(c.Request.Context()) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										89
									
								
								controller/relay-translations.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								controller/relay-translations.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	providersBase "one-api/providers/base" | ||||
| 	"one-api/types" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func RelayTranslations(c *gin.Context) { | ||||
|  | ||||
| 	var audioRequest types.AudioRequest | ||||
|  | ||||
| 	if err := common.UnmarshalBodyReusable(c, &audioRequest); err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	channel, pass := fetchChannel(c, audioRequest.Model) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 解析模型映射 | ||||
| 	var isModelMapped bool | ||||
| 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||
| 	if err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	if modelMap != nil && modelMap[audioRequest.Model] != "" { | ||||
| 		audioRequest.Model = modelMap[audioRequest.Model] | ||||
| 		isModelMapped = true | ||||
| 	} | ||||
|  | ||||
| 	// 获取供应商 | ||||
| 	provider, pass := getProvider(c, channel.Type, common.RelayModeAudioTranslation) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
| 	translationProvider, ok := provider.(providersBase.TranslationInterface) | ||||
| 	if !ok { | ||||
| 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 获取Input Tokens | ||||
| 	promptTokens := 0 | ||||
|  | ||||
| 	var quotaInfo *QuotaInfo | ||||
| 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||
| 	var usage *types.Usage | ||||
| 	quotaInfo, errWithCode = generateQuotaInfo(c, audioRequest.Model, promptTokens) | ||||
| 	if errWithCode != nil { | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	usage, errWithCode = translationProvider.TranslationAction(&audioRequest, isModelMapped, promptTokens) | ||||
|  | ||||
| 	// 如果报错,则退还配额 | ||||
| 	if errWithCode != nil { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		if quotaInfo.HandelStatus { | ||||
| 			go func(ctx context.Context) { | ||||
| 				// return pre-consumed quota | ||||
| 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||
| 				} | ||||
| 			}(c.Request.Context()) | ||||
| 		} | ||||
| 		errorHelper(c, errWithCode) | ||||
| 		return | ||||
| 	} else { | ||||
| 		tokenName := c.GetString("token_name") | ||||
| 		// 如果没有报错,则消费配额 | ||||
| 		go func(ctx context.Context) { | ||||
| 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, err.Error()) | ||||
| 			} | ||||
| 		}(c.Request.Context()) | ||||
| 	} | ||||
| } | ||||
| @@ -1,132 +1,126 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkoukk/tiktoken-go" | ||||
| 	"io" | ||||
| 	"math" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"one-api/providers" | ||||
| 	providersBase "one-api/providers/base" | ||||
| 	"one-api/types" | ||||
| 	"reflect" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/go-playground/validator/v10" | ||||
| ) | ||||
|  | ||||
| var stopFinishReason = "stop" | ||||
|  | ||||
| // tokenEncoderMap won't grow after initialization | ||||
| var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | ||||
| var defaultTokenEncoder *tiktoken.Tiktoken | ||||
|  | ||||
| func InitTokenEncoders() { | ||||
| 	common.SysLog("initializing token encoders") | ||||
| 	gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") | ||||
| 	if err != nil { | ||||
| 		common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) | ||||
| 	} | ||||
| 	defaultTokenEncoder = gpt35TokenEncoder | ||||
| 	gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") | ||||
| 	if err != nil { | ||||
| 		common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) | ||||
| 	} | ||||
| 	for model, _ := range common.ModelRatio { | ||||
| 		if strings.HasPrefix(model, "gpt-3.5") { | ||||
| 			tokenEncoderMap[model] = gpt35TokenEncoder | ||||
| 		} else if strings.HasPrefix(model, "gpt-4") { | ||||
| 			tokenEncoderMap[model] = gpt4TokenEncoder | ||||
| 		} else { | ||||
| 			tokenEncoderMap[model] = nil | ||||
| func GetValidFieldName(err error, obj interface{}) string { | ||||
| 	getObj := reflect.TypeOf(obj) | ||||
| 	if errs, ok := err.(validator.ValidationErrors); ok { | ||||
| 		for _, e := range errs { | ||||
| 			if f, exist := getObj.Elem().FieldByName(e.Field()); exist { | ||||
| 				return f.Name | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	common.SysLog("token encoders initialized") | ||||
| 	return err.Error() | ||||
| } | ||||
|  | ||||
| func getTokenEncoder(model string) *tiktoken.Tiktoken { | ||||
| 	tokenEncoder, ok := tokenEncoderMap[model] | ||||
| 	if ok && tokenEncoder != nil { | ||||
| 		return tokenEncoder | ||||
| 	} | ||||
| func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, pass bool) { | ||||
| 	channelId, ok := c.Get("channelId") | ||||
| 	if ok { | ||||
| 		tokenEncoder, err := tiktoken.EncodingForModel(model) | ||||
| 		if err != nil { | ||||
| 			common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) | ||||
| 			tokenEncoder = defaultTokenEncoder | ||||
| 		channel, pass = fetchChannelById(c, channelId.(int)) | ||||
| 		if pass { | ||||
| 			return | ||||
| 		} | ||||
| 		tokenEncoderMap[model] = tokenEncoder | ||||
| 		return tokenEncoder | ||||
|  | ||||
| 	} | ||||
| 	return defaultTokenEncoder | ||||
| 	channel, pass = fetchChannelByModel(c, modelName) | ||||
| 	if pass { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	setChannelToContext(c, channel) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | ||||
| 	if common.ApproximateTokenEnabled { | ||||
| 		return int(float64(len(text)) * 0.38) | ||||
| func fetchChannelById(c *gin.Context, channelId any) (*model.Channel, bool) { | ||||
| 	id, err := strconv.Atoi(channelId.(string)) | ||||
| 	if err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") | ||||
| 		return nil, true | ||||
| 	} | ||||
| 	return len(tokenEncoder.Encode(text, nil, nil)) | ||||
| 	channel, err := model.GetChannelById(id, true) | ||||
| 	if err != nil { | ||||
| 		common.AbortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") | ||||
| 		return nil, true | ||||
| 	} | ||||
| 	if channel.Status != common.ChannelStatusEnabled { | ||||
| 		common.AbortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") | ||||
| 		return nil, true | ||||
| 	} | ||||
|  | ||||
| 	return channel, false | ||||
| } | ||||
|  | ||||
| func countTokenMessages(messages []Message, model string) int { | ||||
| 	tokenEncoder := getTokenEncoder(model) | ||||
| 	// Reference: | ||||
| 	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | ||||
| 	// https://github.com/pkoukk/tiktoken-go/issues/6 | ||||
| 	// | ||||
| 	// Every message follows <|start|>{role/name}\n{content}<|end|>\n | ||||
| 	var tokensPerMessage int | ||||
| 	var tokensPerName int | ||||
| 	if model == "gpt-3.5-turbo-0301" { | ||||
| 		tokensPerMessage = 4 | ||||
| 		tokensPerName = -1 // If there's a name, the role is omitted | ||||
| 	} else { | ||||
| 		tokensPerMessage = 3 | ||||
| 		tokensPerName = 1 | ||||
| 	} | ||||
| 	tokenNum := 0 | ||||
| 	for _, message := range messages { | ||||
| 		tokenNum += tokensPerMessage | ||||
| 		tokenNum += getTokenNum(tokenEncoder, message.Content) | ||||
| 		tokenNum += getTokenNum(tokenEncoder, message.Role) | ||||
| 		if message.Name != nil { | ||||
| 			tokenNum += tokensPerName | ||||
| 			tokenNum += getTokenNum(tokenEncoder, *message.Name) | ||||
| func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, bool) { | ||||
| 	group := c.GetString("group") | ||||
| 	channel, err := model.CacheGetRandomSatisfiedChannel(group, modelName) | ||||
| 	if err != nil { | ||||
| 		message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName) | ||||
| 		if channel != nil { | ||||
| 			common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | ||||
| 			message = "数据库一致性已被破坏,请联系管理员" | ||||
| 		} | ||||
| 		common.AbortWithMessage(c, http.StatusServiceUnavailable, message) | ||||
| 		return nil, true | ||||
| 	} | ||||
| 	tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> | ||||
| 	return tokenNum | ||||
|  | ||||
| 	return channel, false | ||||
| } | ||||
|  | ||||
| func countTokenInput(input any, model string) int { | ||||
| 	switch input.(type) { | ||||
| 	case string: | ||||
| 		return countTokenText(input.(string), model) | ||||
| 	case []string: | ||||
| 		text := "" | ||||
| 		for _, s := range input.([]string) { | ||||
| 			text += s | ||||
| 		} | ||||
| 		return countTokenText(text, model) | ||||
| func getProvider(c *gin.Context, channelType int, relayMode int) (providersBase.ProviderInterface, bool) { | ||||
| 	provider := providers.GetProvider(channelType, c) | ||||
| 	if provider == nil { | ||||
| 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found") | ||||
| 		return nil, true | ||||
| 	} | ||||
| 	return 0 | ||||
|  | ||||
| 	if !provider.SupportAPI(relayMode) { | ||||
| 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel does not support this API") | ||||
| 		return nil, true | ||||
| 	} | ||||
|  | ||||
| 	return provider, false | ||||
| } | ||||
|  | ||||
| func countTokenText(text string, model string) int { | ||||
| 	tokenEncoder := getTokenEncoder(model) | ||||
| 	return getTokenNum(tokenEncoder, text) | ||||
| func setChannelToContext(c *gin.Context, channel *model.Channel) { | ||||
| 	// c.Set("channel", channel.Type) | ||||
| 	c.Set("channel_id", channel.Id) | ||||
| 	c.Set("channel_name", channel.Name) | ||||
| 	c.Set("api_key", channel.Key) | ||||
| 	c.Set("base_url", channel.GetBaseURL()) | ||||
| 	switch channel.Type { | ||||
| 	case common.ChannelTypeAzure: | ||||
| 		c.Set("api_version", channel.Other) | ||||
| 	case common.ChannelTypeXunfei: | ||||
| 		c.Set("api_version", channel.Other) | ||||
| 	case common.ChannelTypeGemini: | ||||
| 		c.Set("api_version", channel.Other) | ||||
| 	case common.ChannelTypeAIProxyLibrary: | ||||
| 		c.Set("library_id", channel.Other) | ||||
| 	case common.ChannelTypeAli: | ||||
| 		c.Set("plugin", channel.Other) | ||||
| 	} | ||||
|  | ||||
| } | ||||
|  | ||||
| func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { | ||||
| 	openAIError := OpenAIError{ | ||||
| 		Message: err.Error(), | ||||
| 		Type:    "one_api_error", | ||||
| 		Code:    code, | ||||
| 	} | ||||
| 	return &OpenAIErrorWithStatusCode{ | ||||
| 		OpenAIError: openAIError, | ||||
| 		StatusCode:  statusCode, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func shouldDisableChannel(err *OpenAIError, statusCode int) bool { | ||||
| func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool { | ||||
| 	if !common.AutomaticDisableChannelEnabled { | ||||
| 		return false | ||||
| 	} | ||||
| @@ -142,37 +136,165 @@ func shouldDisableChannel(err *OpenAIError, statusCode int) bool { | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| func setEventStreamHeaders(c *gin.Context) { | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool { | ||||
| 	if !common.AutomaticEnableChannelEnabled { | ||||
| 		return false | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	if openAIErr != nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { | ||||
| 	openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ | ||||
| 		StatusCode: resp.StatusCode, | ||||
| 		OpenAIError: OpenAIError{ | ||||
| 			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), | ||||
| 			Type:    "upstream_error", | ||||
| 			Code:    "bad_response_status_code", | ||||
| 			Param:   strconv.Itoa(resp.StatusCode), | ||||
| 		}, | ||||
| 	} | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { | ||||
| 	// quotaDelta is remaining quota to be consumed | ||||
| 	err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		common.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	err = model.CacheUpdateUserQuota(userId) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		common.SysError("error update user quota cache: " + err.Error()) | ||||
| 	} | ||||
| 	var textResponse TextResponse | ||||
| 	err = json.Unmarshal(responseBody, &textResponse) | ||||
| 	// totalQuota is total quota consumed | ||||
| 	if totalQuota != 0 { | ||||
| 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 		model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) | ||||
| 		model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) | ||||
| 		model.UpdateChannelUsedQuota(channelId, totalQuota) | ||||
| 	} | ||||
| 	if totalQuota <= 0 { | ||||
| 		common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func parseModelMapping(modelMapping string) (map[string]string, error) { | ||||
| 	if modelMapping == "" || modelMapping == "{}" { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	modelMap := make(map[string]string) | ||||
| 	err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	openAIErrorWithStatusCode.OpenAIError = textResponse.Error | ||||
| 	return modelMap, nil | ||||
| } | ||||
|  | ||||
| type QuotaInfo struct { | ||||
| 	modelName         string | ||||
| 	promptTokens      int | ||||
| 	preConsumedTokens int | ||||
| 	modelRatio        float64 | ||||
| 	groupRatio        float64 | ||||
| 	ratio             float64 | ||||
| 	preConsumedQuota  int | ||||
| 	userId            int | ||||
| 	channelId         int | ||||
| 	tokenId           int | ||||
| 	HandelStatus      bool | ||||
| } | ||||
|  | ||||
| func generateQuotaInfo(c *gin.Context, modelName string, promptTokens int) (*QuotaInfo, *types.OpenAIErrorWithStatusCode) { | ||||
| 	quotaInfo := &QuotaInfo{ | ||||
| 		modelName:    modelName, | ||||
| 		promptTokens: promptTokens, | ||||
| 		userId:       c.GetInt("id"), | ||||
| 		channelId:    c.GetInt("channel_id"), | ||||
| 		tokenId:      c.GetInt("token_id"), | ||||
| 		HandelStatus: false, | ||||
| 	} | ||||
| 	quotaInfo.initQuotaInfo(c.GetString("group")) | ||||
|  | ||||
| 	errWithCode := quotaInfo.preQuotaConsumption() | ||||
| 	if errWithCode != nil { | ||||
| 		return nil, errWithCode | ||||
| 	} | ||||
|  | ||||
| 	return quotaInfo, nil | ||||
| } | ||||
|  | ||||
| func (q *QuotaInfo) initQuotaInfo(groupName string) { | ||||
| 	modelRatio := common.GetModelRatio(q.modelName) | ||||
| 	groupRatio := common.GetGroupRatio(groupName) | ||||
| 	preConsumedTokens := common.PreConsumedQuota | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio) | ||||
|  | ||||
| 	q.preConsumedTokens = preConsumedTokens | ||||
| 	q.modelRatio = modelRatio | ||||
| 	q.groupRatio = groupRatio | ||||
| 	q.ratio = ratio | ||||
| 	q.preConsumedQuota = preConsumedQuota | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode { | ||||
| 	userQuota, err := model.CacheGetUserQuota(q.userId) | ||||
| 	if err != nil { | ||||
| 		return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	if userQuota < q.preConsumedQuota { | ||||
| 		return common.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| 	} | ||||
|  | ||||
| 	err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota) | ||||
| 	if err != nil { | ||||
| 		return common.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	if userQuota > 100*q.preConsumedQuota { | ||||
| 		// in this case, we do not pre-consume quota | ||||
| 		// because the user has enough quota | ||||
| 		q.preConsumedQuota = 0 | ||||
| 		// common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) | ||||
| 	} | ||||
|  | ||||
| 	if q.preConsumedQuota > 0 { | ||||
| 		err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota) | ||||
| 		if err != nil { | ||||
| 			return common.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||
| 		} | ||||
| 		q.HandelStatus = true | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error { | ||||
| 	quota := 0 | ||||
| 	completionRatio := common.GetCompletionRatio(q.modelName) | ||||
| 	promptTokens := usage.PromptTokens | ||||
| 	completionTokens := usage.CompletionTokens | ||||
| 	quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * q.ratio)) | ||||
| 	if q.ratio != 0 && quota <= 0 { | ||||
| 		quota = 1 | ||||
| 	} | ||||
| 	totalTokens := promptTokens + completionTokens | ||||
| 	if totalTokens == 0 { | ||||
| 		// in this case, must be some error happened | ||||
| 		// we cannot just return, because we may have to return the pre-consumed quota | ||||
| 		quota = 0 | ||||
| 	} | ||||
| 	quotaDelta := quota - q.preConsumedQuota | ||||
| 	err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta) | ||||
| 	if err != nil { | ||||
| 		return errors.New("error consuming token remain quota: " + err.Error()) | ||||
| 	} | ||||
| 	err = model.CacheUpdateUserQuota(q.userId) | ||||
| 	if err != nil { | ||||
| 		return errors.New("error consuming token remain quota: " + err.Error()) | ||||
| 	} | ||||
| 	if quota != 0 { | ||||
| 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", q.modelRatio, q.groupRatio) | ||||
| 		model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent) | ||||
| 		model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota) | ||||
| 		model.UpdateChannelUsedQuota(q.channelId, quota) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -1,303 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"crypto/hmac" | ||||
| 	"crypto/sha256" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // https://console.xfyun.cn/services/cbm | ||||
| // https://www.xfyun.cn/doc/spark/Web.html | ||||
|  | ||||
| type XunfeiMessage struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type XunfeiChatRequest struct { | ||||
| 	Header struct { | ||||
| 		AppId string `json:"app_id"` | ||||
| 	} `json:"header"` | ||||
| 	Parameter struct { | ||||
| 		Chat struct { | ||||
| 			Domain      string  `json:"domain,omitempty"` | ||||
| 			Temperature float64 `json:"temperature,omitempty"` | ||||
| 			TopK        int     `json:"top_k,omitempty"` | ||||
| 			MaxTokens   int     `json:"max_tokens,omitempty"` | ||||
| 			Auditing    bool    `json:"auditing,omitempty"` | ||||
| 		} `json:"chat"` | ||||
| 	} `json:"parameter"` | ||||
| 	Payload struct { | ||||
| 		Message struct { | ||||
| 			Text []XunfeiMessage `json:"text"` | ||||
| 		} `json:"message"` | ||||
| 	} `json:"payload"` | ||||
| } | ||||
|  | ||||
| type XunfeiChatResponseTextItem struct { | ||||
| 	Content string `json:"content"` | ||||
| 	Role    string `json:"role"` | ||||
| 	Index   int    `json:"index"` | ||||
| } | ||||
|  | ||||
| type XunfeiChatResponse struct { | ||||
| 	Header struct { | ||||
| 		Code    int    `json:"code"` | ||||
| 		Message string `json:"message"` | ||||
| 		Sid     string `json:"sid"` | ||||
| 		Status  int    `json:"status"` | ||||
| 	} `json:"header"` | ||||
| 	Payload struct { | ||||
| 		Choices struct { | ||||
| 			Status int                          `json:"status"` | ||||
| 			Seq    int                          `json:"seq"` | ||||
| 			Text   []XunfeiChatResponseTextItem `json:"text"` | ||||
| 		} `json:"choices"` | ||||
| 		Usage struct { | ||||
| 			//Text struct { | ||||
| 			//	QuestionTokens   string `json:"question_tokens"` | ||||
| 			//	PromptTokens     string `json:"prompt_tokens"` | ||||
| 			//	CompletionTokens string `json:"completion_tokens"` | ||||
| 			//	TotalTokens      string `json:"total_tokens"` | ||||
| 			//} `json:"text"` | ||||
| 			Text Usage `json:"text"` | ||||
| 		} `json:"usage"` | ||||
| 	} `json:"payload"` | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest { | ||||
| 	messages := make([]XunfeiMessage, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	xunfeiRequest := XunfeiChatRequest{} | ||||
| 	xunfeiRequest.Header.AppId = xunfeiAppId | ||||
| 	xunfeiRequest.Parameter.Chat.Domain = domain | ||||
| 	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature | ||||
| 	xunfeiRequest.Parameter.Chat.TopK = request.N | ||||
| 	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens | ||||
| 	xunfeiRequest.Payload.Message.Text = messages | ||||
| 	return &xunfeiRequest | ||||
| } | ||||
|  | ||||
| func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { | ||||
| 	if len(response.Payload.Choices.Text) == 0 { | ||||
| 		response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | ||||
| 			{ | ||||
| 				Content: "", | ||||
| 			}, | ||||
| 		} | ||||
| 	} | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Payload.Choices.Text[0].Content, | ||||
| 		}, | ||||
| 		FinishReason: stopFinishReason, | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Usage:   response.Payload.Usage.Text, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { | ||||
| 		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | ||||
| 			{ | ||||
| 				Content: "", | ||||
| 			}, | ||||
| 		} | ||||
| 	} | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content | ||||
| 	if xunfeiResponse.Payload.Choices.Status == 2 { | ||||
| 		choice.FinishReason = &stopFinishReason | ||||
| 	} | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "SparkDesk", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { | ||||
| 	HmacWithShaToBase64 := func(algorithm, data, key string) string { | ||||
| 		mac := hmac.New(sha256.New, []byte(key)) | ||||
| 		mac.Write([]byte(data)) | ||||
| 		encodeData := mac.Sum(nil) | ||||
| 		return base64.StdEncoding.EncodeToString(encodeData) | ||||
| 	} | ||||
| 	ul, err := url.Parse(hostUrl) | ||||
| 	if err != nil { | ||||
| 		fmt.Println(err) | ||||
| 	} | ||||
| 	date := time.Now().UTC().Format(time.RFC1123) | ||||
| 	signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} | ||||
| 	sign := strings.Join(signString, "\n") | ||||
| 	sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret) | ||||
| 	authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, | ||||
| 		"hmac-sha256", "host date request-line", sha) | ||||
| 	authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) | ||||
| 	v := url.Values{} | ||||
| 	v.Add("host", ul.Host) | ||||
| 	v.Add("date", date) | ||||
| 	v.Add("authorization", authorization) | ||||
| 	callUrl := hostUrl + "?" + v.Encode() | ||||
| 	return callUrl | ||||
| } | ||||
|  | ||||
| func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) | ||||
| 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	setEventStreamHeaders(c) | ||||
| 	var usage Usage | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case xunfeiResponse := <-dataChan: | ||||
| 			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens | ||||
| 			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens | ||||
| 			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens | ||||
| 			response := streamResponseXunfei2OpenAI(&xunfeiResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) | ||||
| 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var usage Usage | ||||
| 	var content string | ||||
| 	var xunfeiResponse XunfeiChatResponse | ||||
| 	stop := false | ||||
| 	for !stop { | ||||
| 		select { | ||||
| 		case xunfeiResponse = <-dataChan: | ||||
| 			content += xunfeiResponse.Payload.Choices.Text[0].Content | ||||
| 			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens | ||||
| 			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens | ||||
| 			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens | ||||
| 		case stop = <-stopChan: | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	xunfeiResponse.Payload.Choices.Text[0].Content = content | ||||
|  | ||||
| 	response := responseXunfei2OpenAI(&xunfeiResponse) | ||||
| 	jsonResponse, err := json.Marshal(response) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	_, _ = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { | ||||
| 	d := websocket.Dialer{ | ||||
| 		HandshakeTimeout: 5 * time.Second, | ||||
| 	} | ||||
| 	conn, resp, err := d.Dial(authUrl, nil) | ||||
| 	if err != nil || resp.StatusCode != 101 { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
| 	data := requestOpenAI2Xunfei(textRequest, appId, domain) | ||||
| 	err = conn.WriteJSON(data) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
|  | ||||
| 	dataChan := make(chan XunfeiChatResponse) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			_, msg, err := conn.ReadMessage() | ||||
| 			if err != nil { | ||||
| 				common.SysError("error reading stream response: " + err.Error()) | ||||
| 				break | ||||
| 			} | ||||
| 			var response XunfeiChatResponse | ||||
| 			err = json.Unmarshal(msg, &response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				break | ||||
| 			} | ||||
| 			dataChan <- response | ||||
| 			if response.Payload.Choices.Status == 2 { | ||||
| 				err := conn.Close() | ||||
| 				if err != nil { | ||||
| 					common.SysError("error closing websocket connection: " + err.Error()) | ||||
| 				} | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
|  | ||||
| 	return dataChan, stopChan, nil | ||||
| } | ||||
|  | ||||
| func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) { | ||||
| 	query := c.Request.URL.Query() | ||||
| 	apiVersion := query.Get("api-version") | ||||
| 	if apiVersion == "" { | ||||
| 		apiVersion = c.GetString("api_version") | ||||
| 	} | ||||
| 	if apiVersion == "" { | ||||
| 		apiVersion = "v1.1" | ||||
| 		common.SysLog("api_version not found, use default: " + apiVersion) | ||||
| 	} | ||||
| 	domain := "general" | ||||
| 	if apiVersion == "v2.1" { | ||||
| 		domain = "generalv2" | ||||
| 	} | ||||
| 	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) | ||||
| 	return domain, authUrl | ||||
| } | ||||
| @@ -1,301 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/golang-jwt/jwt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // https://open.bigmodel.cn/doc/api#chatglm_std | ||||
| // chatglm_std, chatglm_lite | ||||
| // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke | ||||
| // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke | ||||
|  | ||||
| type ZhipuMessage struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type ZhipuRequest struct { | ||||
| 	Prompt      []ZhipuMessage `json:"prompt"` | ||||
| 	Temperature float64        `json:"temperature,omitempty"` | ||||
| 	TopP        float64        `json:"top_p,omitempty"` | ||||
| 	RequestId   string         `json:"request_id,omitempty"` | ||||
| 	Incremental bool           `json:"incremental,omitempty"` | ||||
| } | ||||
|  | ||||
| type ZhipuResponseData struct { | ||||
| 	TaskId     string         `json:"task_id"` | ||||
| 	RequestId  string         `json:"request_id"` | ||||
| 	TaskStatus string         `json:"task_status"` | ||||
| 	Choices    []ZhipuMessage `json:"choices"` | ||||
| 	Usage      `json:"usage"` | ||||
| } | ||||
|  | ||||
| type ZhipuResponse struct { | ||||
| 	Code    int               `json:"code"` | ||||
| 	Msg     string            `json:"msg"` | ||||
| 	Success bool              `json:"success"` | ||||
| 	Data    ZhipuResponseData `json:"data"` | ||||
| } | ||||
|  | ||||
| type ZhipuStreamMetaResponse struct { | ||||
| 	RequestId  string `json:"request_id"` | ||||
| 	TaskId     string `json:"task_id"` | ||||
| 	TaskStatus string `json:"task_status"` | ||||
| 	Usage      `json:"usage"` | ||||
| } | ||||
|  | ||||
| type zhipuTokenData struct { | ||||
| 	Token      string | ||||
| 	ExpiryTime time.Time | ||||
| } | ||||
|  | ||||
| var zhipuTokens sync.Map | ||||
| var expSeconds int64 = 24 * 3600 | ||||
|  | ||||
| func getZhipuToken(apikey string) string { | ||||
| 	data, ok := zhipuTokens.Load(apikey) | ||||
| 	if ok { | ||||
| 		tokenData := data.(zhipuTokenData) | ||||
| 		if time.Now().Before(tokenData.ExpiryTime) { | ||||
| 			return tokenData.Token | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	split := strings.Split(apikey, ".") | ||||
| 	if len(split) != 2 { | ||||
| 		common.SysError("invalid zhipu key: " + apikey) | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| 	id := split[0] | ||||
| 	secret := split[1] | ||||
|  | ||||
| 	expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6 | ||||
| 	expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second) | ||||
|  | ||||
| 	timestamp := time.Now().UnixNano() / 1e6 | ||||
|  | ||||
| 	payload := jwt.MapClaims{ | ||||
| 		"api_key":   id, | ||||
| 		"exp":       expMillis, | ||||
| 		"timestamp": timestamp, | ||||
| 	} | ||||
|  | ||||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) | ||||
|  | ||||
| 	token.Header["alg"] = "HS256" | ||||
| 	token.Header["sign_type"] = "SIGN" | ||||
|  | ||||
| 	tokenString, err := token.SignedString([]byte(secret)) | ||||
| 	if err != nil { | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| 	zhipuTokens.Store(apikey, zhipuTokenData{ | ||||
| 		Token:      tokenString, | ||||
| 		ExpiryTime: expiryTime, | ||||
| 	}) | ||||
|  | ||||
| 	return tokenString | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { | ||||
| 	messages := make([]ZhipuMessage, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    "system", | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	return &ZhipuRequest{ | ||||
| 		Prompt:      messages, | ||||
| 		Temperature: request.Temperature, | ||||
| 		TopP:        request.TopP, | ||||
| 		Incremental: false, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      response.Data.TaskId, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)), | ||||
| 		Usage:   response.Data.Usage, | ||||
| 	} | ||||
| 	for i, choice := range response.Data.Choices { | ||||
| 		openaiChoice := OpenAITextResponseChoice{ | ||||
| 			Index: i, | ||||
| 			Message: Message{ | ||||
| 				Role:    choice.Role, | ||||
| 				Content: strings.Trim(choice.Content, "\""), | ||||
| 			}, | ||||
| 			FinishReason: "", | ||||
| 		} | ||||
| 		if i == len(response.Data.Choices)-1 { | ||||
| 			openaiChoice.FinishReason = "stop" | ||||
| 		} | ||||
| 		fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice) | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = zhipuResponse | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "chatglm", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = "" | ||||
| 	choice.FinishReason = &stopFinishReason | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Id:      zhipuResponse.RequestId, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "chatglm", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response, &zhipuResponse.Usage | ||||
| } | ||||
|  | ||||
| func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage *Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 { | ||||
| 			return i + 2, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	metaChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			lines := strings.Split(data, "\n") | ||||
| 			for i, line := range lines { | ||||
| 				if len(line) < 5 { | ||||
| 					continue | ||||
| 				} | ||||
| 				if line[:5] == "data:" { | ||||
| 					dataChan <- line[5:] | ||||
| 					if i != len(lines)-1 { | ||||
| 						dataChan <- "\n" | ||||
| 					} | ||||
| 				} else if line[:5] == "meta:" { | ||||
| 					metaChan <- line[5:] | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			response := streamResponseZhipu2OpenAI(data) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case data := <-metaChan: | ||||
| 			var zhipuResponse ZhipuStreamMetaResponse | ||||
| 			err := json.Unmarshal([]byte(data), &zhipuResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			usage = zhipuUsage | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, usage | ||||
| } | ||||
|  | ||||
| func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var zhipuResponse ZhipuResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &zhipuResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if !zhipuResponse.Success { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: zhipuResponse.Msg, | ||||
| 				Type:    "zhipu_error", | ||||
| 				Param:   "", | ||||
| 				Code:    zhipuResponse.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseZhipu2OpenAI(&zhipuResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
| @@ -4,228 +4,14 @@ import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/types" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type Message struct { | ||||
| 	Role    string  `json:"role"` | ||||
| 	Content string  `json:"content"` | ||||
| 	Name    *string `json:"name,omitempty"` | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	RelayModeUnknown = iota | ||||
| 	RelayModeChatCompletions | ||||
| 	RelayModeCompletions | ||||
| 	RelayModeEmbeddings | ||||
| 	RelayModeModerations | ||||
| 	RelayModeImagesGenerations | ||||
| 	RelayModeEdits | ||||
| 	RelayModeAudio | ||||
| ) | ||||
|  | ||||
| // https://platform.openai.com/docs/api-reference/chat | ||||
|  | ||||
| type GeneralOpenAIRequest struct { | ||||
| 	Model       string    `json:"model,omitempty"` | ||||
| 	Messages    []Message `json:"messages,omitempty"` | ||||
| 	Prompt      any       `json:"prompt,omitempty"` | ||||
| 	Stream      bool      `json:"stream,omitempty"` | ||||
| 	MaxTokens   int       `json:"max_tokens,omitempty"` | ||||
| 	Temperature float64   `json:"temperature,omitempty"` | ||||
| 	TopP        float64   `json:"top_p,omitempty"` | ||||
| 	N           int       `json:"n,omitempty"` | ||||
| 	Input       any       `json:"input,omitempty"` | ||||
| 	Instruction string    `json:"instruction,omitempty"` | ||||
| 	Size        string    `json:"size,omitempty"` | ||||
| 	Functions   any       `json:"functions,omitempty"` | ||||
| } | ||||
|  | ||||
| func (r GeneralOpenAIRequest) ParseInput() []string { | ||||
| 	if r.Input == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	var input []string | ||||
| 	switch r.Input.(type) { | ||||
| 	case string: | ||||
| 		input = []string{r.Input.(string)} | ||||
| 	case []any: | ||||
| 		input = make([]string, 0, len(r.Input.([]any))) | ||||
| 		for _, item := range r.Input.([]any) { | ||||
| 			if str, ok := item.(string); ok { | ||||
| 				input = append(input, str) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return input | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	Model     string    `json:"model"` | ||||
| 	Messages  []Message `json:"messages"` | ||||
| 	MaxTokens int       `json:"max_tokens"` | ||||
| } | ||||
|  | ||||
| type TextRequest struct { | ||||
| 	Model     string    `json:"model"` | ||||
| 	Messages  []Message `json:"messages"` | ||||
| 	Prompt    string    `json:"prompt"` | ||||
| 	MaxTokens int       `json:"max_tokens"` | ||||
| 	//Stream   bool      `json:"stream"` | ||||
| } | ||||
|  | ||||
| type ImageRequest struct { | ||||
| 	Prompt string `json:"prompt"` | ||||
| 	N      int    `json:"n"` | ||||
| 	Size   string `json:"size"` | ||||
| } | ||||
|  | ||||
| type AudioResponse struct { | ||||
| 	Text string `json:"text,omitempty"` | ||||
| } | ||||
|  | ||||
| type Usage struct { | ||||
| 	PromptTokens     int `json:"prompt_tokens"` | ||||
| 	CompletionTokens int `json:"completion_tokens"` | ||||
| 	TotalTokens      int `json:"total_tokens"` | ||||
| } | ||||
|  | ||||
| type OpenAIError struct { | ||||
| 	Message string `json:"message"` | ||||
| 	Type    string `json:"type"` | ||||
| 	Param   string `json:"param"` | ||||
| 	Code    any    `json:"code"` | ||||
| } | ||||
|  | ||||
| type OpenAIErrorWithStatusCode struct { | ||||
| 	OpenAIError | ||||
| 	StatusCode int `json:"status_code"` | ||||
| } | ||||
|  | ||||
| type TextResponse struct { | ||||
| 	Choices []OpenAITextResponseChoice `json:"choices"` | ||||
| 	Usage   `json:"usage"` | ||||
| 	Error   OpenAIError `json:"error"` | ||||
| } | ||||
|  | ||||
| type OpenAITextResponseChoice struct { | ||||
| 	Index        int `json:"index"` | ||||
| 	Message      `json:"message"` | ||||
| 	FinishReason string `json:"finish_reason"` | ||||
| } | ||||
|  | ||||
| type OpenAITextResponse struct { | ||||
| 	Id      string                     `json:"id"` | ||||
| 	Object  string                     `json:"object"` | ||||
| 	Created int64                      `json:"created"` | ||||
| 	Choices []OpenAITextResponseChoice `json:"choices"` | ||||
| 	Usage   `json:"usage"` | ||||
| } | ||||
|  | ||||
| type OpenAIEmbeddingResponseItem struct { | ||||
| 	Object    string    `json:"object"` | ||||
| 	Index     int       `json:"index"` | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| } | ||||
|  | ||||
| type OpenAIEmbeddingResponse struct { | ||||
| 	Object string                        `json:"object"` | ||||
| 	Data   []OpenAIEmbeddingResponseItem `json:"data"` | ||||
| 	Model  string                        `json:"model"` | ||||
| 	Usage  `json:"usage"` | ||||
| } | ||||
|  | ||||
| type ImageResponse struct { | ||||
| 	Created int `json:"created"` | ||||
| 	Data    []struct { | ||||
| 		Url string `json:"url"` | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type ChatCompletionsStreamResponseChoice struct { | ||||
| 	Delta struct { | ||||
| 		Content string `json:"content"` | ||||
| 	} `json:"delta"` | ||||
| 	FinishReason *string `json:"finish_reason"` | ||||
| } | ||||
|  | ||||
| type ChatCompletionsStreamResponse struct { | ||||
| 	Id      string                                `json:"id"` | ||||
| 	Object  string                                `json:"object"` | ||||
| 	Created int64                                 `json:"created"` | ||||
| 	Model   string                                `json:"model"` | ||||
| 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"` | ||||
| } | ||||
|  | ||||
| type CompletionsStreamResponse struct { | ||||
| 	Choices []struct { | ||||
| 		Text         string `json:"text"` | ||||
| 		FinishReason string `json:"finish_reason"` | ||||
| 	} `json:"choices"` | ||||
| } | ||||
|  | ||||
| func Relay(c *gin.Context) { | ||||
| 	relayMode := RelayModeUnknown | ||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { | ||||
| 		relayMode = RelayModeChatCompletions | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { | ||||
| 		relayMode = RelayModeCompletions | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { | ||||
| 		relayMode = RelayModeEmbeddings | ||||
| 	} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||
| 		relayMode = RelayModeEmbeddings | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||
| 		relayMode = RelayModeModerations | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||
| 		relayMode = RelayModeImagesGenerations | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { | ||||
| 		relayMode = RelayModeEdits | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||
| 		relayMode = RelayModeAudio | ||||
| 	} | ||||
| 	var err *OpenAIErrorWithStatusCode | ||||
| 	switch relayMode { | ||||
| 	case RelayModeImagesGenerations: | ||||
| 		err = relayImageHelper(c, relayMode) | ||||
| 	case RelayModeAudio: | ||||
| 		err = relayAudioHelper(c, relayMode) | ||||
| 	default: | ||||
| 		err = relayTextHelper(c, relayMode) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		requestId := c.GetString(common.RequestIdKey) | ||||
| 		retryTimesStr := c.Query("retry") | ||||
| 		retryTimes, _ := strconv.Atoi(retryTimesStr) | ||||
| 		if retryTimesStr == "" { | ||||
| 			retryTimes = common.RetryTimes | ||||
| 		} | ||||
| 		if retryTimes > 0 { | ||||
| 			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) | ||||
| 		} else { | ||||
| 			if err.StatusCode == http.StatusTooManyRequests { | ||||
| 				err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" | ||||
| 			} | ||||
| 			err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) | ||||
| 			c.JSON(err.StatusCode, gin.H{ | ||||
| 				"error": err.OpenAIError, | ||||
| 			}) | ||||
| 		} | ||||
| 		channelId := c.GetInt("channel_id") | ||||
| 		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | ||||
| 		// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||
| 		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { | ||||
| 			channelId := c.GetInt("channel_id") | ||||
| 			channelName := c.GetString("channel_name") | ||||
| 			disableChannel(channelId, channelName, err.Message) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func RelayNotImplemented(c *gin.Context) { | ||||
| 	err := OpenAIError{ | ||||
| 	err := types.OpenAIError{ | ||||
| 		Message: "API not implemented", | ||||
| 		Type:    "one_api_error", | ||||
| 		Param:   "", | ||||
| @@ -237,7 +23,7 @@ func RelayNotImplemented(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func RelayNotFound(c *gin.Context) { | ||||
| 	err := OpenAIError{ | ||||
| 	err := types.OpenAIError{ | ||||
| 		Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), | ||||
| 		Type:    "invalid_request_error", | ||||
| 		Param:   "", | ||||
| @@ -247,3 +33,31 @@ func RelayNotFound(c *gin.Context) { | ||||
| 		"error": err, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func errorHelper(c *gin.Context, err *types.OpenAIErrorWithStatusCode) { | ||||
| 	requestId := c.GetString(common.RequestIdKey) | ||||
| 	retryTimesStr := c.Query("retry") | ||||
| 	retryTimes, _ := strconv.Atoi(retryTimesStr) | ||||
| 	if retryTimesStr == "" { | ||||
| 		retryTimes = common.RetryTimes | ||||
| 	} | ||||
| 	if retryTimes > 0 { | ||||
| 		c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) | ||||
| 	} else { | ||||
| 		if err.StatusCode == http.StatusTooManyRequests { | ||||
| 			err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" | ||||
| 		} | ||||
| 		err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) | ||||
| 		c.JSON(err.StatusCode, gin.H{ | ||||
| 			"error": err.OpenAIError, | ||||
| 		}) | ||||
| 	} | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | ||||
| 	// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||
| 	if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { | ||||
| 		channelId := c.GetInt("channel_id") | ||||
| 		channelName := c.GetString("channel_name") | ||||
| 		disableChannel(channelId, channelName, err.Message) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -7,6 +7,7 @@ import ( | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| @@ -248,6 +249,30 @@ func GetUser(c *gin.Context) { | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func GetUserDashboard(c *gin.Context) { | ||||
| 	id := c.GetInt("id") | ||||
| 	// 获取7天前 00:00:00 和 今天23:59:59  的秒时间戳 | ||||
| 	now := time.Now() | ||||
| 	toDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) | ||||
| 	endOfDay := toDay.Add(time.Hour * 24).Add(-time.Second).Unix() | ||||
| 	startOfDay := toDay.AddDate(0, 0, -7).Unix() | ||||
|  | ||||
| 	dashboards, err := model.SearchLogsByDayAndModel(id, int(startOfDay), int(endOfDay)) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "无法获取统计信息.", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data":    dashboards, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func GenerateAccessToken(c *gin.Context) { | ||||
| 	id := c.GetInt("id") | ||||
| 	user, err := model.GetUserById(id, true) | ||||
|   | ||||
| @@ -9,21 +9,21 @@ services: | ||||
|     ports: | ||||
|       - "3000:3000" | ||||
|     volumes: | ||||
|       - ./data:/data | ||||
|       - ./data/oneapi:/data | ||||
|       - ./logs:/app/logs | ||||
|     environment: | ||||
|       - SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api  # 修改此行,或注释掉以使用 SQLite 作为数据库 | ||||
|       - SQL_DSN=oneapi:123456@tcp(db:3306)/one-api  # 修改此行,或注释掉以使用 SQLite 作为数据库 | ||||
|       - REDIS_CONN_STRING=redis://redis | ||||
|       - SESSION_SECRET=random_string  # 修改为随机字符串 | ||||
|       - TZ=Asia/Shanghai | ||||
| #      - NODE_TYPE=slave  # 多机部署时从节点取消注释该行 | ||||
| #      - SYNC_FREQUENCY=60  # 需要定期从数据库加载数据时取消注释该行 | ||||
| #      - FRONTEND_BASE_URL=https://openai.justsong.cn  # 多机部署时从节点取消注释该行 | ||||
|  | ||||
|     depends_on: | ||||
|       - redis | ||||
|       - db | ||||
|     healthcheck: | ||||
|       test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ] | ||||
|       test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] | ||||
|       interval: 30s | ||||
|       timeout: 10s | ||||
|       retries: 3 | ||||
| @@ -32,3 +32,18 @@ services: | ||||
|     image: redis:latest | ||||
|     container_name: redis | ||||
|     restart: always | ||||
|  | ||||
|   db: | ||||
|     image: mysql:8.2.0 | ||||
|     restart: always | ||||
|     container_name: mysql | ||||
|     volumes: | ||||
|       - ./data/mysql:/var/lib/mysql  # 挂载目录,持久化存储 | ||||
|     ports: | ||||
|       - '3306:3306' | ||||
|     environment: | ||||
|       TZ: Asia/Shanghai   # 设置时区 | ||||
|       MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码 | ||||
|       MYSQL_USER: oneapi   # 创建专用用户 | ||||
|       MYSQL_PASSWORD: '123456'    # 设置专用用户密码 | ||||
|       MYSQL_DATABASE: one-api   # 自动创建数据库 | ||||
							
								
								
									
										15
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								go.mod
									
									
									
									
									
								
							| @@ -15,8 +15,11 @@ require ( | ||||
| 	github.com/google/uuid v1.3.0 | ||||
| 	github.com/gorilla/websocket v1.5.0 | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.5 | ||||
| 	golang.org/x/crypto v0.9.0 | ||||
| 	github.com/stretchr/testify v1.8.3 | ||||
| 	golang.org/x/crypto v0.14.0 | ||||
| 	golang.org/x/image v0.14.0 | ||||
| 	gorm.io/driver/mysql v1.4.3 | ||||
| 	gorm.io/driver/postgres v1.5.2 | ||||
| 	gorm.io/driver/sqlite v1.4.3 | ||||
| 	gorm.io/gorm v1.25.0 | ||||
| ) | ||||
| @@ -25,6 +28,7 @@ require ( | ||||
| 	github.com/bytedance/sonic v1.9.1 // indirect | ||||
| 	github.com/cespare/xxhash/v2 v2.1.2 // indirect | ||||
| 	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect | ||||
| 	github.com/davecgh/go-spew v1.1.1 // indirect | ||||
| 	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect | ||||
| 	github.com/dlclark/regexp2 v1.10.0 // indirect | ||||
| 	github.com/gabriel-vasile/mimetype v1.4.2 // indirect | ||||
| @@ -41,6 +45,7 @@ require ( | ||||
| 	github.com/jackc/pgx/v5 v5.3.1 // indirect | ||||
| 	github.com/jinzhu/inflection v1.0.0 // indirect | ||||
| 	github.com/jinzhu/now v1.1.5 // indirect | ||||
| 	github.com/joho/godotenv v1.5.1 // indirect | ||||
| 	github.com/json-iterator/go v1.1.12 // indirect | ||||
| 	github.com/klauspost/cpuid/v2 v2.2.4 // indirect | ||||
| 	github.com/leodido/go-urn v1.2.4 // indirect | ||||
| @@ -49,13 +54,13 @@ require ( | ||||
| 	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect | ||||
| 	github.com/modern-go/reflect2 v1.0.2 // indirect | ||||
| 	github.com/pelletier/go-toml/v2 v2.0.8 // indirect | ||||
| 	github.com/pmezard/go-difflib v1.0.0 // indirect | ||||
| 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | ||||
| 	github.com/ugorji/go/codec v1.2.11 // indirect | ||||
| 	golang.org/x/arch v0.3.0 // indirect | ||||
| 	golang.org/x/net v0.10.0 // indirect | ||||
| 	golang.org/x/sys v0.8.0 // indirect | ||||
| 	golang.org/x/text v0.9.0 // indirect | ||||
| 	golang.org/x/net v0.17.0 // indirect | ||||
| 	golang.org/x/sys v0.13.0 // indirect | ||||
| 	golang.org/x/text v0.14.0 // indirect | ||||
| 	google.golang.org/protobuf v1.30.0 // indirect | ||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||
| 	gorm.io/driver/postgres v1.5.2 // indirect | ||||
| ) | ||||
|   | ||||
							
								
								
									
										21
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								go.sum
									
									
									
									
									
								
							| @@ -80,6 +80,8 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr | ||||
| github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||
| github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= | ||||
| github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||
| github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= | ||||
| github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= | ||||
| github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= | ||||
| github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= | ||||
| github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= | ||||
| @@ -150,11 +152,13 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu | ||||
| golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= | ||||
| golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | ||||
| golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= | ||||
| golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= | ||||
| golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= | ||||
| golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= | ||||
| golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= | ||||
| golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= | ||||
| golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= | ||||
| golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | ||||
| golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= | ||||
| golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= | ||||
| golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= | ||||
| golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= | ||||
| golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| @@ -162,14 +166,14 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc | ||||
| golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= | ||||
| golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= | ||||
| golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | ||||
| golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= | ||||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= | ||||
| golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= | ||||
| golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= | ||||
| golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= | ||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| @@ -198,7 +202,6 @@ gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBp | ||||
| gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= | ||||
| gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= | ||||
| gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= | ||||
| gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74= | ||||
| gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= | ||||
| gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= | ||||
| gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= | ||||
|   | ||||
| @@ -119,6 +119,7 @@ | ||||
|   " 年 ": " y ", | ||||
|   "未测试": "Not tested", | ||||
|   "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", | ||||
|   "已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.", | ||||
|   "已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", | ||||
|   "通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", | ||||
|   "已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!", | ||||
| @@ -139,6 +140,7 @@ | ||||
|   "启用": "Enable", | ||||
|   "编辑": "Edit", | ||||
|   "添加新的渠道": "Add a new channel", | ||||
|   "测试所有通道": "Test all channels", | ||||
|   "测试所有已启用通道": "Test all enabled channels", | ||||
|   "更新所有已启用通道余额": "Update the balance of all enabled channels", | ||||
|   "刷新": "Refresh", | ||||
|   | ||||
							
								
								
									
										9
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								main.go
									
									
									
									
									
								
							| @@ -3,9 +3,6 @@ package main | ||||
| import ( | ||||
| 	"embed" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-contrib/sessions/cookie" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"one-api/common" | ||||
| 	"one-api/controller" | ||||
| 	"one-api/middleware" | ||||
| @@ -13,6 +10,10 @@ import ( | ||||
| 	"one-api/router" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
|  | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-contrib/sessions/cookie" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| //go:embed web/build | ||||
| @@ -82,7 +83,7 @@ func main() { | ||||
| 		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") | ||||
| 		model.InitBatchUpdater() | ||||
| 	} | ||||
| 	controller.InitTokenEncoders() | ||||
| 	common.InitTokenEncoders() | ||||
|  | ||||
| 	// Initialize HTTP server | ||||
| 	server := gin.New() | ||||
|   | ||||
| @@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) { | ||||
| 		c.Set("id", token.UserId) | ||||
| 		c.Set("token_id", token.Id) | ||||
| 		c.Set("token_name", token.Name) | ||||
| 		requestURL := c.Request.URL.String() | ||||
| 		consumeQuota := true | ||||
| 		if strings.HasPrefix(requestURL, "/v1/models") { | ||||
| 			consumeQuota = false | ||||
| 		} | ||||
| 		c.Set("consume_quota", consumeQuota) | ||||
| 		if len(parts) > 1 { | ||||
| 			if model.IsAdmin(token.UserId) { | ||||
| 				c.Set("channelId", parts[1]) | ||||
|   | ||||
| @@ -1,98 +1,16 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type ModelRequest struct { | ||||
| 	Model string `json:"model"` | ||||
| } | ||||
|  | ||||
| func Distribute() func(c *gin.Context) { | ||||
| 	return func(c *gin.Context) { | ||||
| 		userId := c.GetInt("id") | ||||
| 		userGroup, _ := model.CacheGetUserGroup(userId) | ||||
| 		c.Set("group", userGroup) | ||||
| 		var channel *model.Channel | ||||
| 		channelId, ok := c.Get("channelId") | ||||
| 		if ok { | ||||
| 			id, err := strconv.Atoi(channelId.(string)) | ||||
| 			if err != nil { | ||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") | ||||
| 				return | ||||
| 			} | ||||
| 			channel, err = model.GetChannelById(id, true) | ||||
| 			if err != nil { | ||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") | ||||
| 				return | ||||
| 			} | ||||
| 			if channel.Status != common.ChannelStatusEnabled { | ||||
| 				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			// Select a channel for the user | ||||
| 			var modelRequest ModelRequest | ||||
| 			var err error | ||||
| 			if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||
| 				err = common.UnmarshalBodyReusable(c, &modelRequest) | ||||
| 			} | ||||
| 			if err != nil { | ||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的请求") | ||||
| 				return | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "text-moderation-stable" | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = c.Param("model") | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "dall-e" | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "whisper-1" | ||||
| 				} | ||||
| 			} | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) | ||||
| 			if err != nil { | ||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | ||||
| 				if channel != nil { | ||||
| 					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | ||||
| 					message = "数据库一致性已被破坏,请联系管理员" | ||||
| 				} | ||||
| 				abortWithMessage(c, http.StatusServiceUnavailable, message) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 		c.Set("channel", channel.Type) | ||||
| 		c.Set("channel_id", channel.Id) | ||||
| 		c.Set("channel_name", channel.Name) | ||||
| 		c.Set("model_mapping", channel.GetModelMapping()) | ||||
| 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||
| 		c.Set("base_url", channel.GetBaseURL()) | ||||
| 		switch channel.Type { | ||||
| 		case common.ChannelTypeAzure: | ||||
| 			c.Set("api_version", channel.Other) | ||||
| 		case common.ChannelTypeXunfei: | ||||
| 			c.Set("api_version", channel.Other) | ||||
| 		case common.ChannelTypeAIProxyLibrary: | ||||
| 			c.Set("library_id", channel.Other) | ||||
| 		} | ||||
| 		c.Next() | ||||
| 	} | ||||
| } | ||||
|   | ||||
							
								
								
									
										26
									
								
								middleware/recover.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								middleware/recover.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| ) | ||||
|  | ||||
| func RelayPanicRecover() gin.HandlerFunc { | ||||
| 	return func(c *gin.Context) { | ||||
| 		defer func() { | ||||
| 			if err := recover(); err != nil { | ||||
| 				common.SysError(fmt.Sprintf("panic detected: %v", err)) | ||||
| 				c.JSON(http.StatusInternalServerError, gin.H{ | ||||
| 					"error": gin.H{ | ||||
| 						"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), | ||||
| 						"type":    "one_api_panic", | ||||
| 					}, | ||||
| 				}) | ||||
| 				c.Abort() | ||||
| 			} | ||||
| 		}() | ||||
| 		c.Next() | ||||
| 	} | ||||
| } | ||||
| @@ -15,10 +15,17 @@ type Ability struct { | ||||
|  | ||||
| func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||
| 	ability := Ability{} | ||||
| 	groupCol := "`group`" | ||||
| 	trueVal := "1" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		groupCol = `"group"` | ||||
| 		trueVal = "true" | ||||
| 	} | ||||
|  | ||||
| 	var err error = nil | ||||
| 	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model) | ||||
| 	channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery) | ||||
| 	if common.UsingSQLite { | ||||
| 	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) | ||||
| 	channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) | ||||
| 	if common.UsingSQLite || common.UsingPostgreSQL { | ||||
| 		err = channelQuery.Order("RANDOM()").First(&ability).Error | ||||
| 	} else { | ||||
| 		err = channelQuery.Order("RAND()").First(&ability).Error | ||||
|   | ||||
| @@ -21,14 +21,18 @@ var ( | ||||
| ) | ||||
|  | ||||
| func CacheGetTokenByKey(key string) (*Token, error) { | ||||
| 	keyCol := "`key`" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		keyCol = `"key"` | ||||
| 	} | ||||
| 	var token Token | ||||
| 	if !common.RedisEnabled { | ||||
| 		err := DB.Where("`key` = ?", key).First(&token).Error | ||||
| 		err := DB.Where(keyCol+" = ?", key).First(&token).Error | ||||
| 		return &token, err | ||||
| 	} | ||||
| 	tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) | ||||
| 	if err != nil { | ||||
| 		err := DB.Where("`key` = ?", key).First(&token).Error | ||||
| 		err := DB.Where(keyCol+" = ?", key).First(&token).Error | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|   | ||||
| @@ -1,8 +1,9 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"gorm.io/gorm" | ||||
| 	"one-api/common" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type Channel struct { | ||||
| @@ -38,7 +39,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | ||||
| } | ||||
|  | ||||
| func SearchChannels(keyword string) (channels []*Channel, err error) { | ||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error | ||||
| 	keyCol := "`key`" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		keyCol = `"key"` | ||||
| 	} | ||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error | ||||
| 	return channels, err | ||||
| } | ||||
|  | ||||
| @@ -53,17 +58,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) { | ||||
| 	return &channel, err | ||||
| } | ||||
|  | ||||
| func GetRandomChannel() (*Channel, error) { | ||||
| 	channel := Channel{} | ||||
| 	var err error = nil | ||||
| 	if common.UsingSQLite { | ||||
| 		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error | ||||
| 	} else { | ||||
| 		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error | ||||
| 	} | ||||
| 	return &channel, err | ||||
| } | ||||
|  | ||||
| func BatchInsertChannels(channels []Channel) error { | ||||
| 	var err error | ||||
| 	err = DB.Create(&channels).Error | ||||
| @@ -176,3 +170,13 @@ func updateChannelUsedQuota(id int, quota int) { | ||||
| 		common.SysError("failed to update channel used quota: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func DeleteChannelByStatus(status int64) (int64, error) { | ||||
| 	result := DB.Where("status = ?", status).Delete(&Channel{}) | ||||
| 	return result.RowsAffected, result.Error | ||||
| } | ||||
|  | ||||
| func DeleteDisabledChannel() (int64, error) { | ||||
| 	result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) | ||||
| 	return result.RowsAffected, result.Error | ||||
| } | ||||
|   | ||||
							
								
								
									
										58
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										58
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -3,8 +3,9 @@ package model | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm" | ||||
| 	"one-api/common" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type Log struct { | ||||
| @@ -22,6 +23,15 @@ type Log struct { | ||||
| 	ChannelId        int    `json:"channel" gorm:"index"` | ||||
| } | ||||
|  | ||||
| type LogStatistic struct { | ||||
| 	Day              string `gorm:"column:day"` | ||||
| 	ModelName        string `gorm:"column:model_name"` | ||||
| 	RequestCount     int    `gorm:"column:request_count"` | ||||
| 	Quota            int    `gorm:"column:quota"` | ||||
| 	PromptTokens     int    `gorm:"column:prompt_tokens"` | ||||
| 	CompletionTokens int    `gorm:"column:completion_tokens"` | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	LogTypeUnknown = iota | ||||
| 	LogTypeTopup | ||||
| @@ -94,7 +104,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName | ||||
| 		tx = tx.Where("created_at <= ?", endTimestamp) | ||||
| 	} | ||||
| 	if channel != 0 { | ||||
| 		tx = tx.Where("channel = ?", channel) | ||||
| 		tx = tx.Where("channel_id = ?", channel) | ||||
| 	} | ||||
| 	err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error | ||||
| 	return logs, err | ||||
| @@ -134,7 +144,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | ||||
| } | ||||
|  | ||||
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { | ||||
| 	tx := DB.Table("logs").Select("ifnull(sum(quota),0)") | ||||
| 	tx := DB.Table("logs").Select(assembleSumSelectStr("quota")) | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
| 	} | ||||
| @@ -151,14 +161,14 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa | ||||
| 		tx = tx.Where("model_name = ?", modelName) | ||||
| 	} | ||||
| 	if channel != 0 { | ||||
| 		tx = tx.Where("channel = ?", channel) | ||||
| 		tx = tx.Where("channel_id = ?", channel) | ||||
| 	} | ||||
| 	tx.Where("type = ?", LogTypeConsume).Scan("a) | ||||
| 	return quota | ||||
| } | ||||
|  | ||||
| func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | ||||
| 	tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") | ||||
| 	tx := DB.Table("logs").Select(assembleSumSelectStr("prompt_tokens") + " + " + assembleSumSelectStr("completion_tokens")) | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
| 	} | ||||
| @@ -182,3 +192,41 @@ func DeleteOldLog(targetTimestamp int64) (int64, error) { | ||||
| 	result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) | ||||
| 	return result.RowsAffected, result.Error | ||||
| } | ||||
|  | ||||
| func SearchLogsByDayAndModel(user_id, start, end int) (LogStatistics []*LogStatistic, err error) { | ||||
| 	groupSelect := "DATE_FORMAT(FROM_UNIXTIME(created_at), '%Y-%m-%d') as day" | ||||
|  | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		groupSelect = "TO_CHAR(date_trunc('day', to_timestamp(created_at)), 'YYYY-MM-DD') as day" | ||||
| 	} | ||||
|  | ||||
| 	err = DB.Raw(` | ||||
| 		SELECT `+groupSelect+`, | ||||
| 		model_name, count(1) as request_count, | ||||
| 		sum(quota) as quota, | ||||
| 		sum(prompt_tokens) as prompt_tokens, | ||||
| 		sum(completion_tokens) as completion_tokens | ||||
| 		FROM logs | ||||
| 		WHERE type=2 | ||||
| 		AND user_id= ? | ||||
| 		AND created_at BETWEEN ? AND ? | ||||
| 		GROUP BY day, model_name | ||||
| 		ORDER BY day, model_name | ||||
| 	`, user_id, start, end).Scan(&LogStatistics).Error | ||||
|  | ||||
| 	fmt.Println(user_id, start, end) | ||||
|  | ||||
| 	return LogStatistics, err | ||||
| } | ||||
|  | ||||
| func assembleSumSelectStr(selectStr string) string { | ||||
| 	sumSelectStr := "%s(sum(%s),0)" | ||||
| 	nullfunc := "ifnull" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		nullfunc = "coalesce" | ||||
| 	} | ||||
|  | ||||
| 	sumSelectStr = fmt.Sprintf(sumSelectStr, nullfunc, selectStr) | ||||
|  | ||||
| 	return sumSelectStr | ||||
| } | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"gorm.io/driver/mysql" | ||||
| 	"gorm.io/driver/postgres" | ||||
| 	"gorm.io/driver/sqlite" | ||||
| @@ -42,6 +43,7 @@ func chooseDB() (*gorm.DB, error) { | ||||
| 		if strings.HasPrefix(dsn, "postgres://") { | ||||
| 			// Use PostgreSQL | ||||
| 			common.SysLog("using PostgreSQL as database") | ||||
| 			common.UsingPostgreSQL = true | ||||
| 			return gorm.Open(postgres.New(postgres.Config{ | ||||
| 				DSN:                  dsn, | ||||
| 				PreferSimpleProtocol: true, // disables implicit prepared statement usage | ||||
| @@ -58,7 +60,8 @@ func chooseDB() (*gorm.DB, error) { | ||||
| 	// Use SQLite | ||||
| 	common.SysLog("SQL_DSN not set, using SQLite as database") | ||||
| 	common.UsingSQLite = true | ||||
| 	return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ | ||||
| 	config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) | ||||
| 	return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ | ||||
| 		PrepareStmt: true, // precompile SQL | ||||
| 	}) | ||||
| } | ||||
|   | ||||
| @@ -34,6 +34,7 @@ func InitOptionMap() { | ||||
| 	common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) | ||||
| 	common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) | ||||
| 	common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) | ||||
| 	common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) | ||||
| 	common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) | ||||
| 	common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) | ||||
| 	common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) | ||||
| @@ -147,6 +148,8 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 			common.EmailDomainRestrictionEnabled = boolValue | ||||
| 		case "AutomaticDisableChannelEnabled": | ||||
| 			common.AutomaticDisableChannelEnabled = boolValue | ||||
| 		case "AutomaticEnableChannelEnabled": | ||||
| 			common.AutomaticEnableChannelEnabled = boolValue | ||||
| 		case "ApproximateTokenEnabled": | ||||
| 			common.ApproximateTokenEnabled = boolValue | ||||
| 		case "LogConsumeEnabled": | ||||
|   | ||||
| @@ -3,8 +3,9 @@ package model | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm" | ||||
| 	"one-api/common" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type Redemption struct { | ||||
| @@ -27,7 +28,7 @@ func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) { | ||||
| } | ||||
|  | ||||
| func SearchRedemptions(keyword string) (redemptions []*Redemption, err error) { | ||||
| 	err = DB.Where("id = ? or name LIKE ?", keyword, keyword+"%").Find(&redemptions).Error | ||||
| 	err = DB.Where("id = ? or name LIKE ?", common.String2Int(keyword), keyword+"%").Find(&redemptions).Error | ||||
| 	return redemptions, err | ||||
| } | ||||
|  | ||||
| @@ -50,8 +51,13 @@ func Redeem(key string, userId int) (quota int, err error) { | ||||
| 	} | ||||
| 	redemption := &Redemption{} | ||||
|  | ||||
| 	keyCol := "`key`" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		keyCol = `"key"` | ||||
| 	} | ||||
|  | ||||
| 	err = DB.Transaction(func(tx *gorm.DB) error { | ||||
| 		err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error | ||||
| 		err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error | ||||
| 		if err != nil { | ||||
| 			return errors.New("无效的兑换码") | ||||
| 		} | ||||
|   | ||||
| @@ -3,9 +3,10 @@ package model | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // User if you add sensitive fields, don't forget to clean them in setupLogin function. | ||||
| @@ -42,7 +43,8 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) { | ||||
| } | ||||
|  | ||||
| func SearchUsers(keyword string) (users []*User, err error) { | ||||
| 	err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error | ||||
| 	err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", common.String2Int(keyword), keyword+"%", keyword+"%", keyword+"%").Find(&users).Error | ||||
|  | ||||
| 	return users, err | ||||
| } | ||||
|  | ||||
| @@ -266,7 +268,12 @@ func GetUserEmail(id int) (email string, err error) { | ||||
| } | ||||
|  | ||||
| func GetUserGroup(id int) (group string, err error) { | ||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error | ||||
| 	groupCol := "`group`" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		groupCol = `"group"` | ||||
| 	} | ||||
|  | ||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error | ||||
| 	return group, err | ||||
| } | ||||
|  | ||||
| @@ -309,7 +316,8 @@ func GetRootUserEmail() (email string) { | ||||
|  | ||||
| func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | ||||
| 	if common.BatchUpdateEnabled { | ||||
| 		addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota) | ||||
| 		addNewRecord(BatchUpdateTypeUsedQuota, id, quota) | ||||
| 		addNewRecord(BatchUpdateTypeRequestCount, id, 1) | ||||
| 		return | ||||
| 	} | ||||
| 	updateUserUsedQuotaAndRequestCount(id, quota, 1) | ||||
| @@ -327,6 +335,24 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func updateUserUsedQuota(id int, quota int) { | ||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | ||||
| 		map[string]interface{}{ | ||||
| 			"used_quota": gorm.Expr("used_quota + ?", quota), | ||||
| 		}, | ||||
| 	).Error | ||||
| 	if err != nil { | ||||
| 		common.SysError("failed to update user used quota: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func updateUserRequestCount(id int, count int) { | ||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error | ||||
| 	if err != nil { | ||||
| 		common.SysError("failed to update user request count: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetUsernameById(id int) (username string) { | ||||
| 	DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username) | ||||
| 	return username | ||||
|   | ||||
| @@ -6,13 +6,13 @@ import ( | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock | ||||
|  | ||||
| const ( | ||||
| 	BatchUpdateTypeUserQuota = iota | ||||
| 	BatchUpdateTypeTokenQuota | ||||
| 	BatchUpdateTypeUsedQuotaAndRequestCount | ||||
| 	BatchUpdateTypeUsedQuota | ||||
| 	BatchUpdateTypeChannelUsedQuota | ||||
| 	BatchUpdateTypeRequestCount | ||||
| 	BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock | ||||
| ) | ||||
|  | ||||
| var batchUpdateStores []map[int]int | ||||
| @@ -51,7 +51,7 @@ func batchUpdate() { | ||||
| 		store := batchUpdateStores[i] | ||||
| 		batchUpdateStores[i] = make(map[int]int) | ||||
| 		batchUpdateLocks[i].Unlock() | ||||
|  | ||||
| 		// TODO: maybe we can combine updates with same key? | ||||
| 		for key, value := range store { | ||||
| 			switch i { | ||||
| 			case BatchUpdateTypeUserQuota: | ||||
| @@ -64,8 +64,10 @@ func batchUpdate() { | ||||
| 				if err != nil { | ||||
| 					common.SysError("failed to batch update token quota: " + err.Error()) | ||||
| 				} | ||||
| 			case BatchUpdateTypeUsedQuotaAndRequestCount: | ||||
| 				updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect | ||||
| 			case BatchUpdateTypeUsedQuota: | ||||
| 				updateUserUsedQuota(key, value) | ||||
| 			case BatchUpdateTypeRequestCount: | ||||
| 				updateUserRequestCount(key, value) | ||||
| 			case BatchUpdateTypeChannelUsedQuota: | ||||
| 				updateChannelUsedQuota(key, value) | ||||
| 			} | ||||
|   | ||||
							
								
								
									
										30
									
								
								providers/aigc2d/balance.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								providers/aigc2d/balance.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | ||||
| package aigc2d | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"one-api/providers/base" | ||||
| ) | ||||
|  | ||||
| func (p *Aigc2dProvider) Balance(channel *model.Channel) (float64, error) { | ||||
| 	fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "") | ||||
| 	headers := p.GetRequestHeaders() | ||||
|  | ||||
| 	client := common.NewClient() | ||||
| 	req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	// 发送请求 | ||||
| 	var response base.BalanceResponse | ||||
| 	_, errWithCode := common.SendRequest(req, &response, false) | ||||
| 	if errWithCode != nil { | ||||
| 		return 0, errors.New(errWithCode.OpenAIError.Message) | ||||
| 	} | ||||
|  | ||||
| 	channel.UpdateBalance(response.TotalAvailable) | ||||
|  | ||||
| 	return response.TotalAvailable, nil | ||||
| } | ||||
							
								
								
									
										20
									
								
								providers/aigc2d/base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								providers/aigc2d/base.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| package aigc2d | ||||
|  | ||||
| import ( | ||||
| 	"one-api/providers/base" | ||||
| 	"one-api/providers/openai" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type Aigc2dProviderFactory struct{} | ||||
|  | ||||
| func (f Aigc2dProviderFactory) Create(c *gin.Context) base.ProviderInterface { | ||||
| 	return &Aigc2dProvider{ | ||||
| 		OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aigc2d.com"), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type Aigc2dProvider struct { | ||||
| 	*openai.OpenAIProvider | ||||
| } | ||||
							
								
								
									
										35
									
								
								providers/aiproxy/balance.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								providers/aiproxy/balance.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,35 @@ | ||||
| package aiproxy | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| ) | ||||
|  | ||||
| func (p *AIProxyProvider) Balance(channel *model.Channel) (float64, error) { | ||||
| 	fullRequestURL := "https://aiproxy.io/api/report/getUserOverview" | ||||
| 	headers := make(map[string]string) | ||||
| 	headers["Api-Key"] = channel.Key | ||||
|  | ||||
| 	client := common.NewClient() | ||||
| 	req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	// 发送请求 | ||||
| 	var response AIProxyUserOverviewResponse | ||||
| 	_, errWithCode := common.SendRequest(req, &response, false) | ||||
| 	if errWithCode != nil { | ||||
| 		return 0, errors.New(errWithCode.OpenAIError.Message) | ||||
| 	} | ||||
|  | ||||
| 	if !response.Success { | ||||
| 		return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message) | ||||
| 	} | ||||
|  | ||||
| 	channel.UpdateBalance(response.Data.TotalPoints) | ||||
|  | ||||
| 	return response.Data.TotalPoints, nil | ||||
| } | ||||
							
								
								
									
										20
									
								
								providers/aiproxy/base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								providers/aiproxy/base.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| package aiproxy | ||||
|  | ||||
| import ( | ||||
| 	"one-api/providers/base" | ||||
| 	"one-api/providers/openai" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type AIProxyProviderFactory struct{} | ||||
|  | ||||
| func (f AIProxyProviderFactory) Create(c *gin.Context) base.ProviderInterface { | ||||
| 	return &AIProxyProvider{ | ||||
| 		OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aiproxy.io"), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type AIProxyProvider struct { | ||||
| 	*openai.OpenAIProvider | ||||
| } | ||||
							
								
								
									
										10
									
								
								providers/aiproxy/type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								providers/aiproxy/type.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| package aiproxy | ||||
|  | ||||
| type AIProxyUserOverviewResponse struct { | ||||
| 	Success   bool   `json:"success"` | ||||
| 	Message   string `json:"message"` | ||||
| 	ErrorCode int    `json:"error_code"` | ||||
| 	Data      struct { | ||||
| 		TotalPoints float64 `json:"totalPoints"` | ||||
| 	} `json:"data"` | ||||
| } | ||||
							
								
								
									
										41
									
								
								providers/ali/base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								providers/ali/base.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | ||||
| package ali | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
|  | ||||
| 	"one-api/providers/base" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| // 定义供应商工厂 | ||||
| type AliProviderFactory struct{} | ||||
|  | ||||
| // 创建 AliProvider | ||||
| // https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation | ||||
| func (f AliProviderFactory) Create(c *gin.Context) base.ProviderInterface { | ||||
| 	return &AliProvider{ | ||||
| 		BaseProvider: base.BaseProvider{ | ||||
| 			BaseURL:         "https://dashscope.aliyuncs.com", | ||||
| 			ChatCompletions: "/api/v1/services/aigc/text-generation/generation", | ||||
| 			Embeddings:      "/api/v1/services/embeddings/text-embedding/text-embedding", | ||||
| 			Context:         c, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type AliProvider struct { | ||||
| 	base.BaseProvider | ||||
| } | ||||
|  | ||||
| // 获取请求头 | ||||
| func (p *AliProvider) GetRequestHeaders() (headers map[string]string) { | ||||
| 	headers = make(map[string]string) | ||||
| 	p.CommonRequestHeaders(headers) | ||||
| 	headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key")) | ||||
| 	if p.Context.GetString("plugin") != "" { | ||||
| 		headers["X-DashScope-Plugin"] = p.Context.GetString("plugin") | ||||
| 	} | ||||
|  | ||||
| 	return headers | ||||
| } | ||||
							
								
								
									
										216
									
								
								providers/ali/chat.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										216
									
								
								providers/ali/chat.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,216 @@ | ||||
| package ali | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/types" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // 阿里云响应处理 | ||||
| func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||
| 	if aliResponse.Code != "" { | ||||
| 		errWithCode = &types.OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: types.OpenAIError{ | ||||
| 				Message: aliResponse.Message, | ||||
| 				Type:    aliResponse.Code, | ||||
| 				Param:   aliResponse.RequestId, | ||||
| 				Code:    aliResponse.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		} | ||||
|  | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	choice := types.ChatCompletionChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: types.ChatCompletionMessage{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: aliResponse.Output.Text, | ||||
| 		}, | ||||
| 		FinishReason: aliResponse.Output.FinishReason, | ||||
| 	} | ||||
|  | ||||
| 	OpenAIResponse = types.ChatCompletionResponse{ | ||||
| 		ID:      aliResponse.RequestId, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []types.ChatCompletionChoice{choice}, | ||||
| 		Usage: &types.Usage{ | ||||
| 			PromptTokens:     aliResponse.Usage.InputTokens, | ||||
| 			CompletionTokens: aliResponse.Usage.OutputTokens, | ||||
| 			TotalTokens:      aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // 获取聊天请求体 | ||||
| func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest { | ||||
| 	messages := make([]AliMessage, 0, len(request.Messages)) | ||||
| 	for i := 0; i < len(request.Messages); i++ { | ||||
| 		message := request.Messages[i] | ||||
| 		messages = append(messages, AliMessage{ | ||||
| 			Content: message.StringContent(), | ||||
| 			Role:    strings.ToLower(message.Role), | ||||
| 		}) | ||||
| 	} | ||||
| 	return &AliChatRequest{ | ||||
| 		Model: request.Model, | ||||
| 		Input: AliInput{ | ||||
| 			Messages: messages, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 聊天 | ||||
| func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||
|  | ||||
| 	requestBody := p.getChatRequestBody(request) | ||||
| 	fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) | ||||
| 	headers := p.GetRequestHeaders() | ||||
| 	if request.Stream { | ||||
| 		headers["Accept"] = "text/event-stream" | ||||
| 		headers["X-DashScope-SSE"] = "enable" | ||||
| 	} | ||||
|  | ||||
| 	client := common.NewClient() | ||||
| 	req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) | ||||
| 	if err != nil { | ||||
| 		return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	if request.Stream { | ||||
| 		usage, errWithCode = p.sendStreamRequest(req) | ||||
| 		if errWithCode != nil { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if usage == nil { | ||||
| 			usage = &types.Usage{ | ||||
| 				PromptTokens:     0, | ||||
| 				CompletionTokens: 0, | ||||
| 				TotalTokens:      0, | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 	} else { | ||||
| 		aliResponse := &AliChatResponse{} | ||||
| 		errWithCode = p.SendRequest(req, aliResponse, false) | ||||
| 		if errWithCode != nil { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		usage = &types.Usage{ | ||||
| 			PromptTokens:     aliResponse.Usage.InputTokens, | ||||
| 			CompletionTokens: aliResponse.Usage.OutputTokens, | ||||
| 			TotalTokens:      aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens, | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // 阿里云响应转OpenAI响应 | ||||
| func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse { | ||||
| 	var choice types.ChatCompletionStreamChoice | ||||
| 	choice.Delta.Content = aliResponse.Output.Text | ||||
| 	if aliResponse.Output.FinishReason != "null" { | ||||
| 		finishReason := aliResponse.Output.FinishReason | ||||
| 		choice.FinishReason = &finishReason | ||||
| 	} | ||||
|  | ||||
| 	response := types.ChatCompletionStreamResponse{ | ||||
| 		ID:      aliResponse.RequestId, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "ernie-bot", | ||||
| 		Choices: []types.ChatCompletionStreamChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| // 发送流请求 | ||||
| func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||
| 	defer req.Body.Close() | ||||
|  | ||||
| 	usage = &types.Usage{} | ||||
| 	// 发送请求 | ||||
| 	resp, err := common.HttpClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	if common.IsFailureStatusCode(resp) { | ||||
| 		return nil, common.HandleErrorResp(resp) | ||||
| 	} | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] != "data:" { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[5:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	common.SetEventStreamHeaders(p.Context) | ||||
| 	lastResponseText := "" | ||||
| 	p.Context.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var aliResponse AliChatResponse | ||||
| 			err := json.Unmarshal([]byte(data), &aliResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			if aliResponse.Usage.OutputTokens != 0 { | ||||
| 				usage.PromptTokens = aliResponse.Usage.InputTokens | ||||
| 				usage.CompletionTokens = aliResponse.Usage.OutputTokens | ||||
| 				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | ||||
| 			} | ||||
| 			response := p.streamResponseAli2OpenAI(&aliResponse) | ||||
| 			response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) | ||||
| 			lastResponseText = aliResponse.Output.Text | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										73
									
								
								providers/ali/embeddings.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								providers/ali/embeddings.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,73 @@ | ||||
| package ali | ||||
|  | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/types" | ||||
| ) | ||||
|  | ||||
| // 嵌入请求处理 | ||||
| func (aliResponse *AliEmbeddingResponse) ResponseHandler(resp *http.Response) (any, *types.OpenAIErrorWithStatusCode) { | ||||
| 	if aliResponse.Code != "" { | ||||
| 		return nil, &types.OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: types.OpenAIError{ | ||||
| 				Message: aliResponse.Message, | ||||
| 				Type:    aliResponse.Code, | ||||
| 				Param:   aliResponse.RequestId, | ||||
| 				Code:    aliResponse.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	openAIEmbeddingResponse := &types.EmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]types.Embedding, 0, len(aliResponse.Output.Embeddings)), | ||||
| 		Model:  "text-embedding-v1", | ||||
| 		Usage:  &types.Usage{TotalTokens: aliResponse.Usage.TotalTokens}, | ||||
| 	} | ||||
|  | ||||
| 	for _, item := range aliResponse.Output.Embeddings { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{ | ||||
| 			Object:    `embedding`, | ||||
| 			Index:     item.TextIndex, | ||||
| 			Embedding: item.Embedding, | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	return openAIEmbeddingResponse, nil | ||||
| } | ||||
|  | ||||
| // 获取嵌入请求体 | ||||
| func (p *AliProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest { | ||||
| 	return &AliEmbeddingRequest{ | ||||
| 		Model: "text-embedding-v1", | ||||
| 		Input: struct { | ||||
| 			Texts []string `json:"texts"` | ||||
| 		}{ | ||||
| 			Texts: request.ParseInput(), | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (p *AliProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||
|  | ||||
| 	requestBody := p.getEmbeddingsRequestBody(request) | ||||
| 	fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model) | ||||
| 	headers := p.GetRequestHeaders() | ||||
|  | ||||
| 	client := common.NewClient() | ||||
| 	req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) | ||||
| 	if err != nil { | ||||
| 		return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	aliEmbeddingResponse := &AliEmbeddingResponse{} | ||||
| 	errWithCode = p.SendRequest(req, aliEmbeddingResponse, false) | ||||
| 	if errWithCode != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens} | ||||
|  | ||||
| 	return usage, nil | ||||
| } | ||||
							
								
								
									
										70
									
								
								providers/ali/type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								providers/ali/type.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,70 @@ | ||||
| package ali | ||||
|  | ||||
| type AliError struct { | ||||
| 	Code      string `json:"code"` | ||||
| 	Message   string `json:"message"` | ||||
| 	RequestId string `json:"request_id"` | ||||
| } | ||||
|  | ||||
| type AliUsage struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| 	TotalTokens  int `json:"total_tokens"` | ||||
| } | ||||
|  | ||||
| type AliMessage struct { | ||||
| 	Content string `json:"content"` | ||||
| 	Role    string `json:"role"` | ||||
| } | ||||
|  | ||||
| type AliInput struct { | ||||
| 	// Prompt  string       `json:"prompt"` | ||||
| 	Messages []AliMessage `json:"messages"` | ||||
| } | ||||
|  | ||||
| type AliParameters struct { | ||||
| 	TopP         float64 `json:"top_p,omitempty"` | ||||
| 	TopK         int     `json:"top_k,omitempty"` | ||||
| 	Seed         uint64  `json:"seed,omitempty"` | ||||
| 	EnableSearch bool    `json:"enable_search,omitempty"` | ||||
| } | ||||
|  | ||||
| type AliChatRequest struct { | ||||
| 	Model      string        `json:"model"` | ||||
| 	Input      AliInput      `json:"input"` | ||||
| 	Parameters AliParameters `json:"parameters,omitempty"` | ||||
| } | ||||
|  | ||||
| type AliOutput struct { | ||||
| 	Text         string `json:"text"` | ||||
| 	FinishReason string `json:"finish_reason"` | ||||
| } | ||||
|  | ||||
| type AliChatResponse struct { | ||||
| 	Output AliOutput `json:"output"` | ||||
| 	Usage  AliUsage  `json:"usage"` | ||||
| 	AliError | ||||
| } | ||||
|  | ||||
| type AliEmbeddingRequest struct { | ||||
| 	Model string `json:"model"` | ||||
| 	Input struct { | ||||
| 		Texts []string `json:"texts"` | ||||
| 	} `json:"input"` | ||||
| 	Parameters *struct { | ||||
| 		TextType string `json:"text_type,omitempty"` | ||||
| 	} `json:"parameters,omitempty"` | ||||
| } | ||||
|  | ||||
| type AliEmbedding struct { | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| 	TextIndex int       `json:"text_index"` | ||||
| } | ||||
|  | ||||
| type AliEmbeddingResponse struct { | ||||
| 	Output struct { | ||||
| 		Embeddings []AliEmbedding `json:"embeddings"` | ||||
| 	} `json:"output"` | ||||
| 	Usage AliUsage `json:"usage"` | ||||
| 	AliError | ||||
| } | ||||
							
								
								
									
										30
									
								
								providers/api2d/balance.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								providers/api2d/balance.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | ||||
| package api2d | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"one-api/providers/base" | ||||
| ) | ||||
|  | ||||
| func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) { | ||||
| 	fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "") | ||||
| 	headers := p.GetRequestHeaders() | ||||
|  | ||||
| 	client := common.NewClient() | ||||
| 	req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	// 发送请求 | ||||
| 	var response base.BalanceResponse | ||||
| 	_, errWithCode := common.SendRequest(req, &response, false) | ||||
| 	if errWithCode != nil { | ||||
| 		return 0, errors.New(errWithCode.OpenAIError.Message) | ||||
| 	} | ||||
|  | ||||
| 	channel.UpdateBalance(response.TotalAvailable) | ||||
|  | ||||
| 	return response.TotalAvailable, nil | ||||
| } | ||||
							
								
								
									
										21
									
								
								providers/api2d/base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								providers/api2d/base.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| package api2d | ||||
|  | ||||
| import ( | ||||
| 	"one-api/providers/base" | ||||
| 	"one-api/providers/openai" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type Api2dProviderFactory struct{} | ||||
|  | ||||
| // 创建 Api2dProvider | ||||
| func (f Api2dProviderFactory) Create(c *gin.Context) base.ProviderInterface { | ||||
| 	return &Api2dProvider{ | ||||
| 		OpenAIProvider: openai.CreateOpenAIProvider(c, "https://oa.api2d.net"), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type Api2dProvider struct { | ||||
| 	*openai.OpenAIProvider | ||||
| } | ||||
							
								
								
									
										30
									
								
								providers/api2gpt/balance.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								providers/api2gpt/balance.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | ||||
| package api2gpt | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"one-api/providers/base" | ||||
| ) | ||||
|  | ||||
| func (p *Api2gptProvider) Balance(channel *model.Channel) (float64, error) { | ||||
| 	fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "") | ||||
| 	headers := p.GetRequestHeaders() | ||||
|  | ||||
| 	client := common.NewClient() | ||||
| 	req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	// 发送请求 | ||||
| 	var response base.BalanceResponse | ||||
| 	_, errWithCode := common.SendRequest(req, &response, false) | ||||
| 	if errWithCode != nil { | ||||
| 		return 0, errors.New(errWithCode.OpenAIError.Message) | ||||
| 	} | ||||
|  | ||||
| 	channel.UpdateBalance(response.TotalAvailable) | ||||
|  | ||||
| 	return response.TotalRemaining, nil | ||||
| } | ||||
							
								
								
									
										20
									
								
								providers/api2gpt/base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								providers/api2gpt/base.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| package api2gpt | ||||
|  | ||||
| import ( | ||||
| 	"one-api/providers/base" | ||||
| 	"one-api/providers/openai" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type Api2gptProviderFactory struct{} | ||||
|  | ||||
| func (f Api2gptProviderFactory) Create(c *gin.Context) base.ProviderInterface { | ||||
| 	return &Api2gptProvider{ | ||||
| 		OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.api2gpt.com"), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type Api2gptProvider struct { | ||||
| 	*openai.OpenAIProvider | ||||
| } | ||||
							
								
								
									
										36
									
								
								providers/azure/base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								providers/azure/base.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | ||||
| package azure | ||||
|  | ||||
| import ( | ||||
| 	"one-api/providers/base" | ||||
| 	"one-api/providers/openai" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type AzureProviderFactory struct{} | ||||
|  | ||||
| // 创建 AzureProvider | ||||
| func (f AzureProviderFactory) Create(c *gin.Context) base.ProviderInterface { | ||||
| 	return &AzureProvider{ | ||||
| 		OpenAIProvider: openai.OpenAIProvider{ | ||||
| 			BaseProvider: base.BaseProvider{ | ||||
| 				BaseURL:             "", | ||||
| 				Completions:         "/completions", | ||||
| 				ChatCompletions:     "/chat/completions", | ||||
| 				Embeddings:          "/embeddings", | ||||
| 				AudioTranscriptions: "/audio/transcriptions", | ||||
| 				AudioTranslations:   "/audio/translations", | ||||
| 				ImagesGenerations:   "/images/generations", | ||||
| 				// ImagesEdit:          "/images/edit", | ||||
| 				// ImagesVariations:    "/images/variations", | ||||
| 				Context: c, | ||||
| 				// AudioSpeech:         "/audio/speech", | ||||
| 			}, | ||||
| 			IsAzure: true, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type AzureProvider struct { | ||||
| 	openai.OpenAIProvider | ||||
| } | ||||
							
								
								
									
										102
									
								
								providers/azure/image_generations.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										102
									
								
								providers/azure/image_generations.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,102 @@ | ||||
| package azure | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/providers/openai" | ||||
| 	"one-api/types" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||
| 	if c.Status == "canceled" || c.Status == "failed" { | ||||
| 		errWithCode = &types.OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: types.OpenAIError{ | ||||
| 				Message: c.Error.Message, | ||||
| 				Type:    "one_api_error", | ||||
| 				Code:    c.Error.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	operation_location := resp.Header.Get("operation-location") | ||||
| 	if operation_location == "" { | ||||
| 		return nil, common.ErrorWrapper(errors.New("image url is empty"), "get_images_url_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	client := common.NewClient() | ||||
| 	req, err := client.NewRequest("GET", operation_location, common.WithHeader(c.Header)) | ||||
| 	if err != nil { | ||||
| 		return nil, common.ErrorWrapper(err, "get_images_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	getImageAzureResponse := ImageAzureResponse{} | ||||
| 	for i := 0; i < 3; i++ { | ||||
| 		// 休眠 2 秒 | ||||
| 		time.Sleep(2 * time.Second) | ||||
| 		_, errWithCode = common.SendRequest(req, &getImageAzureResponse, false) | ||||
| 		fmt.Println("getImageAzureResponse", getImageAzureResponse) | ||||
| 		if errWithCode != nil { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if getImageAzureResponse.Status == "canceled" || getImageAzureResponse.Status == "failed" { | ||||
| 			return nil, &types.OpenAIErrorWithStatusCode{ | ||||
| 				OpenAIError: types.OpenAIError{ | ||||
| 					Message: c.Error.Message, | ||||
| 					Type:    "get_images_request_failed", | ||||
| 					Code:    c.Error.Code, | ||||
| 				}, | ||||
| 				StatusCode: resp.StatusCode, | ||||
| 			} | ||||
| 		} | ||||
| 		if getImageAzureResponse.Status == "succeeded" { | ||||
| 			return getImageAzureResponse.Result, nil | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil, common.ErrorWrapper(errors.New("get image Timeout"), "get_images_url_failed", http.StatusInternalServerError) | ||||
| } | ||||
|  | ||||
| func (p *AzureProvider) ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||
|  | ||||
| 	requestBody, err := p.GetRequestBody(&request, isModelMapped) | ||||
| 	if err != nil { | ||||
| 		return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	fullRequestURL := p.GetFullRequestURL(p.ImagesGenerations, request.Model) | ||||
| 	headers := p.GetRequestHeaders() | ||||
|  | ||||
| 	client := common.NewClient() | ||||
| 	req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) | ||||
| 	if err != nil { | ||||
| 		return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	if request.Model == "dall-e-2" { | ||||
| 		imageAzureResponse := &ImageAzureResponse{ | ||||
| 			Header: headers, | ||||
| 		} | ||||
| 		errWithCode = p.SendRequest(req, imageAzureResponse, false) | ||||
| 	} else { | ||||
| 		openAIProviderImageResponseResponse := &openai.OpenAIProviderImageResponseResponse{} | ||||
| 		errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true) | ||||
| 	} | ||||
|  | ||||
| 	if errWithCode != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	usage = &types.Usage{ | ||||
| 		PromptTokens:     promptTokens, | ||||
| 		CompletionTokens: 0, | ||||
| 		TotalTokens:      promptTokens, | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										21
									
								
								providers/azure/type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								providers/azure/type.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| package azure | ||||
|  | ||||
| import "one-api/types" | ||||
|  | ||||
| type ImageAzureResponse struct { | ||||
| 	ID      string              `json:"id,omitempty"` | ||||
| 	Created int64               `json:"created,omitempty"` | ||||
| 	Expires int64               `json:"expires,omitempty"` | ||||
| 	Result  types.ImageResponse `json:"result,omitempty"` | ||||
| 	Status  string              `json:"status,omitempty"` | ||||
| 	Error   ImageAzureError     `json:"error,omitempty"` | ||||
| 	Header  map[string]string   `json:"header,omitempty"` | ||||
| } | ||||
|  | ||||
| type ImageAzureError struct { | ||||
| 	Code       string   `json:"code,omitempty"` | ||||
| 	Target     string   `json:"target,omitempty"` | ||||
| 	Message    string   `json:"message,omitempty"` | ||||
| 	Details    []string `json:"details,omitempty"` | ||||
| 	InnerError any      `json:"innererror,omitempty"` | ||||
| } | ||||
							
								
								
									
										36
									
								
								providers/azureSpeech/base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								providers/azureSpeech/base.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | ||||
| package azureSpeech | ||||
|  | ||||
| import ( | ||||
| 	"one-api/providers/base" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| // 定义供应商工厂 | ||||
| type AzureSpeechProviderFactory struct{} | ||||
|  | ||||
| // 创建 AliProvider | ||||
| func (f AzureSpeechProviderFactory) Create(c *gin.Context) base.ProviderInterface { | ||||
| 	return &AzureSpeechProvider{ | ||||
| 		BaseProvider: base.BaseProvider{ | ||||
| 			BaseURL:     "", | ||||
| 			AudioSpeech: "/cognitiveservices/v1", | ||||
| 			Context:     c, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type AzureSpeechProvider struct { | ||||
| 	base.BaseProvider | ||||
| } | ||||
|  | ||||
| // 获取请求头 | ||||
| func (p *AzureSpeechProvider) GetRequestHeaders() (headers map[string]string) { | ||||
| 	headers = make(map[string]string) | ||||
| 	headers["Ocp-Apim-Subscription-Key"] = p.Context.GetString("api_key") | ||||
| 	headers["Content-Type"] = "application/ssml+xml" | ||||
| 	headers["User-Agent"] = "OneAPI" | ||||
| 	// headers["X-Microsoft-OutputFormat"] = "audio-16khz-128kbitrate-mono-mp3" | ||||
|  | ||||
| 	return headers | ||||
| } | ||||
							
								
								
									
										88
									
								
								providers/azureSpeech/speech.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								providers/azureSpeech/speech.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,88 @@ | ||||
| package azureSpeech | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/types" | ||||
| ) | ||||
|  | ||||
| var outputFormatMap = map[string]string{ | ||||
| 	"mp3":  "audio-16khz-128kbitrate-mono-mp3", | ||||
| 	"opus": "audio-16khz-128kbitrate-mono-opus", | ||||
| 	"aac":  "audio-24khz-160kbitrate-mono-mp3", | ||||
| 	"flac": "audio-48khz-192kbitrate-mono-mp3", | ||||
| } | ||||
|  | ||||
| func CreateSSML(text string, name string, role string) string { | ||||
| 	ssmlTemplate := `<speak version='1.0' xml:lang='en-US'> | ||||
|         <voice xml:lang='en-US' %s name='%s'> | ||||
|             %s | ||||
|         </voice> | ||||
|     </speak>` | ||||
|  | ||||
| 	roleAttribute := "" | ||||
| 	if role != "" { | ||||
| 		roleAttribute = fmt.Sprintf("role='%s'", role) | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Sprintf(ssmlTemplate, roleAttribute, name, text) | ||||
| } | ||||
|  | ||||
| func (p *AzureSpeechProvider) getRequestBody(request *types.SpeechAudioRequest) *bytes.Buffer { | ||||
| 	voiceMap := map[string][]string{ | ||||
| 		"alloy":   {"zh-CN-YunxiNeural"}, | ||||
| 		"echo":    {"zh-CN-YunyangNeural"}, | ||||
| 		"fable":   {"zh-CN-YunxiNeural", "Boy"}, | ||||
| 		"onyx":    {"zh-CN-YunyeNeural"}, | ||||
| 		"nova":    {"zh-CN-XiaochenNeural"}, | ||||
| 		"shimmer": {"zh-CN-XiaohanNeural"}, | ||||
| 	} | ||||
|  | ||||
| 	voice := "" | ||||
| 	role := "" | ||||
| 	if voiceMap[request.Voice] != nil { | ||||
| 		voice = voiceMap[request.Voice][0] | ||||
| 		if len(voiceMap[request.Voice]) > 1 { | ||||
| 			role = voiceMap[request.Voice][1] | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	ssml := CreateSSML(request.Input, voice, role) | ||||
|  | ||||
| 	return bytes.NewBufferString(ssml) | ||||
|  | ||||
| } | ||||
|  | ||||
| func (p *AzureSpeechProvider) SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||
|  | ||||
| 	fullRequestURL := p.GetFullRequestURL(p.AudioSpeech, request.Model) | ||||
| 	headers := p.GetRequestHeaders() | ||||
| 	responseFormatr := outputFormatMap[request.ResponseFormat] | ||||
| 	if responseFormatr == "" { | ||||
| 		responseFormatr = outputFormatMap["mp3"] | ||||
| 	} | ||||
| 	headers["X-Microsoft-OutputFormat"] = responseFormatr | ||||
|  | ||||
| 	requestBody := p.getRequestBody(request) | ||||
|  | ||||
| 	client := common.NewClient() | ||||
| 	req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) | ||||
| 	if err != nil { | ||||
| 		return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	errWithCode = p.SendRequestRaw(req) | ||||
| 	if errWithCode != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	usage = &types.Usage{ | ||||
| 		PromptTokens:     promptTokens, | ||||
| 		CompletionTokens: 0, | ||||
| 		TotalTokens:      promptTokens, | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										129
									
								
								providers/baidu/base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								providers/baidu/base.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,129 @@ | ||||
| package baidu | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"one-api/common" | ||||
| 	"one-api/providers/base" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| // 定义供应商工厂 | ||||
| type BaiduProviderFactory struct{} | ||||
|  | ||||
| // 创建 BaiduProvider | ||||
|  | ||||
| func (f BaiduProviderFactory) Create(c *gin.Context) base.ProviderInterface { | ||||
| 	return &BaiduProvider{ | ||||
| 		BaseProvider: base.BaseProvider{ | ||||
| 			BaseURL:         "https://aip.baidubce.com", | ||||
| 			ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat", | ||||
| 			Embeddings:      "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings", | ||||
| 			Context:         c, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| var baiduTokenStore sync.Map | ||||
|  | ||||
| type BaiduProvider struct { | ||||
| 	base.BaseProvider | ||||
| } | ||||
|  | ||||
| // 获取完整请求 URL | ||||
| func (p *BaiduProvider) GetFullRequestURL(requestURL string, modelName string) string { | ||||
| 	var modelNameMap = map[string]string{ | ||||
| 		"ERNIE-Bot":       "completions", | ||||
| 		"ERNIE-Bot-turbo": "eb-instant", | ||||
| 		"ERNIE-Bot-4":     "completions_pro", | ||||
| 		"BLOOMZ-7B":       "bloomz_7b1", | ||||
| 		"Embedding-V1":    "embedding-v1", | ||||
| 	} | ||||
|  | ||||
| 	baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") | ||||
| 	apiKey, err := p.getBaiduAccessToken() | ||||
| 	if err != nil { | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Sprintf("%s%s/%s?access_token=%s", baseURL, requestURL, modelNameMap[modelName], apiKey) | ||||
| } | ||||
|  | ||||
| // 获取请求头 | ||||
| func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) { | ||||
| 	headers = make(map[string]string) | ||||
| 	p.CommonRequestHeaders(headers) | ||||
|  | ||||
| 	return headers | ||||
| } | ||||
|  | ||||
| func (p *BaiduProvider) getBaiduAccessToken() (string, error) { | ||||
| 	apiKey := p.Context.GetString("api_key") | ||||
| 	if val, ok := baiduTokenStore.Load(apiKey); ok { | ||||
| 		var accessToken BaiduAccessToken | ||||
| 		if accessToken, ok = val.(BaiduAccessToken); ok { | ||||
| 			// soon this will expire | ||||
| 			if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { | ||||
| 				go func() { | ||||
| 					_, _ = p.getBaiduAccessTokenHelper(apiKey) | ||||
| 				}() | ||||
| 			} | ||||
| 			return accessToken.AccessToken, nil | ||||
| 		} | ||||
| 	} | ||||
| 	accessToken, err := p.getBaiduAccessTokenHelper(apiKey) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	if accessToken == nil { | ||||
| 		return "", errors.New("getBaiduAccessToken return a nil token") | ||||
| 	} | ||||
| 	return (*accessToken).AccessToken, nil | ||||
| } | ||||
|  | ||||
| func (p *BaiduProvider) getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { | ||||
| 	parts := strings.Split(apiKey, "|") | ||||
| 	if len(parts) != 2 { | ||||
| 		return nil, errors.New("invalid baidu apikey") | ||||
| 	} | ||||
|  | ||||
| 	client := common.NewClient() | ||||
| 	url := fmt.Sprintf(p.BaseURL+"/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", parts[0], parts[1]) | ||||
|  | ||||
| 	var headers = map[string]string{ | ||||
| 		"Content-Type": "application/json", | ||||
| 		"Accept":       "application/json", | ||||
| 	} | ||||
|  | ||||
| 	req, err := client.NewRequest("POST", url, common.WithHeader(headers)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	resp, err := common.HttpClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	var accessToken BaiduAccessToken | ||||
| 	err = json.NewDecoder(resp.Body).Decode(&accessToken) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if accessToken.Error != "" { | ||||
| 		return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) | ||||
| 	} | ||||
| 	if accessToken.AccessToken == "" { | ||||
| 		return nil, errors.New("getBaiduAccessTokenHelper get empty access token") | ||||
| 	} | ||||
| 	accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second) | ||||
| 	baiduTokenStore.Store(apiKey, accessToken) | ||||
| 	return &accessToken, nil | ||||
| } | ||||
							
								
								
									
										198
									
								
								providers/baidu/chat.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										198
									
								
								providers/baidu/chat.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,198 @@ | ||||
| package baidu | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/providers/base" | ||||
| 	"one-api/types" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func (baiduResponse *BaiduChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||
| 	if baiduResponse.ErrorMsg != "" { | ||||
| 		return nil, &types.OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: types.OpenAIError{ | ||||
| 				Message: baiduResponse.ErrorMsg, | ||||
| 				Type:    "baidu_error", | ||||
| 				Param:   "", | ||||
| 				Code:    baiduResponse.ErrorCode, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	choice := types.ChatCompletionChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: types.ChatCompletionMessage{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: baiduResponse.Result, | ||||
| 		}, | ||||
| 		FinishReason: "stop", | ||||
| 	} | ||||
|  | ||||
| 	OpenAIResponse = types.ChatCompletionResponse{ | ||||
| 		ID:      baiduResponse.Id, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: baiduResponse.Created, | ||||
| 		Choices: []types.ChatCompletionChoice{choice}, | ||||
| 		Usage:   baiduResponse.Usage, | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) *BaiduChatRequest { | ||||
| 	messages := make([]BaiduMessage, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	return &BaiduChatRequest{ | ||||
| 		Messages: messages, | ||||
| 		Stream:   request.Stream, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||
| 	requestBody := p.getChatRequestBody(request) | ||||
| 	fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) | ||||
| 	if fullRequestURL == "" { | ||||
| 		return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	headers := p.GetRequestHeaders() | ||||
| 	if request.Stream { | ||||
| 		headers["Accept"] = "text/event-stream" | ||||
| 	} | ||||
|  | ||||
| 	client := common.NewClient() | ||||
| 	req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) | ||||
| 	if err != nil { | ||||
| 		return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	if request.Stream { | ||||
| 		usage, errWithCode = p.sendStreamRequest(req) | ||||
| 		if errWithCode != nil { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 	} else { | ||||
| 		baiduChatRequest := &BaiduChatResponse{} | ||||
| 		errWithCode = p.SendRequest(req, baiduChatRequest, false) | ||||
| 		if errWithCode != nil { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		usage = baiduChatRequest.Usage | ||||
| 	} | ||||
| 	return | ||||
|  | ||||
| } | ||||
|  | ||||
| func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *types.ChatCompletionStreamResponse { | ||||
| 	var choice types.ChatCompletionStreamChoice | ||||
| 	choice.Delta.Content = baiduResponse.Result | ||||
| 	if baiduResponse.IsEnd { | ||||
| 		choice.FinishReason = &base.StopFinishReason | ||||
| 	} | ||||
|  | ||||
| 	response := types.ChatCompletionStreamResponse{ | ||||
| 		ID:      baiduResponse.Id, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: baiduResponse.Created, | ||||
| 		Model:   "ernie-bot", | ||||
| 		Choices: []types.ChatCompletionStreamChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||
| 	defer req.Body.Close() | ||||
|  | ||||
| 	usage = &types.Usage{} | ||||
| 	// 发送请求 | ||||
| 	resp, err := common.HttpClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	if common.IsFailureStatusCode(resp) { | ||||
| 		return nil, common.HandleErrorResp(resp) | ||||
| 	} | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 6 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[6:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	common.SetEventStreamHeaders(p.Context) | ||||
| 	p.Context.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var baiduResponse BaiduChatStreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &baiduResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			if baiduResponse.Usage.TotalTokens != 0 { | ||||
| 				usage.TotalTokens = baiduResponse.Usage.TotalTokens | ||||
| 				usage.PromptTokens = baiduResponse.Usage.PromptTokens | ||||
| 				usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens | ||||
| 			} | ||||
| 			response := p.streamResponseBaidu2OpenAI(&baiduResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	return usage, nil | ||||
| } | ||||
							
								
								
									
										69
									
								
								providers/baidu/embeddings.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								providers/baidu/embeddings.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| package baidu | ||||
|  | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/types" | ||||
| ) | ||||
|  | ||||
| func (p *BaiduProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *BaiduEmbeddingRequest { | ||||
| 	return &BaiduEmbeddingRequest{ | ||||
| 		Input: request.ParseInput(), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (baiduResponse *BaiduEmbeddingResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||
| 	if baiduResponse.ErrorMsg != "" { | ||||
| 		return nil, &types.OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: types.OpenAIError{ | ||||
| 				Message: baiduResponse.ErrorMsg, | ||||
| 				Type:    "baidu_error", | ||||
| 				Param:   "", | ||||
| 				Code:    baiduResponse.ErrorCode, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	openAIEmbeddingResponse := &types.EmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]types.Embedding, 0, len(baiduResponse.Data)), | ||||
| 		Model:  "text-embedding-v1", | ||||
| 		Usage:  &baiduResponse.Usage, | ||||
| 	} | ||||
|  | ||||
| 	for _, item := range baiduResponse.Data { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{ | ||||
| 			Object:    item.Object, | ||||
| 			Index:     item.Index, | ||||
| 			Embedding: item.Embedding, | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	return openAIEmbeddingResponse, nil | ||||
| } | ||||
|  | ||||
| func (p *BaiduProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||
|  | ||||
| 	requestBody := p.getEmbeddingsRequestBody(request) | ||||
| 	fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model) | ||||
| 	if fullRequestURL == "" { | ||||
| 		return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	headers := p.GetRequestHeaders() | ||||
| 	client := common.NewClient() | ||||
| 	req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) | ||||
| 	if err != nil { | ||||
| 		return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	baiduEmbeddingResponse := &BaiduEmbeddingResponse{} | ||||
| 	errWithCode = p.SendRequest(req, baiduEmbeddingResponse, false) | ||||
| 	if errWithCode != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	usage = &baiduEmbeddingResponse.Usage | ||||
|  | ||||
| 	return usage, nil | ||||
| } | ||||
							
								
								
									
										66
									
								
								providers/baidu/type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								providers/baidu/type.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | ||||
| package baidu | ||||
|  | ||||
| import ( | ||||
| 	"one-api/types" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type BaiduAccessToken struct { | ||||
| 	AccessToken      string    `json:"access_token"` | ||||
| 	Error            string    `json:"error,omitempty"` | ||||
| 	ErrorDescription string    `json:"error_description,omitempty"` | ||||
| 	ExpiresIn        int64     `json:"expires_in,omitempty"` | ||||
| 	ExpiresAt        time.Time `json:"-"` | ||||
| } | ||||
|  | ||||
| type BaiduMessage struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type BaiduChatRequest struct { | ||||
| 	Messages []BaiduMessage `json:"messages"` | ||||
| 	Stream   bool           `json:"stream"` | ||||
| 	UserId   string         `json:"user_id,omitempty"` | ||||
| } | ||||
|  | ||||
| type BaiduChatResponse struct { | ||||
| 	Id               string       `json:"id"` | ||||
| 	Object           string       `json:"object"` | ||||
| 	Created          int64        `json:"created"` | ||||
| 	Result           string       `json:"result"` | ||||
| 	IsTruncated      bool         `json:"is_truncated"` | ||||
| 	NeedClearHistory bool         `json:"need_clear_history"` | ||||
| 	Usage            *types.Usage `json:"usage"` | ||||
| 	BaiduError | ||||
| } | ||||
|  | ||||
| type BaiduEmbeddingRequest struct { | ||||
| 	Input []string `json:"input"` | ||||
| } | ||||
|  | ||||
| type BaiduEmbeddingData struct { | ||||
| 	Object    string    `json:"object"` | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| 	Index     int       `json:"index"` | ||||
| } | ||||
|  | ||||
| type BaiduEmbeddingResponse struct { | ||||
| 	Id      string               `json:"id"` | ||||
| 	Object  string               `json:"object"` | ||||
| 	Created int64                `json:"created"` | ||||
| 	Data    []BaiduEmbeddingData `json:"data"` | ||||
| 	Usage   types.Usage          `json:"usage"` | ||||
| 	BaiduError | ||||
| } | ||||
|  | ||||
| type BaiduChatStreamResponse struct { | ||||
| 	BaiduChatResponse | ||||
| 	SentenceId int  `json:"sentence_id"` | ||||
| 	IsEnd      bool `json:"is_end"` | ||||
| } | ||||
|  | ||||
| type BaiduError struct { | ||||
| 	ErrorCode int    `json:"error_code"` | ||||
| 	ErrorMsg  string `json:"error_msg"` | ||||
| } | ||||
							
								
								
									
										156
									
								
								providers/base/common.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										156
									
								
								providers/base/common.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,156 @@ | ||||
| package base | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/types" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| var StopFinishReason = "stop" | ||||
|  | ||||
| type BaseProvider struct { | ||||
| 	BaseURL             string | ||||
| 	Completions         string | ||||
| 	ChatCompletions     string | ||||
| 	Embeddings          string | ||||
| 	AudioSpeech         string | ||||
| 	Moderation          string | ||||
| 	AudioTranscriptions string | ||||
| 	AudioTranslations   string | ||||
| 	ImagesGenerations   string | ||||
| 	ImagesEdit          string | ||||
| 	ImagesVariations    string | ||||
| 	Proxy               string | ||||
| 	Context             *gin.Context | ||||
| } | ||||
|  | ||||
| // 获取基础URL | ||||
| func (p *BaseProvider) GetBaseURL() string { | ||||
| 	if p.Context.GetString("base_url") != "" { | ||||
| 		return p.Context.GetString("base_url") | ||||
| 	} | ||||
|  | ||||
| 	return p.BaseURL | ||||
| } | ||||
|  | ||||
| // 获取完整请求URL | ||||
| func (p *BaseProvider) GetFullRequestURL(requestURL string, modelName string) string { | ||||
| 	baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") | ||||
|  | ||||
| 	return fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
| } | ||||
|  | ||||
| // 获取请求头 | ||||
| func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) { | ||||
| 	headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type") | ||||
| 	headers["Accept"] = p.Context.Request.Header.Get("Accept") | ||||
| 	if headers["Content-Type"] == "" { | ||||
| 		headers["Content-Type"] = "application/json" | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 发送请求 | ||||
| func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler, rawOutput bool) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { | ||||
| 	defer req.Body.Close() | ||||
|  | ||||
| 	resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true) | ||||
| 	if openAIErrorWithStatusCode != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	openAIResponse, openAIErrorWithStatusCode := response.ResponseHandler(resp) | ||||
| 	if openAIErrorWithStatusCode != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if rawOutput { | ||||
| 		for k, v := range resp.Header { | ||||
| 			p.Context.Writer.Header().Set(k, v[0]) | ||||
| 		} | ||||
|  | ||||
| 		p.Context.Writer.WriteHeader(resp.StatusCode) | ||||
| 		_, err := io.Copy(p.Context.Writer, resp.Body) | ||||
| 		if err != nil { | ||||
| 			return common.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 	} else { | ||||
| 		jsonResponse, err := json.Marshal(openAIResponse) | ||||
| 		if err != nil { | ||||
| 			return common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		p.Context.Writer.Header().Set("Content-Type", "application/json") | ||||
| 		p.Context.Writer.WriteHeader(resp.StatusCode) | ||||
| 		_, err = p.Context.Writer.Write(jsonResponse) | ||||
|  | ||||
| 		if err != nil { | ||||
| 			return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { | ||||
| 	defer req.Body.Close() | ||||
|  | ||||
| 	// 发送请求 | ||||
| 	resp, err := common.HttpClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	// 处理响应 | ||||
| 	if common.IsFailureStatusCode(resp) { | ||||
| 		return common.HandleErrorResp(resp) | ||||
| 	} | ||||
|  | ||||
| 	for k, v := range resp.Header { | ||||
| 		p.Context.Writer.Header().Set(k, v[0]) | ||||
| 	} | ||||
|  | ||||
| 	p.Context.Writer.WriteHeader(resp.StatusCode) | ||||
|  | ||||
| 	_, err = io.Copy(p.Context.Writer, resp.Body) | ||||
| 	if err != nil { | ||||
| 		return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (p *BaseProvider) SupportAPI(relayMode int) bool { | ||||
| 	switch relayMode { | ||||
| 	case common.RelayModeChatCompletions: | ||||
| 		return p.ChatCompletions != "" | ||||
| 	case common.RelayModeCompletions: | ||||
| 		return p.Completions != "" | ||||
| 	case common.RelayModeEmbeddings: | ||||
| 		return p.Embeddings != "" | ||||
| 	case common.RelayModeAudioSpeech: | ||||
| 		return p.AudioSpeech != "" | ||||
| 	case common.RelayModeAudioTranscription: | ||||
| 		return p.AudioTranscriptions != "" | ||||
| 	case common.RelayModeAudioTranslation: | ||||
| 		return p.AudioTranslations != "" | ||||
| 	case common.RelayModeModerations: | ||||
| 		return p.Moderation != "" | ||||
| 	case common.RelayModeImagesGenerations: | ||||
| 		return p.ImagesGenerations != "" | ||||
| 	case common.RelayModeImagesEdits: | ||||
| 		return p.ImagesEdit != "" | ||||
| 	case common.RelayModeImagesVariations: | ||||
| 		return p.ImagesVariations != "" | ||||
| 	default: | ||||
| 		return false | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										84
									
								
								providers/base/interface.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								providers/base/interface.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,84 @@ | ||||
| package base | ||||
|  | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"one-api/model" | ||||
| 	"one-api/types" | ||||
| ) | ||||
|  | ||||
| // 基础接口 | ||||
| type ProviderInterface interface { | ||||
| 	GetBaseURL() string | ||||
| 	GetFullRequestURL(requestURL string, modelName string) string | ||||
| 	GetRequestHeaders() (headers map[string]string) | ||||
| 	SupportAPI(relayMode int) bool | ||||
| } | ||||
|  | ||||
| // 完成接口 | ||||
| type CompletionInterface interface { | ||||
| 	ProviderInterface | ||||
| 	CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) | ||||
| } | ||||
|  | ||||
| // 聊天接口 | ||||
| type ChatInterface interface { | ||||
| 	ProviderInterface | ||||
| 	ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) | ||||
| } | ||||
|  | ||||
| // 嵌入接口 | ||||
| type EmbeddingsInterface interface { | ||||
| 	ProviderInterface | ||||
| 	EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) | ||||
| } | ||||
|  | ||||
| // 审查接口 | ||||
| type ModerationInterface interface { | ||||
| 	ProviderInterface | ||||
| 	ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) | ||||
| } | ||||
|  | ||||
| // 文字转语音接口 | ||||
| type SpeechInterface interface { | ||||
| 	ProviderInterface | ||||
| 	SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) | ||||
| } | ||||
|  | ||||
| // 语音转文字接口 | ||||
| type TranscriptionsInterface interface { | ||||
| 	ProviderInterface | ||||
| 	TranscriptionsAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) | ||||
| } | ||||
|  | ||||
| // 语音翻译接口 | ||||
| type TranslationInterface interface { | ||||
| 	ProviderInterface | ||||
| 	TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) | ||||
| } | ||||
|  | ||||
| // 图片生成接口 | ||||
| type ImageGenerationsInterface interface { | ||||
| 	ProviderInterface | ||||
| 	ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) | ||||
| } | ||||
|  | ||||
| // 图片编辑接口 | ||||
| type ImageEditsInterface interface { | ||||
| 	ProviderInterface | ||||
| 	ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) | ||||
| } | ||||
|  | ||||
| type ImageVariationsInterface interface { | ||||
| 	ProviderInterface | ||||
| 	ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) | ||||
| } | ||||
|  | ||||
| // 余额接口 | ||||
| type BalanceInterface interface { | ||||
| 	Balance(channel *model.Channel) (float64, error) | ||||
| } | ||||
|  | ||||
| type ProviderResponseHandler interface { | ||||
| 	// 响应处理函数 | ||||
| 	ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) | ||||
| } | ||||
							
								
								
									
										9
									
								
								providers/base/type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								providers/base/type.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | ||||
| package base | ||||
|  | ||||
| type BalanceResponse struct { | ||||
| 	Object         string  `json:"object"` | ||||
| 	TotalGranted   float64 `json:"total_granted"` | ||||
| 	TotalUsed      float64 `json:"total_used"` | ||||
| 	TotalRemaining float64 `json:"total_remaining"` | ||||
| 	TotalAvailable float64 `json:"total_available"` | ||||
| } | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user