mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-26 11:23:43 +08:00 
			
		
		
		
	Compare commits
	
		
			58 Commits
		
	
	
		
			v0.6.2-alp
			...
			v0.5.10-3
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | c5aa59e297 | ||
|  | 211a862d54 | ||
|  | c4c89e8e1b | ||
|  | 72983ac734 | ||
|  | 4d43dce64b | ||
|  | 0fa94d3c94 | ||
|  | 002dba5a75 | ||
|  | fb24d024a7 | ||
|  | eeb867da10 | ||
|  | 47b72b850f | ||
|  | be613883a1 | ||
|  | f823581235 | ||
|  | 89e6b9fe33 | ||
|  | 5a8fef00e5 | ||
|  | fe72f85554 | ||
|  | 3ac0b256e3 | ||
|  | b0fefd6dc5 | ||
|  | 43d8bedbb4 | ||
|  | c60f755715 | ||
|  | a4138aec1a | ||
|  | ffa4e491ea | ||
|  | 365744a040 | ||
|  | 8bcaf182bc | ||
|  | 045e2fa139 | ||
|  | a884c4b0bf | ||
|  | c97c8a0f65 | ||
|  | 58fc40a744 | ||
|  | da87fca2a2 | ||
|  | 5e08cc8719 | ||
|  | d8b13b2c07 | ||
|  | be364ae09b | ||
|  | 2114bc1982 | ||
|  | 0f038d715d | ||
|  | 9dd92bbddd | ||
|  | 5b70ee3407 | ||
|  | 17027fb61e | ||
|  | a013b1a166 | ||
|  | 7c6dee7390 | ||
|  | 96dc7614e6 | ||
|  | 1c7c2d40bb | ||
|  | 455269c145 | ||
|  | 544f20cc73 | ||
|  | 902c2faa2c | ||
|  | 53da7134b2 | ||
|  | 1fa1c66f13 | ||
|  | 341c21e4cb | ||
|  | fe56aa1a46 | ||
|  | f0e2ba0318 | ||
|  | 43e7b465cb | ||
|  | 4f245bf738 | ||
|  | 56b3c939bf | ||
|  | 257135f676 | ||
|  | 84784ffccc | ||
|  | ef18eb9f93 | ||
|  | 12499aaf69 | ||
|  | 1e17944e4a | ||
|  | 28c29283c5 | ||
|  | cb3e9b8277 | 
							
								
								
									
										46
									
								
								.air.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								.air.toml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | |||||||
|  | root = "." | ||||||
|  | testdata_dir = "testdata" | ||||||
|  | tmp_dir = "tmp" | ||||||
|  |  | ||||||
|  | [build] | ||||||
|  | args_bin = [] | ||||||
|  | bin = "./tmp/main" | ||||||
|  | cmd = "go build -o ./tmp/main ." | ||||||
|  | delay = 1000 | ||||||
|  | exclude_dir = ["assets", "tmp", "vendor", "testdata", "web"] | ||||||
|  | exclude_file = [] | ||||||
|  | exclude_regex = ["_test.go"] | ||||||
|  | exclude_unchanged = false | ||||||
|  | follow_symlink = false | ||||||
|  | full_bin = "" | ||||||
|  | include_dir = [] | ||||||
|  | include_ext = ["go", "tpl", "tmpl", "html"] | ||||||
|  | include_file = [] | ||||||
|  | kill_delay = "0s" | ||||||
|  | log = "build-errors.log" | ||||||
|  | poll = false | ||||||
|  | poll_interval = 0 | ||||||
|  | post_cmd = [] | ||||||
|  | pre_cmd = [] | ||||||
|  | rerun = false | ||||||
|  | rerun_delay = 500 | ||||||
|  | send_interrupt = false | ||||||
|  | stop_on_error = false | ||||||
|  |  | ||||||
|  | [color] | ||||||
|  | app = "" | ||||||
|  | build = "yellow" | ||||||
|  | main = "magenta" | ||||||
|  | runner = "green" | ||||||
|  | watcher = "cyan" | ||||||
|  |  | ||||||
|  | [log] | ||||||
|  | main_only = false | ||||||
|  | time = false | ||||||
|  |  | ||||||
|  | [misc] | ||||||
|  | clean_on_exit = false | ||||||
|  |  | ||||||
|  | [screen] | ||||||
|  | clear_on_rebuild = false | ||||||
|  | keep_scroll = true | ||||||
							
								
								
									
										49
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										49
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							| @@ -1,49 +0,0 @@ | |||||||
| name: Publish Docker image (amd64, English) |  | ||||||
|  |  | ||||||
| on: |  | ||||||
|   push: |  | ||||||
|     tags: |  | ||||||
|       - '*' |  | ||||||
|   workflow_dispatch: |  | ||||||
|     inputs: |  | ||||||
|       name: |  | ||||||
|         description: 'reason' |  | ||||||
|         required: false |  | ||||||
| jobs: |  | ||||||
|   push_to_registries: |  | ||||||
|     name: Push Docker image to multiple registries |  | ||||||
|     runs-on: ubuntu-latest |  | ||||||
|     permissions: |  | ||||||
|       packages: write |  | ||||||
|       contents: read |  | ||||||
|     steps: |  | ||||||
|       - name: Check out the repo |  | ||||||
|         uses: actions/checkout@v3 |  | ||||||
|  |  | ||||||
|       - name: Save version info |  | ||||||
|         run: | |  | ||||||
|           git describe --tags > VERSION  |  | ||||||
|  |  | ||||||
|       - name: Translate |  | ||||||
|         run: | |  | ||||||
|           python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json |  | ||||||
|       - name: Log in to Docker Hub |  | ||||||
|         uses: docker/login-action@v2 |  | ||||||
|         with: |  | ||||||
|           username: ${{ secrets.DOCKERHUB_USERNAME }} |  | ||||||
|           password: ${{ secrets.DOCKERHUB_TOKEN }} |  | ||||||
|  |  | ||||||
|       - name: Extract metadata (tags, labels) for Docker |  | ||||||
|         id: meta |  | ||||||
|         uses: docker/metadata-action@v4 |  | ||||||
|         with: |  | ||||||
|           images: | |  | ||||||
|             justsong/one-api-en |  | ||||||
|  |  | ||||||
|       - name: Build and push Docker images |  | ||||||
|         uses: docker/build-push-action@v3 |  | ||||||
|         with: |  | ||||||
|           context: . |  | ||||||
|           push: true |  | ||||||
|           tags: ${{ steps.meta.outputs.tags }} |  | ||||||
|           labels: ${{ steps.meta.outputs.labels }} |  | ||||||
							
								
								
									
										54
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										54
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -1,54 +0,0 @@ | |||||||
| name: Publish Docker image (amd64) |  | ||||||
|  |  | ||||||
| on: |  | ||||||
|   push: |  | ||||||
|     tags: |  | ||||||
|       - '*' |  | ||||||
|   workflow_dispatch: |  | ||||||
|     inputs: |  | ||||||
|       name: |  | ||||||
|         description: 'reason' |  | ||||||
|         required: false |  | ||||||
| jobs: |  | ||||||
|   push_to_registries: |  | ||||||
|     name: Push Docker image to multiple registries |  | ||||||
|     runs-on: ubuntu-latest |  | ||||||
|     permissions: |  | ||||||
|       packages: write |  | ||||||
|       contents: read |  | ||||||
|     steps: |  | ||||||
|       - name: Check out the repo |  | ||||||
|         uses: actions/checkout@v3 |  | ||||||
|  |  | ||||||
|       - name: Save version info |  | ||||||
|         run: | |  | ||||||
|           git describe --tags > VERSION  |  | ||||||
|  |  | ||||||
|       - name: Log in to Docker Hub |  | ||||||
|         uses: docker/login-action@v2 |  | ||||||
|         with: |  | ||||||
|           username: ${{ secrets.DOCKERHUB_USERNAME }} |  | ||||||
|           password: ${{ secrets.DOCKERHUB_TOKEN }} |  | ||||||
|  |  | ||||||
|       - name: Log in to the Container registry |  | ||||||
|         uses: docker/login-action@v2 |  | ||||||
|         with: |  | ||||||
|           registry: ghcr.io |  | ||||||
|           username: ${{ github.actor }} |  | ||||||
|           password: ${{ secrets.GITHUB_TOKEN }} |  | ||||||
|  |  | ||||||
|       - name: Extract metadata (tags, labels) for Docker |  | ||||||
|         id: meta |  | ||||||
|         uses: docker/metadata-action@v4 |  | ||||||
|         with: |  | ||||||
|           images: | |  | ||||||
|             justsong/one-api |  | ||||||
|             ghcr.io/${{ github.repository }} |  | ||||||
|  |  | ||||||
|       - name: Build and push Docker images |  | ||||||
|         uses: docker/build-push-action@v3 |  | ||||||
|         with: |  | ||||||
|           context: . |  | ||||||
|           push: true |  | ||||||
|           tags: ${{ steps.meta.outputs.tags }} |  | ||||||
|           labels: ${{ steps.meta.outputs.labels }} |  | ||||||
							
								
								
									
										62
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										62
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -1,62 +0,0 @@ | |||||||
| name: Publish Docker image (arm64) |  | ||||||
|  |  | ||||||
| on: |  | ||||||
|   push: |  | ||||||
|     tags: |  | ||||||
|       - '*' |  | ||||||
|       - '!*-alpha*' |  | ||||||
|   workflow_dispatch: |  | ||||||
|     inputs: |  | ||||||
|       name: |  | ||||||
|         description: 'reason' |  | ||||||
|         required: false |  | ||||||
| jobs: |  | ||||||
|   push_to_registries: |  | ||||||
|     name: Push Docker image to multiple registries |  | ||||||
|     runs-on: ubuntu-latest |  | ||||||
|     permissions: |  | ||||||
|       packages: write |  | ||||||
|       contents: read |  | ||||||
|     steps: |  | ||||||
|       - name: Check out the repo |  | ||||||
|         uses: actions/checkout@v3 |  | ||||||
|  |  | ||||||
|       - name: Save version info |  | ||||||
|         run: | |  | ||||||
|           git describe --tags > VERSION  |  | ||||||
|  |  | ||||||
|       - name: Set up QEMU |  | ||||||
|         uses: docker/setup-qemu-action@v2 |  | ||||||
|  |  | ||||||
|       - name: Set up Docker Buildx |  | ||||||
|         uses: docker/setup-buildx-action@v2 |  | ||||||
|  |  | ||||||
|       - name: Log in to Docker Hub |  | ||||||
|         uses: docker/login-action@v2 |  | ||||||
|         with: |  | ||||||
|           username: ${{ secrets.DOCKERHUB_USERNAME }} |  | ||||||
|           password: ${{ secrets.DOCKERHUB_TOKEN }} |  | ||||||
|  |  | ||||||
|       - name: Log in to the Container registry |  | ||||||
|         uses: docker/login-action@v2 |  | ||||||
|         with: |  | ||||||
|           registry: ghcr.io |  | ||||||
|           username: ${{ github.actor }} |  | ||||||
|           password: ${{ secrets.GITHUB_TOKEN }} |  | ||||||
|  |  | ||||||
|       - name: Extract metadata (tags, labels) for Docker |  | ||||||
|         id: meta |  | ||||||
|         uses: docker/metadata-action@v4 |  | ||||||
|         with: |  | ||||||
|           images: | |  | ||||||
|             justsong/one-api |  | ||||||
|             ghcr.io/${{ github.repository }} |  | ||||||
|  |  | ||||||
|       - name: Build and push Docker images |  | ||||||
|         uses: docker/build-push-action@v3 |  | ||||||
|         with: |  | ||||||
|           context: . |  | ||||||
|           platforms: linux/amd64,linux/arm64 |  | ||||||
|           push: true |  | ||||||
|           tags: ${{ steps.meta.outputs.tags }} |  | ||||||
|           labels: ${{ steps.meta.outputs.labels }} |  | ||||||
							
								
								
									
										62
									
								
								.github/workflows/docker-image.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								.github/workflows/docker-image.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,62 @@ | |||||||
|  | name: one-api docker image | ||||||
|  |  | ||||||
|  | on: | ||||||
|  |   push: | ||||||
|  |     branches: | ||||||
|  |       - main | ||||||
|  |     tags: | ||||||
|  |       - "v*" | ||||||
|  |  | ||||||
|  | env: | ||||||
|  |   # github.repository as <account>/<repo> | ||||||
|  |   IMAGE_NAME: martialbe/one-api | ||||||
|  |  | ||||||
|  | jobs: | ||||||
|  |   build-and-push: | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |     permissions: | ||||||
|  |       packages: write | ||||||
|  |       contents: read | ||||||
|  |     steps: | ||||||
|  |       - name: Check out the repo | ||||||
|  |         uses: actions/checkout@v3 | ||||||
|  |         with: | ||||||
|  |           fetch-depth: 0 | ||||||
|  |  | ||||||
|  |       - name: Set up QEMU | ||||||
|  |         uses: docker/setup-qemu-action@v2 | ||||||
|  |  | ||||||
|  |       - name: Set up Docker Buildx | ||||||
|  |         uses: docker/setup-buildx-action@v2 | ||||||
|  |  | ||||||
|  |       - name: Login to GHCR | ||||||
|  |         uses: docker/login-action@v2 | ||||||
|  |         with: | ||||||
|  |           registry: ghcr.io | ||||||
|  |           username: ${{ github.repository_owner }} | ||||||
|  |           password: ${{ secrets.GT_Token }} | ||||||
|  |  | ||||||
|  |       - name: Docker meta | ||||||
|  |         id: meta | ||||||
|  |         uses: docker/metadata-action@v4 | ||||||
|  |         with: | ||||||
|  |           # list of Docker images to use as base name for tags | ||||||
|  |           images: ghcr.io/${{ env.IMAGE_NAME }} | ||||||
|  |           # generate Docker tags based on the following events/attributes | ||||||
|  |           tags: | | ||||||
|  |             type=raw,value=dev,enable=${{ github.ref == 'refs/heads/main' }} | ||||||
|  |             type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }} | ||||||
|  |             type=pep440,pattern={{raw}},enable=${{ startsWith(github.ref, 'refs/tags/') }} | ||||||
|  |  | ||||||
|  |       - name: Build and push | ||||||
|  |         uses: docker/build-push-action@v4 | ||||||
|  |         with: | ||||||
|  |           context: . | ||||||
|  |           platforms: linux/amd64 | ||||||
|  |           build-args: | | ||||||
|  |             COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }} | ||||||
|  |           push: true | ||||||
|  |           tags: ${{ steps.meta.outputs.tags }} | ||||||
|  |           labels: ${{ steps.meta.outputs.labels }} | ||||||
|  |           cache-from: type=gha | ||||||
|  |           cache-to: type=gha,mode=max | ||||||
							
								
								
									
										19
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										19
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -5,13 +5,8 @@ permissions: | |||||||
| on: | on: | ||||||
|   push: |   push: | ||||||
|     tags: |     tags: | ||||||
|       - '*' |       - "*" | ||||||
|       - '!*-alpha*' |       - "!*-alpha*" | ||||||
|   workflow_dispatch: |  | ||||||
|     inputs: |  | ||||||
|       name: |  | ||||||
|         description: 'reason' |  | ||||||
|         required: false |  | ||||||
| jobs: | jobs: | ||||||
|   release: |   release: | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
| @@ -28,17 +23,17 @@ jobs: | |||||||
|           CI: "" |           CI: "" | ||||||
|         run: | |         run: | | ||||||
|           cd web |           cd web | ||||||
|           git describe --tags > VERSION |           npm install | ||||||
|           REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh |           REACT_APP_VERSION=$(git describe --tags) npm run build | ||||||
|           cd .. |           cd .. | ||||||
|       - name: Set up Go |       - name: Set up Go | ||||||
|         uses: actions/setup-go@v3 |         uses: actions/setup-go@v3 | ||||||
|         with: |         with: | ||||||
|           go-version: '>=1.18.0' |           go-version: ">=1.18.0" | ||||||
|       - name: Build Backend (amd64) |       - name: Build Backend (amd64) | ||||||
|         run: | |         run: | | ||||||
|           go mod download |           go mod download | ||||||
|           go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api |           go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api | ||||||
|  |  | ||||||
|       - name: Build Backend (arm64) |       - name: Build Backend (arm64) | ||||||
|         run: | |         run: | | ||||||
| @@ -56,4 +51,4 @@ jobs: | |||||||
|           draft: true |           draft: true | ||||||
|           generate_release_notes: true |           generate_release_notes: true | ||||||
|         env: |         env: | ||||||
|           GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} |           GITHUB_TOKEN: ${{ secrets.GT_Token }} | ||||||
|   | |||||||
							
								
								
									
										19
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										19
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -5,13 +5,8 @@ permissions: | |||||||
| on: | on: | ||||||
|   push: |   push: | ||||||
|     tags: |     tags: | ||||||
|       - '*' |       - "*" | ||||||
|       - '!*-alpha*' |       - "!*-alpha*" | ||||||
|   workflow_dispatch: |  | ||||||
|     inputs: |  | ||||||
|       name: |  | ||||||
|         description: 'reason' |  | ||||||
|         required: false |  | ||||||
| jobs: | jobs: | ||||||
|   release: |   release: | ||||||
|     runs-on: macos-latest |     runs-on: macos-latest | ||||||
| @@ -28,17 +23,17 @@ jobs: | |||||||
|           CI: "" |           CI: "" | ||||||
|         run: | |         run: | | ||||||
|           cd web |           cd web | ||||||
|           git describe --tags > VERSION |           npm install | ||||||
|           REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh |           REACT_APP_VERSION=$(git describe --tags) npm run build | ||||||
|           cd .. |           cd .. | ||||||
|       - name: Set up Go |       - name: Set up Go | ||||||
|         uses: actions/setup-go@v3 |         uses: actions/setup-go@v3 | ||||||
|         with: |         with: | ||||||
|           go-version: '>=1.18.0' |           go-version: ">=1.18.0" | ||||||
|       - name: Build Backend |       - name: Build Backend | ||||||
|         run: | |         run: | | ||||||
|           go mod download |           go mod download | ||||||
|           go build -ldflags "-X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api-macos |           go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos | ||||||
|       - name: Release |       - name: Release | ||||||
|         uses: softprops/action-gh-release@v1 |         uses: softprops/action-gh-release@v1 | ||||||
|         if: startsWith(github.ref, 'refs/tags/') |         if: startsWith(github.ref, 'refs/tags/') | ||||||
| @@ -47,4 +42,4 @@ jobs: | |||||||
|           draft: true |           draft: true | ||||||
|           generate_release_notes: true |           generate_release_notes: true | ||||||
|         env: |         env: | ||||||
|           GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} |           GITHUB_TOKEN: ${{ secrets.GT_Token }} | ||||||
|   | |||||||
							
								
								
									
										19
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										19
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -5,13 +5,8 @@ permissions: | |||||||
| on: | on: | ||||||
|   push: |   push: | ||||||
|     tags: |     tags: | ||||||
|       - '*' |       - "*" | ||||||
|       - '!*-alpha*' |       - "!*-alpha*" | ||||||
|   workflow_dispatch: |  | ||||||
|     inputs: |  | ||||||
|       name: |  | ||||||
|         description: 'reason' |  | ||||||
|         required: false |  | ||||||
| jobs: | jobs: | ||||||
|   release: |   release: | ||||||
|     runs-on: windows-latest |     runs-on: windows-latest | ||||||
| @@ -30,18 +25,18 @@ jobs: | |||||||
|         env: |         env: | ||||||
|           CI: "" |           CI: "" | ||||||
|         run: | |         run: | | ||||||
|           cd web/default |           cd web | ||||||
|           npm install |           npm install | ||||||
|           REACT_APP_VERSION=$(git describe --tags) npm run build |           REACT_APP_VERSION=$(git describe --tags) npm run build | ||||||
|           cd ../.. |           cd .. | ||||||
|       - name: Set up Go |       - name: Set up Go | ||||||
|         uses: actions/setup-go@v3 |         uses: actions/setup-go@v3 | ||||||
|         with: |         with: | ||||||
|           go-version: '>=1.18.0' |           go-version: ">=1.18.0" | ||||||
|       - name: Build Backend |       - name: Build Backend | ||||||
|         run: | |         run: | | ||||||
|           go mod download |           go mod download | ||||||
|           go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api.exe |           go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe | ||||||
|       - name: Release |       - name: Release | ||||||
|         uses: softprops/action-gh-release@v1 |         uses: softprops/action-gh-release@v1 | ||||||
|         if: startsWith(github.ref, 'refs/tags/') |         if: startsWith(github.ref, 'refs/tags/') | ||||||
| @@ -50,4 +45,4 @@ jobs: | |||||||
|           draft: true |           draft: true | ||||||
|           generate_release_notes: true |           generate_release_notes: true | ||||||
|         env: |         env: | ||||||
|           GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} |           GITHUB_TOKEN: ${{ secrets.GT_Token }} | ||||||
|   | |||||||
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -7,4 +7,5 @@ build | |||||||
| *.db-journal | *.db-journal | ||||||
| logs | logs | ||||||
| data | data | ||||||
| /web/node_modules | tmp/ | ||||||
|  | .env | ||||||
							
								
								
									
										19
									
								
								Dockerfile
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								Dockerfile
									
									
									
									
									
								
							| @@ -1,15 +1,10 @@ | |||||||
| FROM node:16 as builder | FROM node:16 as builder | ||||||
|  |  | ||||||
| WORKDIR /web | WORKDIR /build | ||||||
| COPY ./VERSION . | COPY web/package.json . | ||||||
|  | RUN npm install | ||||||
| COPY ./web . | COPY ./web . | ||||||
|  | COPY ./VERSION . | ||||||
| WORKDIR /web/default |  | ||||||
| RUN npm install |  | ||||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build |  | ||||||
|  |  | ||||||
| WORKDIR /web/berry |  | ||||||
| RUN npm install |  | ||||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||||
|  |  | ||||||
| FROM golang AS builder2 | FROM golang AS builder2 | ||||||
| @@ -22,8 +17,8 @@ WORKDIR /build | |||||||
| ADD go.mod go.sum ./ | ADD go.mod go.sum ./ | ||||||
| RUN go mod download | RUN go mod download | ||||||
| COPY . . | COPY . . | ||||||
| COPY --from=builder /web/build ./web/build | COPY --from=builder /build/build ./web/build | ||||||
| RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api | RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api | ||||||
|  |  | ||||||
| FROM alpine | FROM alpine | ||||||
|  |  | ||||||
| @@ -35,4 +30,4 @@ RUN apk update \ | |||||||
| COPY --from=builder2 /build/one-api / | COPY --from=builder2 /build/one-api / | ||||||
| EXPOSE 3000 | EXPOSE 3000 | ||||||
| WORKDIR /data | WORKDIR /data | ||||||
| ENTRYPOINT ["/one-api"] | ENTRYPOINT ["/one-api"] | ||||||
|   | |||||||
							
								
								
									
										142
									
								
								README.en.md
									
									
									
									
									
								
							
							
						
						
									
										142
									
								
								README.en.md
									
									
									
									
									
								
							| @@ -3,35 +3,45 @@ | |||||||
| </p> | </p> | ||||||
|  |  | ||||||
| <p align="center"> | <p align="center"> | ||||||
|   <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a> |   <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a> | ||||||
| </p> | </p> | ||||||
|  |  | ||||||
| <div align="center"> | <div align="center"> | ||||||
|  |  | ||||||
| # One API | # One API | ||||||
|  |  | ||||||
|  | _This project is a derivative of [one-api](https://github.com/songquanpeng/one-api), where the main focus has been on modularizing the module code from the original project and modifying the frontend interface. This project also adheres to the MIT License._ | ||||||
|  |  | ||||||
|  | <p align="center"> | ||||||
|  |   <a href="https://raw.githubusercontent.com/MartialBE/one-api/main/LICENSE"> | ||||||
|  |     <img src="https://img.shields.io/github/license/MartialBE/one-api?color=brightgreen" alt="license"> | ||||||
|  |   </a> | ||||||
|  |   <a href="https://github.com/MartialBE/one-api/releases/latest"> | ||||||
|  |     <img src="https://img.shields.io/github/v/release/MartialBE/one-api?color=brightgreen&include_prereleases" alt="release"> | ||||||
|  |   </a> | ||||||
|  |   <a href="https://github.com/users/MartialBE/packages/container/package/one-api"> | ||||||
|  |     <img src="https://img.shields.io/badge/docker-ghcr.io-blue" alt="docker"> | ||||||
|  |   </a> | ||||||
|  |   <a href="https://goreportcard.com/report/github.com/MartialBE/one-api"> | ||||||
|  |     <img src="https://goreportcard.com/badge/github.com/MartialBE/one-api" alt="GoReportCard"> | ||||||
|  |   </a> | ||||||
|  | </p> | ||||||
|  |  | ||||||
|  | **Please do not mix with the original version, as the different channel ID may cause data disorder.** | ||||||
|  |  | ||||||
|  | ## Screenshots | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _The following is the original project description:_ | ||||||
|  |  | ||||||
|  | --- | ||||||
|  |  | ||||||
| _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use ✨_ | _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use ✨_ | ||||||
|  |  | ||||||
| </div> | </div> | ||||||
|  |  | ||||||
| <p align="center"> |  | ||||||
|   <a href="https://raw.githubusercontent.com/songquanpeng/one-api/main/LICENSE"> |  | ||||||
|     <img src="https://img.shields.io/github/license/songquanpeng/one-api?color=brightgreen" alt="license"> |  | ||||||
|   </a> |  | ||||||
|   <a href="https://github.com/songquanpeng/one-api/releases/latest"> |  | ||||||
|     <img src="https://img.shields.io/github/v/release/songquanpeng/one-api?color=brightgreen&include_prereleases" alt="release"> |  | ||||||
|   </a> |  | ||||||
|   <a href="https://hub.docker.com/repository/docker/justsong/one-api"> |  | ||||||
|     <img src="https://img.shields.io/docker/pulls/justsong/one-api?color=brightgreen" alt="docker pull"> |  | ||||||
|   </a> |  | ||||||
|   <a href="https://github.com/songquanpeng/one-api/releases/latest"> |  | ||||||
|     <img src="https://img.shields.io/github/downloads/songquanpeng/one-api/total?color=brightgreen&include_prereleases" alt="release"> |  | ||||||
|   </a> |  | ||||||
|   <a href="https://goreportcard.com/report/github.com/songquanpeng/one-api"> |  | ||||||
|     <img src="https://goreportcard.com/badge/github.com/songquanpeng/one-api" alt="GoReportCard"> |  | ||||||
|   </a> |  | ||||||
| </p> |  | ||||||
|  |  | ||||||
| <p align="center"> | <p align="center"> | ||||||
|   <a href="#deployment">Deployment Tutorial</a> |   <a href="#deployment">Deployment Tutorial</a> | ||||||
|   · |   · | ||||||
| @@ -57,13 +67,14 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use | |||||||
| > **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability. | > **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability. | ||||||
|  |  | ||||||
| ## Features | ## Features | ||||||
|  |  | ||||||
| 1. Support for multiple large models: | 1. Support for multiple large models: | ||||||
|    + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) |    - [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||||
|    + [x] [Anthropic Claude Series Models](https://anthropic.com) |    - [x] [Anthropic Claude Series Models](https://anthropic.com) | ||||||
|    + [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google) |    - [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google) | ||||||
|    + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) |    - [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||||
|    + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) |    - [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) | ||||||
|    + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) |    - [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) | ||||||
| 2. Supports access to multiple channels through **load balancing**. | 2. Supports access to multiple channels through **load balancing**. | ||||||
| 3. Supports **stream mode** that enables typewriter-like effect through stream transmission. | 3. Supports **stream mode** that enables typewriter-like effect through stream transmission. | ||||||
| 4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details. | 4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details. | ||||||
| @@ -82,13 +93,15 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use | |||||||
| 15. Supports management API access through system access tokens. | 15. Supports management API access through system access tokens. | ||||||
| 16. Supports Cloudflare Turnstile user verification. | 16. Supports Cloudflare Turnstile user verification. | ||||||
| 17. Supports user management and multiple user login/registration methods: | 17. Supports user management and multiple user login/registration methods: | ||||||
|     + Email login/registration and password reset via email. |     - Email login/registration and password reset via email. | ||||||
|     + [GitHub OAuth](https://github.com/settings/applications/new). |     - [GitHub OAuth](https://github.com/settings/applications/new). | ||||||
|     + WeChat Official Account authorization (requires additional deployment of [WeChat Server](https://github.com/songquanpeng/wechat-server)). |     - WeChat Official Account authorization (requires additional deployment of [WeChat Server](https://github.com/songquanpeng/wechat-server)). | ||||||
| 18. Immediate support and encapsulation of other major model APIs as they become available. | 18. Immediate support and encapsulation of other major model APIs as they become available. | ||||||
|  |  | ||||||
| ## Deployment | ## Deployment | ||||||
|  |  | ||||||
| ### Docker Deployment | ### Docker Deployment | ||||||
|  |  | ||||||
| Deployment command: `docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api-en` | Deployment command: `docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api-en` | ||||||
|  |  | ||||||
| Update command: `docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR` | Update command: `docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR` | ||||||
| @@ -98,10 +111,11 @@ The first `3000` in `-p 3000:3000` is the port of the host, which can be modifie | |||||||
| Data will be saved in the `/home/ubuntu/data/one-api` directory on the host. Ensure that the directory exists and has write permissions, or change it to a suitable directory. | Data will be saved in the `/home/ubuntu/data/one-api` directory on the host. Ensure that the directory exists and has write permissions, or change it to a suitable directory. | ||||||
|  |  | ||||||
| Nginx reference configuration: | Nginx reference configuration: | ||||||
|  |  | ||||||
| ``` | ``` | ||||||
| server{ | server{ | ||||||
|    server_name openai.justsong.cn;  # Modify your domain name accordingly |    server_name openai.justsong.cn;  # Modify your domain name accordingly | ||||||
|     |  | ||||||
|    location / { |    location / { | ||||||
|           client_max_body_size  64m; |           client_max_body_size  64m; | ||||||
|           proxy_http_version 1.1; |           proxy_http_version 1.1; | ||||||
| @@ -115,6 +129,7 @@ server{ | |||||||
| ``` | ``` | ||||||
|  |  | ||||||
| Next, configure HTTPS with Let's Encrypt certbot: | Next, configure HTTPS with Let's Encrypt certbot: | ||||||
|  |  | ||||||
| ```bash | ```bash | ||||||
| # Install certbot on Ubuntu: | # Install certbot on Ubuntu: | ||||||
| sudo snap install --classic certbot | sudo snap install --classic certbot | ||||||
| @@ -129,20 +144,23 @@ sudo service nginx restart | |||||||
| The initial account username is `root` and password is `123456`. | The initial account username is `root` and password is `123456`. | ||||||
|  |  | ||||||
| ### Manual Deployment | ### Manual Deployment | ||||||
|  |  | ||||||
| 1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source: | 1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source: | ||||||
|  |  | ||||||
|    ```shell |    ```shell | ||||||
|    git clone https://github.com/songquanpeng/one-api.git |    git clone https://github.com/songquanpeng/one-api.git | ||||||
|     |  | ||||||
|    # Build the frontend |    # Build the frontend | ||||||
|    cd one-api/web/default |    cd one-api/web | ||||||
|    npm install |    npm install | ||||||
|    npm run build |    npm run build | ||||||
|     |  | ||||||
|    # Build the backend |    # Build the backend | ||||||
|    cd ../.. |    cd .. | ||||||
|    go mod download |    go mod download | ||||||
|    go build -ldflags "-s -w" -o one-api |    go build -ldflags "-s -w" -o one-api | ||||||
|    ``` |    ``` | ||||||
|  |  | ||||||
| 2. Run: | 2. Run: | ||||||
|    ```shell |    ```shell | ||||||
|    chmod u+x one-api |    chmod u+x one-api | ||||||
| @@ -153,6 +171,7 @@ The initial account username is `root` and password is `123456`. | |||||||
| For more detailed deployment tutorials, please refer to [this page](https://iamazing.cn/page/how-to-deploy-a-website). | For more detailed deployment tutorials, please refer to [this page](https://iamazing.cn/page/how-to-deploy-a-website). | ||||||
|  |  | ||||||
| ### Multi-machine Deployment | ### Multi-machine Deployment | ||||||
|  |  | ||||||
| 1. Set the same `SESSION_SECRET` for all servers. | 1. Set the same `SESSION_SECRET` for all servers. | ||||||
| 2. Set `SQL_DSN` and use MySQL instead of SQLite. All servers should connect to the same database. | 2. Set `SQL_DSN` and use MySQL instead of SQLite. All servers should connect to the same database. | ||||||
| 3. Set the `NODE_TYPE` for all non-master nodes to `slave`. | 3. Set the `NODE_TYPE` for all non-master nodes to `slave`. | ||||||
| @@ -164,11 +183,13 @@ For more detailed deployment tutorials, please refer to [this page](https://iama | |||||||
| Please refer to the [environment variables](#environment-variables) section for details on using environment variables. | Please refer to the [environment variables](#environment-variables) section for details on using environment variables. | ||||||
|  |  | ||||||
| ### Deployment on Control Panels (e.g., Baota) | ### Deployment on Control Panels (e.g., Baota) | ||||||
|  |  | ||||||
| Refer to [#175](https://github.com/songquanpeng/one-api/issues/175) for detailed instructions. | Refer to [#175](https://github.com/songquanpeng/one-api/issues/175) for detailed instructions. | ||||||
|  |  | ||||||
| If you encounter a blank page after deployment, refer to [#97](https://github.com/songquanpeng/one-api/issues/97) for possible solutions. | If you encounter a blank page after deployment, refer to [#97](https://github.com/songquanpeng/one-api/issues/97) for possible solutions. | ||||||
|  |  | ||||||
| ### Deployment on Third-Party Platforms | ### Deployment on Third-Party Platforms | ||||||
|  |  | ||||||
| <details> | <details> | ||||||
| <summary><strong>Deploy on Sealos</strong></summary> | <summary><strong>Deploy on Sealos</strong></summary> | ||||||
| <div> | <div> | ||||||
| @@ -179,7 +200,6 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co | |||||||
|  |  | ||||||
| [](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) | [](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) | ||||||
|  |  | ||||||
|  |  | ||||||
| </div> | </div> | ||||||
| </details> | </details> | ||||||
|  |  | ||||||
| @@ -194,7 +214,7 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co | |||||||
| 1. First, fork the code. | 1. First, fork the code. | ||||||
| 2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console. | 2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console. | ||||||
| 3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port). | 3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port). | ||||||
| 4. Copy the connection parameters and run ```create database `one-api` ``` to create the database. | 4. Copy the connection parameters and run `` create database `one-api`  `` to create the database. | ||||||
| 5. Then, in Service -> Add Service, select Git (authorization is required for the first use) and choose your forked repository. | 5. Then, in Service -> Add Service, select Git (authorization is required for the first use) and choose your forked repository. | ||||||
| 6. Automatic deployment will start, but please cancel it for now. Go to the Variable tab, add a `PORT` with a value of `3000`, and then add a `SQL_DSN` with a value of `<username>:<password>@tcp(<addr>:<port>)/one-api`. Save the changes. Please note that if `SQL_DSN` is not set, data will not be persisted, and the data will be lost after redeployment. | 6. Automatic deployment will start, but please cancel it for now. Go to the Variable tab, add a `PORT` with a value of `3000`, and then add a `SQL_DSN` with a value of `<username>:<password>@tcp(<addr>:<port>)/one-api`. Save the changes. Please note that if `SQL_DSN` is not set, data will not be persisted, and the data will be lost after redeployment. | ||||||
| 7. Select Redeploy. | 7. Select Redeploy. | ||||||
| @@ -205,6 +225,7 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co | |||||||
| </details> | </details> | ||||||
|  |  | ||||||
| ## Configuration | ## Configuration | ||||||
|  |  | ||||||
| The system is ready to use out of the box. | The system is ready to use out of the box. | ||||||
|  |  | ||||||
| You can configure it by setting environment variables or command line parameters. | You can configure it by setting environment variables or command line parameters. | ||||||
| @@ -212,6 +233,7 @@ You can configure it by setting environment variables or command line parameters | |||||||
| After the system starts, log in as the `root` user to further configure the system. | After the system starts, log in as the `root` user to further configure the system. | ||||||
|  |  | ||||||
| ## Usage | ## Usage | ||||||
|  |  | ||||||
| Add your API Key on the `Channels` page, and then add an access token on the `Tokens` page. | Add your API Key on the `Channels` page, and then add an access token on the `Tokens` page. | ||||||
|  |  | ||||||
| You can then use your access token to access One API. The usage is consistent with the [OpenAI API](https://platform.openai.com/docs/api-reference/introduction). | You can then use your access token to access One API. The usage is consistent with the [OpenAI API](https://platform.openai.com/docs/api-reference/introduction). | ||||||
| @@ -235,59 +257,65 @@ Note that the token needs to be created by an administrator to specify the chann | |||||||
| If the channel ID is not provided, load balancing will be used to distribute the requests to multiple channels. | If the channel ID is not provided, load balancing will be used to distribute the requests to multiple channels. | ||||||
|  |  | ||||||
| ### Environment Variables | ### Environment Variables | ||||||
|  |  | ||||||
| 1. `REDIS_CONN_STRING`: When set, Redis will be used as the storage for request rate limiting instead of memory. | 1. `REDIS_CONN_STRING`: When set, Redis will be used as the storage for request rate limiting instead of memory. | ||||||
|     + Example: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153` |    - Example: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||||
| 2. `SESSION_SECRET`: When set, a fixed session key will be used to ensure that cookies of logged-in users are still valid after the system restarts. | 2. `SESSION_SECRET`: When set, a fixed session key will be used to ensure that cookies of logged-in users are still valid after the system restarts. | ||||||
|     + Example: `SESSION_SECRET=random_string` |    - Example: `SESSION_SECRET=random_string` | ||||||
| 3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0. | 3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0. | ||||||
|     + Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` |    - Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||||
| 4. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. | 4. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. | ||||||
|     + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` |    - Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||||
| 5. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. | 5. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. | ||||||
|     + Example: `SYNC_FREQUENCY=60` |    - Example: `SYNC_FREQUENCY=60` | ||||||
| 6. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. | 6. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. | ||||||
|     + Example: `NODE_TYPE=slave` |    - Example: `NODE_TYPE=slave` | ||||||
| 7. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. | 7. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. | ||||||
|     + Example: `CHANNEL_UPDATE_FREQUENCY=1440` |    - Example: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||||
| 8. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. | 8. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. | ||||||
|     + Example: `CHANNEL_TEST_FREQUENCY=1440` |    - Example: `CHANNEL_TEST_FREQUENCY=1440` | ||||||
| 9. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. | 9. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. | ||||||
|     + Example: `POLLING_INTERVAL=5` |    - Example: `POLLING_INTERVAL=5` | ||||||
|  |  | ||||||
| ### Command Line Parameters | ### Command Line Parameters | ||||||
|  |  | ||||||
| 1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`. | 1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`. | ||||||
|     + Example: `--port 3000` |    - Example: `--port 3000` | ||||||
| 2. `--log-dir <log_dir>`: Specifies the log directory. If not set, the logs will not be saved. | 2. `--log-dir <log_dir>`: Specifies the log directory. If not set, the logs will not be saved. | ||||||
|     + Example: `--log-dir ./logs` |    - Example: `--log-dir ./logs` | ||||||
| 3. `--version`: Prints the system version number and exits. | 3. `--version`: Prints the system version number and exits. | ||||||
| 4. `--help`: Displays the command usage help and parameter descriptions. | 4. `--help`: Displays the command usage help and parameter descriptions. | ||||||
|  |  | ||||||
| ## Screenshots | ## Screenshots | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| ## FAQ | ## FAQ | ||||||
|  |  | ||||||
| 1. What is quota? How is it calculated? Does One API have quota calculation issues? | 1. What is quota? How is it calculated? Does One API have quota calculation issues? | ||||||
|     + Quota = Group multiplier * Model multiplier * (number of prompt tokens + number of completion tokens * completion multiplier) |    - Quota = Group multiplier _ Model multiplier _ (number of prompt tokens + number of completion tokens \* completion multiplier) | ||||||
|     + The completion multiplier is fixed at 1.33 for GPT3.5 and 2 for GPT4, consistent with the official definition. |    - The completion multiplier is fixed at 1.33 for GPT3.5 and 2 for GPT4, consistent with the official definition. | ||||||
|     + If it is not a stream mode, the official API will return the total number of tokens consumed. However, please note that the consumption multipliers for prompts and completions are different. |    - If it is not a stream mode, the official API will return the total number of tokens consumed. However, please note that the consumption multipliers for prompts and completions are different. | ||||||
| 2. Why does it prompt "insufficient quota" even though my account balance is sufficient? | 2. Why does it prompt "insufficient quota" even though my account balance is sufficient? | ||||||
|     + Please check if your token quota is sufficient. It is separate from the account balance. |    - Please check if your token quota is sufficient. It is separate from the account balance. | ||||||
|     + The token quota is used to set the maximum usage and can be freely set by the user. |    - The token quota is used to set the maximum usage and can be freely set by the user. | ||||||
| 3. It says "No available channels" when trying to use a channel. What should I do? | 3. It says "No available channels" when trying to use a channel. What should I do? | ||||||
|     + Please check the user and channel group settings. |    - Please check the user and channel group settings. | ||||||
|     + Also check the channel model settings. |    - Also check the channel model settings. | ||||||
| 4. Channel testing reports an error: "invalid character '<' looking for beginning of value" | 4. Channel testing reports an error: "invalid character '<' looking for beginning of value" | ||||||
|     + This error occurs when the returned value is not valid JSON but an HTML page. |    - This error occurs when the returned value is not valid JSON but an HTML page. | ||||||
|     + Most likely, the IP of your deployment site or the node of the proxy has been blocked by CloudFlare. |    - Most likely, the IP of your deployment site or the node of the proxy has been blocked by CloudFlare. | ||||||
| 5. ChatGPT Next Web reports an error: "Failed to fetch" | 5. ChatGPT Next Web reports an error: "Failed to fetch" | ||||||
|     + Do not set `BASE_URL` during deployment. |    - Do not set `BASE_URL` during deployment. | ||||||
|     + Double-check that your interface address and API Key are correct. |    - Double-check that your interface address and API Key are correct. | ||||||
|  |  | ||||||
| ## Related Projects | ## Related Projects | ||||||
|  |  | ||||||
| [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM | [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM | ||||||
|  |  | ||||||
| ## Note | ## Note | ||||||
|  |  | ||||||
| This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. | This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. | ||||||
|  |  | ||||||
| This project is released under the MIT license. Based on this, attribution and a link to this project must be included at the bottom of the page. | This project is released under the MIT license. Based on this, attribution and a link to this project must be included at the bottom of the page. | ||||||
|   | |||||||
							
								
								
									
										140
									
								
								README.ja.md
									
									
									
									
									
								
							
							
						
						
									
										140
									
								
								README.ja.md
									
									
									
									
									
								
							| @@ -3,35 +3,45 @@ | |||||||
| </p> | </p> | ||||||
|  |  | ||||||
| <p align="center"> | <p align="center"> | ||||||
|   <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a> |   <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a> | ||||||
| </p> | </p> | ||||||
|  |  | ||||||
| <div align="center"> | <div align="center"> | ||||||
|  |  | ||||||
| # One API | # One API | ||||||
|  |  | ||||||
|  | _このプロジェクトは、[one-api](https://github.com/songquanpeng/one-api)をベースにしており、元のプロジェクトのモジュールコードを分離し、モジュール化し、フロントエンドのインターフェースを変更しました。このプロジェクトも MIT ライセンスに従っています。_ | ||||||
|  |  | ||||||
|  | <p align="center"> | ||||||
|  |   <a href="https://raw.githubusercontent.com/MartialBE/one-api/main/LICENSE"> | ||||||
|  |     <img src="https://img.shields.io/github/license/MartialBE/one-api?color=brightgreen" alt="license"> | ||||||
|  |   </a> | ||||||
|  |   <a href="https://github.com/MartialBE/one-api/releases/latest"> | ||||||
|  |     <img src="https://img.shields.io/github/v/release/MartialBE/one-api?color=brightgreen&include_prereleases" alt="release"> | ||||||
|  |   </a> | ||||||
|  |   <a href="https://github.com/users/MartialBE/packages/container/package/one-api"> | ||||||
|  |     <img src="https://img.shields.io/badge/docker-ghcr.io-blue" alt="docker"> | ||||||
|  |   </a> | ||||||
|  |   <a href="https://goreportcard.com/report/github.com/MartialBE/one-api"> | ||||||
|  |     <img src="https://goreportcard.com/badge/github.com/MartialBE/one-api" alt="GoReportCard"> | ||||||
|  |   </a> | ||||||
|  | </p> | ||||||
|  |  | ||||||
|  | **オリジナルバージョンと混合しないでください。チャンネル ID が異なるため、データの混乱を引き起こす可能性があります** | ||||||
|  |  | ||||||
|  | ## スクリーンショット | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _以下は元の項目の説明です:_ | ||||||
|  |  | ||||||
|  | --- | ||||||
|  |  | ||||||
| _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM にアクセスでき、導入と利用が容易です ✨_ | _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM にアクセスでき、導入と利用が容易です ✨_ | ||||||
|  |  | ||||||
| </div> | </div> | ||||||
|  |  | ||||||
| <p align="center"> |  | ||||||
|   <a href="https://raw.githubusercontent.com/songquanpeng/one-api/main/LICENSE"> |  | ||||||
|     <img src="https://img.shields.io/github/license/songquanpeng/one-api?color=brightgreen" alt="license"> |  | ||||||
|   </a> |  | ||||||
|   <a href="https://github.com/songquanpeng/one-api/releases/latest"> |  | ||||||
|     <img src="https://img.shields.io/github/v/release/songquanpeng/one-api?color=brightgreen&include_prereleases" alt="release"> |  | ||||||
|   </a> |  | ||||||
|   <a href="https://hub.docker.com/repository/docker/justsong/one-api"> |  | ||||||
|     <img src="https://img.shields.io/docker/pulls/justsong/one-api?color=brightgreen" alt="docker pull"> |  | ||||||
|   </a> |  | ||||||
|   <a href="https://github.com/songquanpeng/one-api/releases/latest"> |  | ||||||
|     <img src="https://img.shields.io/github/downloads/songquanpeng/one-api/total?color=brightgreen&include_prereleases" alt="release"> |  | ||||||
|   </a> |  | ||||||
|   <a href="https://goreportcard.com/report/github.com/songquanpeng/one-api"> |  | ||||||
|     <img src="https://goreportcard.com/badge/github.com/songquanpeng/one-api" alt="GoReportCard"> |  | ||||||
|   </a> |  | ||||||
| </p> |  | ||||||
|  |  | ||||||
| <p align="center"> | <p align="center"> | ||||||
|   <a href="#deployment">デプロイチュートリアル</a> |   <a href="#deployment">デプロイチュートリアル</a> | ||||||
|   · |   · | ||||||
| @@ -57,13 +67,14 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に | |||||||
| > **注**: Docker からプルされた最新のイメージは、`alpha` リリースかもしれません。安定性が必要な場合は、手動でバージョンを指定してください。 | > **注**: Docker からプルされた最新のイメージは、`alpha` リリースかもしれません。安定性が必要な場合は、手動でバージョンを指定してください。 | ||||||
|  |  | ||||||
| ## 特徴 | ## 特徴 | ||||||
|  |  | ||||||
| 1. 複数の大型モデルをサポート: | 1. 複数の大型モデルをサポート: | ||||||
|    + [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート) |    - [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート) | ||||||
|    + [x] [Anthropic Claude シリーズモデル](https://anthropic.com) |    - [x] [Anthropic Claude シリーズモデル](https://anthropic.com) | ||||||
|    + [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google) |    - [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google) | ||||||
|    + [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) |    - [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||||
|    + [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html) |    - [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html) | ||||||
|    + [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn) |    - [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn) | ||||||
| 2. **ロードバランシング**による複数チャンネルへのアクセスをサポート。 | 2. **ロードバランシング**による複数チャンネルへのアクセスをサポート。 | ||||||
| 3. ストリーム伝送によるタイプライター的効果を可能にする**ストリームモード**に対応。 | 3. ストリーム伝送によるタイプライター的効果を可能にする**ストリームモード**に対応。 | ||||||
| 4. **マルチマシンデプロイ**に対応。[詳細はこちら](#multi-machine-deployment)を参照。 | 4. **マルチマシンデプロイ**に対応。[詳細はこちら](#multi-machine-deployment)を参照。 | ||||||
| @@ -82,13 +93,15 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に | |||||||
| 15. システム・アクセストークンによる管理 API アクセスをサポートする。 | 15. システム・アクセストークンによる管理 API アクセスをサポートする。 | ||||||
| 16. Cloudflare Turnstile によるユーザー認証に対応。 | 16. Cloudflare Turnstile によるユーザー認証に対応。 | ||||||
| 17. ユーザー管理と複数のユーザーログイン/登録方法をサポート: | 17. ユーザー管理と複数のユーザーログイン/登録方法をサポート: | ||||||
|     + 電子メールによるログイン/登録とパスワードリセット。 |     - 電子メールによるログイン/登録とパスワードリセット。 | ||||||
|     + [GitHub OAuth](https://github.com/settings/applications/new)。 |     - [GitHub OAuth](https://github.com/settings/applications/new)。 | ||||||
|     + WeChat 公式アカウントの認証([WeChat Server](https://github.com/songquanpeng/wechat-server)の追加導入が必要)。 |     - WeChat 公式アカウントの認証([WeChat Server](https://github.com/songquanpeng/wechat-server)の追加導入が必要)。 | ||||||
| 18. 他の主要なモデル API が利用可能になった場合、即座にサポートし、カプセル化する。 | 18. 他の主要なモデル API が利用可能になった場合、即座にサポートし、カプセル化する。 | ||||||
|  |  | ||||||
| ## デプロイメント | ## デプロイメント | ||||||
|  |  | ||||||
| ### Docker デプロイメント | ### Docker デプロイメント | ||||||
|  |  | ||||||
| デプロイコマンド: `docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api-en`。 | デプロイコマンド: `docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api-en`。 | ||||||
|  |  | ||||||
| コマンドを更新する: `docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrr/watchtower -cR`。 | コマンドを更新する: `docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrr/watchtower -cR`。 | ||||||
| @@ -97,7 +110,8 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に | |||||||
|  |  | ||||||
| データはホストの `/home/ubuntu/data/one-api` ディレクトリに保存される。このディレクトリが存在し、書き込み権限があることを確認する、もしくは適切なディレクトリに変更してください。 | データはホストの `/home/ubuntu/data/one-api` ディレクトリに保存される。このディレクトリが存在し、書き込み権限があることを確認する、もしくは適切なディレクトリに変更してください。 | ||||||
|  |  | ||||||
| Nginxリファレンス設定: | Nginx リファレンス設定: | ||||||
|  |  | ||||||
| ``` | ``` | ||||||
| server{ | server{ | ||||||
|    server_name openai.justsong.cn;  # ドメイン名は適宜変更 |    server_name openai.justsong.cn;  # ドメイン名は適宜変更 | ||||||
| @@ -116,6 +130,7 @@ server{ | |||||||
| ``` | ``` | ||||||
|  |  | ||||||
| 次に、Let's Encrypt certbot を使って HTTPS を設定します: | 次に、Let's Encrypt certbot を使って HTTPS を設定します: | ||||||
|  |  | ||||||
| ```bash | ```bash | ||||||
| # Ubuntu に certbot をインストール: | # Ubuntu に certbot をインストール: | ||||||
| sudo snap install --classic certbot | sudo snap install --classic certbot | ||||||
| @@ -130,20 +145,23 @@ sudo service nginx restart | |||||||
| 初期アカウントのユーザー名は `root` で、パスワードは `123456` です。 | 初期アカウントのユーザー名は `root` で、パスワードは `123456` です。 | ||||||
|  |  | ||||||
| ### マニュアルデプロイ | ### マニュアルデプロイ | ||||||
|  |  | ||||||
| 1. [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) から実行ファイルをダウンロードする、もしくはソースからコンパイルする: | 1. [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) から実行ファイルをダウンロードする、もしくはソースからコンパイルする: | ||||||
|  |  | ||||||
|    ```shell |    ```shell | ||||||
|    git clone https://github.com/songquanpeng/one-api.git |    git clone https://github.com/songquanpeng/one-api.git | ||||||
|  |  | ||||||
|    # フロントエンドのビルド |    # フロントエンドのビルド | ||||||
|    cd one-api/web/default |    cd one-api/web | ||||||
|    npm install |    npm install | ||||||
|    npm run build |    npm run build | ||||||
|  |  | ||||||
|    # バックエンドのビルド |    # バックエンドのビルド | ||||||
|    cd ../.. |    cd .. | ||||||
|    go mod download |    go mod download | ||||||
|    go build -ldflags "-s -w" -o one-api |    go build -ldflags "-s -w" -o one-api | ||||||
|    ``` |    ``` | ||||||
|  |  | ||||||
| 2. 実行: | 2. 実行: | ||||||
|    ```shell |    ```shell | ||||||
|    chmod u+x one-api |    chmod u+x one-api | ||||||
| @@ -154,6 +172,7 @@ sudo service nginx restart | |||||||
| より詳細なデプロイのチュートリアルについては、[このページ](https://iamazing.cn/page/how-to-deploy-a-website) を参照してください。 | より詳細なデプロイのチュートリアルについては、[このページ](https://iamazing.cn/page/how-to-deploy-a-website) を参照してください。 | ||||||
|  |  | ||||||
| ### マルチマシンデプロイ | ### マルチマシンデプロイ | ||||||
|  |  | ||||||
| 1. すべてのサーバに同じ `SESSION_SECRET` を設定する。 | 1. すべてのサーバに同じ `SESSION_SECRET` を設定する。 | ||||||
| 2. `SQL_DSN` を設定し、SQLite の代わりに MySQL を使用する。すべてのサーバは同じデータベースに接続する。 | 2. `SQL_DSN` を設定し、SQLite の代わりに MySQL を使用する。すべてのサーバは同じデータベースに接続する。 | ||||||
| 3. マスターノード以外のノードの `NODE_TYPE` を `slave` に設定する。 | 3. マスターノード以外のノードの `NODE_TYPE` を `slave` に設定する。 | ||||||
| @@ -165,11 +184,13 @@ sudo service nginx restart | |||||||
| Please refer to the [environment variables](#environment-variables) section for details on using environment variables. | Please refer to the [environment variables](#environment-variables) section for details on using environment variables. | ||||||
|  |  | ||||||
| ### コントロールパネル(例: Baota)への展開 | ### コントロールパネル(例: Baota)への展開 | ||||||
|  |  | ||||||
| 詳しい手順は [#175](https://github.com/songquanpeng/one-api/issues/175) を参照してください。 | 詳しい手順は [#175](https://github.com/songquanpeng/one-api/issues/175) を参照してください。 | ||||||
|  |  | ||||||
| 配置後に空白のページが表示される場合は、[#97](https://github.com/songquanpeng/one-api/issues/97) を参照してください。 | 配置後に空白のページが表示される場合は、[#97](https://github.com/songquanpeng/one-api/issues/97) を参照してください。 | ||||||
|  |  | ||||||
| ### サードパーティプラットフォームへのデプロイ | ### サードパーティプラットフォームへのデプロイ | ||||||
|  |  | ||||||
| <details> | <details> | ||||||
| <summary><strong>Sealos へのデプロイ</strong></summary> | <summary><strong>Sealos へのデプロイ</strong></summary> | ||||||
| <div> | <div> | ||||||
| @@ -180,7 +201,6 @@ Please refer to the [environment variables](#environment-variables) section for | |||||||
|  |  | ||||||
| [](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) | [](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) | ||||||
|  |  | ||||||
|  |  | ||||||
| </div> | </div> | ||||||
| </details> | </details> | ||||||
|  |  | ||||||
| @@ -194,8 +214,8 @@ Please refer to the [environment variables](#environment-variables) section for | |||||||
|  |  | ||||||
| 1. まず、コードをフォークする。 | 1. まず、コードをフォークする。 | ||||||
| 2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。 | 2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。 | ||||||
| 3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。 | 3. 新しいプロジェクトを作成します。Service -> Add Service で Marketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。 | ||||||
| 4. 接続パラメータをコピーし、```create database `one-api` ``` を実行してデータベースを作成する。 | 4. 接続パラメータをコピーし、`` create database `one-api`  `` を実行してデータベースを作成する。 | ||||||
| 5. その後、Service -> Add Service で Git を選択し(最初の使用には認証が必要です)、フォークしたリポジトリを選択します。 | 5. その後、Service -> Add Service で Git を選択し(最初の使用には認証が必要です)、フォークしたリポジトリを選択します。 | ||||||
| 6. 自動デプロイが開始されますが、一旦キャンセルしてください。Variable タブで `PORT` に `3000` を追加し、`SQL_DSN` に `<username>:<password>@tcp(<addr>:<port>)/one-api` を追加します。変更を保存する。SQL_DSN` が設定されていないと、データが永続化されず、再デプロイ後にデータが失われるので注意すること。 | 6. 自動デプロイが開始されますが、一旦キャンセルしてください。Variable タブで `PORT` に `3000` を追加し、`SQL_DSN` に `<username>:<password>@tcp(<addr>:<port>)/one-api` を追加します。変更を保存する。SQL_DSN` が設定されていないと、データが永続化されず、再デプロイ後にデータが失われるので注意すること。 | ||||||
| 7. 再デプロイを選択します。 | 7. 再デプロイを選択します。 | ||||||
| @@ -206,6 +226,7 @@ Please refer to the [environment variables](#environment-variables) section for | |||||||
| </details> | </details> | ||||||
|  |  | ||||||
| ## コンフィグ | ## コンフィグ | ||||||
|  |  | ||||||
| システムは箱から出してすぐに使えます。 | システムは箱から出してすぐに使えます。 | ||||||
|  |  | ||||||
| 環境変数やコマンドラインパラメータを設定することで、システムを構成することができます。 | 環境変数やコマンドラインパラメータを設定することで、システムを構成することができます。 | ||||||
| @@ -213,6 +234,7 @@ Please refer to the [environment variables](#environment-variables) section for | |||||||
| システム起動後、`root` ユーザーとしてログインし、さらにシステムを設定します。 | システム起動後、`root` ユーザーとしてログインし、さらにシステムを設定します。 | ||||||
|  |  | ||||||
| ## 使用方法 | ## 使用方法 | ||||||
|  |  | ||||||
| `Channels` ページで API Key を追加し、`Tokens` ページでアクセストークンを追加する。 | `Channels` ページで API Key を追加し、`Tokens` ページでアクセストークンを追加する。 | ||||||
|  |  | ||||||
| アクセストークンを使って One API にアクセスすることができる。使い方は [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) と同じです。 | アクセストークンを使って One API にアクセスすることができる。使い方は [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) と同じです。 | ||||||
| @@ -236,59 +258,65 @@ graph LR | |||||||
| もしチャネル ID が指定されない場合、ロードバランシングによってリクエストが複数のチャネルに振り分けられます。 | もしチャネル ID が指定されない場合、ロードバランシングによってリクエストが複数のチャネルに振り分けられます。 | ||||||
|  |  | ||||||
| ### 環境変数 | ### 環境変数 | ||||||
|  |  | ||||||
| 1. `REDIS_CONN_STRING`: 設定すると、リクエストレート制限のためのストレージとして、メモリの代わりに Redis が使われる。 | 1. `REDIS_CONN_STRING`: 設定すると、リクエストレート制限のためのストレージとして、メモリの代わりに Redis が使われる。 | ||||||
|     + 例: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153` |    - 例: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||||
| 2. `SESSION_SECRET`: 設定すると、固定セッションキーが使用され、システムの再起動後もログインユーザーのクッキーが有効であることが保証されます。 | 2. `SESSION_SECRET`: 設定すると、固定セッションキーが使用され、システムの再起動後もログインユーザーのクッキーが有効であることが保証されます。 | ||||||
|     + 例: `SESSION_SECRET=random_string` |    - 例: `SESSION_SECRET=random_string` | ||||||
| 3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。 | 3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。 | ||||||
|     + 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` |    - 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||||
| 4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。 | 4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。 | ||||||
|     + 例: `FRONTEND_BASE_URL=https://openai.justsong.cn` |    - 例: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||||
| 5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 | 5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 | ||||||
|     + 例: `SYNC_FREQUENCY=60` |    - 例: `SYNC_FREQUENCY=60` | ||||||
| 6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。 | 6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。 | ||||||
|     + 例: `NODE_TYPE=slave` |    - 例: `NODE_TYPE=slave` | ||||||
| 7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 | 7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 | ||||||
|     + 例: `CHANNEL_UPDATE_FREQUENCY=1440` |    - 例: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||||
| 8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 | 8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 | ||||||
|     + 例: `CHANNEL_TEST_FREQUENCY=1440` |    - 例: `CHANNEL_TEST_FREQUENCY=1440` | ||||||
| 9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 | 9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 | ||||||
|     + 例: `POLLING_INTERVAL=5` |    - 例: `POLLING_INTERVAL=5` | ||||||
|  |  | ||||||
| ### コマンドラインパラメータ | ### コマンドラインパラメータ | ||||||
|  |  | ||||||
| 1. `--port <port_number>`: サーバがリッスンするポート番号を指定。デフォルトは `3000` です。 | 1. `--port <port_number>`: サーバがリッスンするポート番号を指定。デフォルトは `3000` です。 | ||||||
|     + 例: `--port 3000` |    - 例: `--port 3000` | ||||||
| 2. `--log-dir <log_dir>`: ログディレクトリを指定。設定しない場合、ログは保存されません。 | 2. `--log-dir <log_dir>`: ログディレクトリを指定。設定しない場合、ログは保存されません。 | ||||||
|     + 例: `--log-dir ./logs` |    - 例: `--log-dir ./logs` | ||||||
| 3. `--version`: システムのバージョン番号を表示して終了する。 | 3. `--version`: システムのバージョン番号を表示して終了する。 | ||||||
| 4. `--help`: コマンドの使用法ヘルプとパラメータの説明を表示。 | 4. `--help`: コマンドの使用法ヘルプとパラメータの説明を表示。 | ||||||
|  |  | ||||||
| ## スクリーンショット | ## スクリーンショット | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| ## FAQ | ## FAQ | ||||||
|  |  | ||||||
| 1. ノルマとは何か?どのように計算されますか?One API にはノルマ計算の問題はありますか? | 1. ノルマとは何か?どのように計算されますか?One API にはノルマ計算の問題はありますか? | ||||||
|     + ノルマ = グループ倍率 * モデル倍率 * (プロンプトトークンの数 + 完了トークンの数 * 完了倍率) |    - ノルマ = グループ倍率 _ モデル倍率 _ (プロンプトトークンの数 + 完了トークンの数 \* 完了倍率) | ||||||
|     + 完了倍率は、公式の定義と一致するように、GPT3.5 では 1.33、GPT4 では 2 に固定されています。 |    - 完了倍率は、公式の定義と一致するように、GPT3.5 では 1.33、GPT4 では 2 に固定されています。 | ||||||
|     + ストリームモードでない場合、公式 API は消費したトークンの総数を返す。ただし、プロンプトとコンプリートの消費倍率は異なるので注意してください。 |    - ストリームモードでない場合、公式 API は消費したトークンの総数を返す。ただし、プロンプトとコンプリートの消費倍率は異なるので注意してください。 | ||||||
| 2. アカウント残高は十分なのに、"insufficient quota" と表示されるのはなぜですか? | 2. アカウント残高は十分なのに、"insufficient quota" と表示されるのはなぜですか? | ||||||
|     + トークンのクォータが十分かどうかご確認ください。トークンクォータはアカウント残高とは別のものです。 |    - トークンのクォータが十分かどうかご確認ください。トークンクォータはアカウント残高とは別のものです。 | ||||||
|     + トークンクォータは最大使用量を設定するためのもので、ユーザーが自由に設定できます。 |    - トークンクォータは最大使用量を設定するためのもので、ユーザーが自由に設定できます。 | ||||||
| 3. チャンネルを使おうとすると "No available channels" と表示されます。どうすればいいですか? | 3. チャンネルを使おうとすると "No available channels" と表示されます。どうすればいいですか? | ||||||
|     + ユーザーとチャンネルグループの設定を確認してください。 |    - ユーザーとチャンネルグループの設定を確認してください。 | ||||||
|     + チャンネルモデルの設定も確認してください。 |    - チャンネルモデルの設定も確認してください。 | ||||||
| 4. チャンネルテストがエラーを報告する: "invalid character '<' looking for beginning of value" | 4. チャンネルテストがエラーを報告する: "invalid character '<' looking for beginning of value" | ||||||
|     + このエラーは、返された値が有効な JSON ではなく、HTML ページである場合に発生する。 |    - このエラーは、返された値が有効な JSON ではなく、HTML ページである場合に発生する。 | ||||||
|     + ほとんどの場合、デプロイサイトのIPかプロキシのノードが CloudFlare によってブロックされています。 |    - ほとんどの場合、デプロイサイトの IP かプロキシのノードが CloudFlare によってブロックされています。 | ||||||
| 5. ChatGPT Next Web でエラーが発生しました: "Failed to fetch" | 5. ChatGPT Next Web でエラーが発生しました: "Failed to fetch" | ||||||
|     + デプロイ時に `BASE_URL` を設定しないでください。 |    - デプロイ時に `BASE_URL` を設定しないでください。 | ||||||
|     + インターフェイスアドレスと API Key が正しいか再確認してください。 |    - インターフェイスアドレスと API Key が正しいか再確認してください。 | ||||||
|  |  | ||||||
| ## 関連プロジェクト | ## 関連プロジェクト | ||||||
|  |  | ||||||
| [FastGPT](https://github.com/labring/FastGPT): LLM に基づく知識質問応答システム | [FastGPT](https://github.com/labring/FastGPT): LLM に基づく知識質問応答システム | ||||||
|  |  | ||||||
| ## 注 | ## 注 | ||||||
|  |  | ||||||
| 本プロジェクトはオープンソースプロジェクトです。OpenAI の[利用規約](https://openai.com/policies/terms-of-use)および**適用される法令**を遵守してご利用ください。違法な目的での利用はご遠慮ください。 | 本プロジェクトはオープンソースプロジェクトです。OpenAI の[利用規約](https://openai.com/policies/terms-of-use)および**適用される法令**を遵守してご利用ください。違法な目的での利用はご遠慮ください。 | ||||||
|  |  | ||||||
| このプロジェクトは MIT ライセンスで公開されています。これに基づき、ページの最下部に帰属表示と本プロジェクトへのリンクを含める必要があります。 | このプロジェクトは MIT ライセンスで公開されています。これに基づき、ページの最下部に帰属表示と本プロジェクトへのリンクを含める必要があります。 | ||||||
|   | |||||||
							
								
								
									
										231
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										231
									
								
								README.md
									
									
									
									
									
								
							| @@ -2,37 +2,46 @@ | |||||||
|    <strong>中文</strong> | <a href="./README.en.md">English</a> | <a href="./README.ja.md">日本語</a> |    <strong>中文</strong> | <a href="./README.en.md">English</a> | <a href="./README.ja.md">日本語</a> | ||||||
| </p> | </p> | ||||||
|  |  | ||||||
|  |  | ||||||
| <p align="center"> | <p align="center"> | ||||||
|   <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a> |   <a href="https://github.com/MartialBE/one-api"><img src="https://raw.githubusercontent.com/MartialBE/one-api/main/web/src/assets/images/logo.svg" width="150" height="150" alt="one-api logo"></a> | ||||||
| </p> | </p> | ||||||
|  |  | ||||||
| <div align="center"> | <div align="center"> | ||||||
|  |  | ||||||
| # One API | # One API | ||||||
|  |  | ||||||
|  | _本项目是基于[one-api](https://github.com/songquanpeng/one-api)二次开发而来的,主要将原项目中的模块代码分离,模块化,并修改了前端界面。本项目同样遵循 MIT 协议。_ | ||||||
|  |  | ||||||
|  | <p align="center"> | ||||||
|  |   <a href="https://raw.githubusercontent.com/MartialBE/one-api/main/LICENSE"> | ||||||
|  |     <img src="https://img.shields.io/github/license/MartialBE/one-api?color=brightgreen" alt="license"> | ||||||
|  |   </a> | ||||||
|  |   <a href="https://github.com/MartialBE/one-api/releases/latest"> | ||||||
|  |     <img src="https://img.shields.io/github/v/release/MartialBE/one-api?color=brightgreen&include_prereleases" alt="release"> | ||||||
|  |   </a> | ||||||
|  |   <a href="https://github.com/users/MartialBE/packages/container/package/one-api"> | ||||||
|  |     <img src="https://img.shields.io/badge/docker-ghcr.io-blue" alt="docker"> | ||||||
|  |   </a> | ||||||
|  |   <a href="https://goreportcard.com/report/github.com/MartialBE/one-api"> | ||||||
|  |     <img src="https://goreportcard.com/badge/github.com/MartialBE/one-api" alt="GoReportCard"> | ||||||
|  |   </a> | ||||||
|  | </p> | ||||||
|  |  | ||||||
|  | **请不要和原版混用,因为 channel id 不同的原因,会导致数据错乱** | ||||||
|  |  | ||||||
|  | # 截图展示 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _以下为原项目说明:_ | ||||||
|  |  | ||||||
|  | --- | ||||||
|  |  | ||||||
| _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 ✨_ | _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 ✨_ | ||||||
|  |  | ||||||
| </div> | </div> | ||||||
|  |  | ||||||
| <p align="center"> |  | ||||||
|   <a href="https://raw.githubusercontent.com/songquanpeng/one-api/main/LICENSE"> |  | ||||||
|     <img src="https://img.shields.io/github/license/songquanpeng/one-api?color=brightgreen" alt="license"> |  | ||||||
|   </a> |  | ||||||
|   <a href="https://github.com/songquanpeng/one-api/releases/latest"> |  | ||||||
|     <img src="https://img.shields.io/github/v/release/songquanpeng/one-api?color=brightgreen&include_prereleases" alt="release"> |  | ||||||
|   </a> |  | ||||||
|   <a href="https://hub.docker.com/repository/docker/justsong/one-api"> |  | ||||||
|     <img src="https://img.shields.io/docker/pulls/justsong/one-api?color=brightgreen" alt="docker pull"> |  | ||||||
|   </a> |  | ||||||
|   <a href="https://github.com/songquanpeng/one-api/releases/latest"> |  | ||||||
|     <img src="https://img.shields.io/github/downloads/songquanpeng/one-api/total?color=brightgreen&include_prereleases" alt="release"> |  | ||||||
|   </a> |  | ||||||
|   <a href="https://goreportcard.com/report/github.com/songquanpeng/one-api"> |  | ||||||
|     <img src="https://goreportcard.com/badge/github.com/songquanpeng/one-api" alt="GoReportCard"> |  | ||||||
|   </a> |  | ||||||
| </p> |  | ||||||
|  |  | ||||||
| <p align="center"> | <p align="center"> | ||||||
|   <a href="https://github.com/songquanpeng/one-api#部署">部署教程</a> |   <a href="https://github.com/songquanpeng/one-api#部署">部署教程</a> | ||||||
|   · |   · | ||||||
| @@ -53,7 +62,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
|  |  | ||||||
| > [!NOTE] | > [!NOTE] | ||||||
| > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 | > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 | ||||||
| >  | > | ||||||
| > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 | > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 | ||||||
|  |  | ||||||
| > [!WARNING] | > [!WARNING] | ||||||
| @@ -63,21 +72,17 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
| > 使用 root 用户初次登录系统后,务必修改默认密码 `123456`! | > 使用 root 用户初次登录系统后,务必修改默认密码 `123456`! | ||||||
|  |  | ||||||
| ## 功能 | ## 功能 | ||||||
|  |  | ||||||
| 1. 支持多种大模型: | 1. 支持多种大模型: | ||||||
|    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) |    - [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||||
|    + [x] [Anthropic Claude 系列模型](https://anthropic.com) |    - [x] [Anthropic Claude 系列模型](https://anthropic.com) | ||||||
|    + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) |    - [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) | ||||||
|    + [x] [Mistral 系列模型](https://mistral.ai/) |    - [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||||
|    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) |    - [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) | ||||||
|    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) |    - [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||||
|    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) |    - [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) | ||||||
|    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) |    - [x] [360 智脑](https://ai.360.cn) | ||||||
|    + [x] [360 智脑](https://ai.360.cn) |    - [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) | ||||||
|    + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) |  | ||||||
|    + [x] [Moonshot AI](https://platform.moonshot.cn/) |  | ||||||
|    + [x] [百川大模型](https://platform.baichuan-ai.com) |  | ||||||
|    + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) |  | ||||||
|    + [x] [MINIMAX](https://api.minimax.chat/) |  | ||||||
| 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | ||||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||||
| @@ -101,13 +106,14 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
| 20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。 | 20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。 | ||||||
| 21. 支持 Cloudflare Turnstile 用户校验。 | 21. 支持 Cloudflare Turnstile 用户校验。 | ||||||
| 22. 支持用户管理,支持**多种用户登录注册方式**: | 22. 支持用户管理,支持**多种用户登录注册方式**: | ||||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 |     - 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 |     - [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 |     - 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||||
| 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 |  | ||||||
|  |  | ||||||
| ## 部署 | ## 部署 | ||||||
|  |  | ||||||
| ### 基于 Docker 进行部署 | ### 基于 Docker 进行部署 | ||||||
|  |  | ||||||
| ```shell | ```shell | ||||||
| # 使用 SQLite 的部署命令: | # 使用 SQLite 的部署命令: | ||||||
| docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api | docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api | ||||||
| @@ -129,10 +135,11 @@ docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234 | |||||||
| 更新命令:`docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR` | 更新命令:`docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR` | ||||||
|  |  | ||||||
| Nginx 的参考配置: | Nginx 的参考配置: | ||||||
|  |  | ||||||
| ``` | ``` | ||||||
| server{ | server{ | ||||||
|    server_name openai.justsong.cn;  # 请根据实际情况修改你的域名 |    server_name openai.justsong.cn;  # 请根据实际情况修改你的域名 | ||||||
|     |  | ||||||
|    location / { |    location / { | ||||||
|           client_max_body_size  64m; |           client_max_body_size  64m; | ||||||
|           proxy_http_version 1.1; |           proxy_http_version 1.1; | ||||||
| @@ -147,6 +154,7 @@ server{ | |||||||
| ``` | ``` | ||||||
|  |  | ||||||
| 之后使用 Let's Encrypt 的 certbot 配置 HTTPS: | 之后使用 Let's Encrypt 的 certbot 配置 HTTPS: | ||||||
|  |  | ||||||
| ```bash | ```bash | ||||||
| # Ubuntu 安装 certbot: | # Ubuntu 安装 certbot: | ||||||
| sudo snap install --classic certbot | sudo snap install --classic certbot | ||||||
| @@ -160,7 +168,6 @@ sudo service nginx restart | |||||||
|  |  | ||||||
| 初始账号用户名为 `root`,密码为 `123456`。 | 初始账号用户名为 `root`,密码为 `123456`。 | ||||||
|  |  | ||||||
|  |  | ||||||
| ### 基于 Docker Compose 进行部署 | ### 基于 Docker Compose 进行部署 | ||||||
|  |  | ||||||
| > 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分 | > 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分 | ||||||
| @@ -174,20 +181,23 @@ docker-compose ps | |||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ### 手动部署 | ### 手动部署 | ||||||
|  |  | ||||||
| 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: | 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: | ||||||
|  |  | ||||||
|    ```shell |    ```shell | ||||||
|    git clone https://github.com/songquanpeng/one-api.git |    git clone https://github.com/songquanpeng/one-api.git | ||||||
|     |  | ||||||
|    # 构建前端 |    # 构建前端 | ||||||
|    cd one-api/web/default |    cd one-api/web | ||||||
|    npm install |    npm install | ||||||
|    npm run build |    npm run build | ||||||
|     |  | ||||||
|    # 构建后端 |    # 构建后端 | ||||||
|    cd ../.. |    cd .. | ||||||
|    go mod download |    go mod download | ||||||
|    go build -ldflags "-s -w" -o one-api |    go build -ldflags "-s -w" -o one-api | ||||||
|    ```` |    ``` | ||||||
|  |  | ||||||
| 2. 运行: | 2. 运行: | ||||||
|    ```shell |    ```shell | ||||||
|    chmod u+x one-api |    chmod u+x one-api | ||||||
| @@ -198,6 +208,7 @@ docker-compose ps | |||||||
| 更加详细的部署教程[参见此处](https://iamazing.cn/page/how-to-deploy-a-website)。 | 更加详细的部署教程[参见此处](https://iamazing.cn/page/how-to-deploy-a-website)。 | ||||||
|  |  | ||||||
| ### 多机部署 | ### 多机部署 | ||||||
|  |  | ||||||
| 1. 所有服务器 `SESSION_SECRET` 设置一样的值。 | 1. 所有服务器 `SESSION_SECRET` 设置一样的值。 | ||||||
| 2. 必须设置 `SQL_DSN`,使用 MySQL 数据库而非 SQLite,所有服务器连接同一个数据库。 | 2. 必须设置 `SQL_DSN`,使用 MySQL 数据库而非 SQLite,所有服务器连接同一个数据库。 | ||||||
| 3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。 | 3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。 | ||||||
| @@ -215,9 +226,11 @@ docker-compose ps | |||||||
| 如果部署后访问出现空白页面,详见 [#97](https://github.com/songquanpeng/one-api/issues/97)。 | 如果部署后访问出现空白页面,详见 [#97](https://github.com/songquanpeng/one-api/issues/97)。 | ||||||
|  |  | ||||||
| ### 部署第三方服务配合 One API 使用 | ### 部署第三方服务配合 One API 使用 | ||||||
|  |  | ||||||
| > 欢迎 PR 添加更多示例。 | > 欢迎 PR 添加更多示例。 | ||||||
|  |  | ||||||
| #### ChatGPT Next Web | #### ChatGPT Next Web | ||||||
|  |  | ||||||
| 项目主页:https://github.com/Yidadaa/ChatGPT-Next-Web | 项目主页:https://github.com/Yidadaa/ChatGPT-Next-Web | ||||||
|  |  | ||||||
| ```bash | ```bash | ||||||
| @@ -227,6 +240,7 @@ docker run --name chat-next-web -d -p 3001:3000 yidadaa/chatgpt-next-web | |||||||
| 注意修改端口号,之后在页面上设置接口地址(例如:https://openai.justsong.cn/ )和 API Key 即可。 | 注意修改端口号,之后在页面上设置接口地址(例如:https://openai.justsong.cn/ )和 API Key 即可。 | ||||||
|  |  | ||||||
| #### ChatGPT Web | #### ChatGPT Web | ||||||
|  |  | ||||||
| 项目主页:https://github.com/Chanzhaoyu/chatgpt-web | 项目主页:https://github.com/Chanzhaoyu/chatgpt-web | ||||||
|  |  | ||||||
| ```bash | ```bash | ||||||
| @@ -235,14 +249,16 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | |||||||
|  |  | ||||||
| 注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。 | 注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。 | ||||||
|  |  | ||||||
| #### QChatGPT - QQ机器人 | #### QChatGPT - QQ 机器人 | ||||||
|  |  | ||||||
| 项目主页:https://github.com/RockChinQ/QChatGPT | 项目主页:https://github.com/RockChinQ/QChatGPT | ||||||
|  |  | ||||||
| 根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。 | 根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的 key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。 | ||||||
|  |  | ||||||
| 可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。 | 可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。 | ||||||
|  |  | ||||||
| ### 部署到第三方平台 | ### 部署到第三方平台 | ||||||
|  |  | ||||||
| <details> | <details> | ||||||
| <summary><strong>部署到 Sealos </strong></summary> | <summary><strong>部署到 Sealos </strong></summary> | ||||||
| <div> | <div> | ||||||
| @@ -267,7 +283,7 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | |||||||
| 1. 首先 fork 一份代码。 | 1. 首先 fork 一份代码。 | ||||||
| 2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 | 2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 | ||||||
| 3. 新建一个 Project,在 Service -> Add Service 选择 Marketplace,选择 MySQL,并记下连接参数(用户名、密码、地址、端口)。 | 3. 新建一个 Project,在 Service -> Add Service 选择 Marketplace,选择 MySQL,并记下连接参数(用户名、密码、地址、端口)。 | ||||||
| 4. 复制链接参数,运行 ```create database `one-api` ``` 创建数据库。 | 4. 复制链接参数,运行 `` create database `one-api`  `` 创建数据库。 | ||||||
| 5. 然后在 Service -> Add Service,选择 Git(第一次使用需要先授权),选择你 fork 的仓库。 | 5. 然后在 Service -> Add Service,选择 Git(第一次使用需要先授权),选择你 fork 的仓库。 | ||||||
| 6. Deploy 会自动开始,先取消。进入下方 Variable,添加一个 `PORT`,值为 `3000`,再添加一个 `SQL_DSN`,值为 `<username>:<password>@tcp(<addr>:<port>)/one-api` ,然后保存。 注意如果不填写 `SQL_DSN`,数据将无法持久化,重新部署后数据会丢失。 | 6. Deploy 会自动开始,先取消。进入下方 Variable,添加一个 `PORT`,值为 `3000`,再添加一个 `SQL_DSN`,值为 `<username>:<password>@tcp(<addr>:<port>)/one-api` ,然后保存。 注意如果不填写 `SQL_DSN`,数据将无法持久化,重新部署后数据会丢失。 | ||||||
| 7. 选择 Redeploy。 | 7. 选择 Redeploy。 | ||||||
| @@ -289,6 +305,7 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashbo | |||||||
| </details> | </details> | ||||||
|  |  | ||||||
| ## 配置 | ## 配置 | ||||||
|  |  | ||||||
| 系统本身开箱即用。 | 系统本身开箱即用。 | ||||||
|  |  | ||||||
| 你可以通过设置环境变量或者命令行参数进行配置。 | 你可以通过设置环境变量或者命令行参数进行配置。 | ||||||
| @@ -298,6 +315,7 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashbo | |||||||
| **Note**:如果你不知道某个配置项的含义,可以临时删掉值以看到进一步的提示文字。 | **Note**:如果你不知道某个配置项的含义,可以临时删掉值以看到进一步的提示文字。 | ||||||
|  |  | ||||||
| ## 使用方法 | ## 使用方法 | ||||||
|  |  | ||||||
| 在`渠道`页面中添加你的 API Key,之后在`令牌`页面中新增访问令牌。 | 在`渠道`页面中添加你的 API Key,之后在`令牌`页面中新增访问令牌。 | ||||||
|  |  | ||||||
| 之后就可以使用你的令牌访问 One API 了,使用方式与 [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) 一致。 | 之后就可以使用你的令牌访问 One API 了,使用方式与 [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) 一致。 | ||||||
| @@ -307,9 +325,10 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashbo | |||||||
| 注意,具体的 API Base 的格式取决于你所使用的客户端。 | 注意,具体的 API Base 的格式取决于你所使用的客户端。 | ||||||
|  |  | ||||||
| 例如对于 OpenAI 的官方库: | 例如对于 OpenAI 的官方库: | ||||||
|  |  | ||||||
| ```bash | ```bash | ||||||
| OPENAI_API_KEY="sk-xxxxxx" | OPENAI_API_KEY="sk-xxxxxx" | ||||||
| OPENAI_API_BASE="https://<HOST>:<PORT>/v1"  | OPENAI_API_BASE="https://<HOST>:<PORT>/v1" | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ```mermaid | ```mermaid | ||||||
| @@ -328,104 +347,106 @@ graph LR | |||||||
| 不加的话将会使用负载均衡的方式使用多个渠道。 | 不加的话将会使用负载均衡的方式使用多个渠道。 | ||||||
|  |  | ||||||
| ### 环境变量 | ### 环境变量 | ||||||
|  |  | ||||||
| 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 | 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 | ||||||
|    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` |    - 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||||
|    + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 |    - 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 | ||||||
| 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 | 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 | ||||||
|    + 例子:`SESSION_SECRET=random_string` |    - 例子:`SESSION_SECRET=random_string` | ||||||
| 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 | 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 | ||||||
|    + 例子: |    - 例子: | ||||||
|      + MySQL:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` |      - MySQL:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||||
|      + PostgreSQL:`SQL_DSN=postgres://postgres:123456@localhost:5432/oneapi`(适配中,欢迎反馈) |      - PostgreSQL:`SQL_DSN=postgres://postgres:123456@localhost:5432/oneapi`(适配中,欢迎反馈) | ||||||
|    + 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。 |    - 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。 | ||||||
|    + 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。 |    - 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。 | ||||||
|    + 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。 |    - 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。 | ||||||
|    + 请根据你的数据库配置修改下列参数(或者保持默认值): |    - 请根据你的数据库配置修改下列参数(或者保持默认值): | ||||||
|      + `SQL_MAX_IDLE_CONNS`:最大空闲连接数,默认为 `100`。 |      - `SQL_MAX_IDLE_CONNS`:最大空闲连接数,默认为 `100`。 | ||||||
|      + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 |      - `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 | ||||||
|        + 如果报错 `Error 1040: Too many connections`,请适当减小该值。 |        - 如果报错 `Error 1040: Too many connections`,请适当减小该值。 | ||||||
|      + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 |      - `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 | ||||||
| 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | ||||||
|    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` |    - 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||||
| 5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 | 5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||||
|    + 例子:`MEMORY_CACHE_ENABLED=true` |    - 例子:`MEMORY_CACHE_ENABLED=true` | ||||||
| 6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 | 6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 | ||||||
|    + 例子:`SYNC_FREQUENCY=60` |    - 例子:`SYNC_FREQUENCY=60` | ||||||
| 7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | 7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | ||||||
|    + 例子:`NODE_TYPE=slave` |    - 例子:`NODE_TYPE=slave` | ||||||
| 8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | 8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | ||||||
|    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` |    - 例子:`CHANNEL_UPDATE_FREQUENCY=1440` | ||||||
| 9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | 9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | ||||||
|    + 例子:`CHANNEL_TEST_FREQUENCY=1440` |    - 例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||||
| 10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | 10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||||
|     + 例子:`POLLING_INTERVAL=5` |     - 例子:`POLLING_INTERVAL=5` | ||||||
| 11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | 11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||||
|     + 例子:`BATCH_UPDATE_ENABLED=true` |     - 例子:`BATCH_UPDATE_ENABLED=true` | ||||||
|     + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 |     - 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 | ||||||
| 12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | 12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||||
|     + 例子:`BATCH_UPDATE_INTERVAL=5` |     - 例子:`BATCH_UPDATE_INTERVAL=5` | ||||||
| 13. 请求频率限制: | 13. 请求频率限制: | ||||||
|     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 |     - `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 | ||||||
|     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 |     - `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 | ||||||
| 14. 编码器缓存设置: | 14. 编码器缓存设置: | ||||||
|     + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 |     - `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 | ||||||
|     + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 |     - `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 | ||||||
| 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | ||||||
| 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | ||||||
| 17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 |  | ||||||
| 18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 |  | ||||||
|  |  | ||||||
| ### 命令行参数 | ### 命令行参数 | ||||||
|  |  | ||||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||||
|    + 例子:`--port 3000` |    - 例子:`--port 3000` | ||||||
| 2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。 | 2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。 | ||||||
|    + 例子:`--log-dir ./logs` |    - 例子:`--log-dir ./logs` | ||||||
| 3. `--version`: 打印系统版本号并退出。 | 3. `--version`: 打印系统版本号并退出。 | ||||||
| 4. `--help`: 查看命令的使用帮助和参数说明。 | 4. `--help`: 查看命令的使用帮助和参数说明。 | ||||||
|  |  | ||||||
| ## 演示 | ## 演示 | ||||||
|  |  | ||||||
| ### 在线演示 | ### 在线演示 | ||||||
|  |  | ||||||
| 注意,该演示站不提供对外服务: | 注意,该演示站不提供对外服务: | ||||||
| https://openai.justsong.cn | https://openai.justsong.cn | ||||||
|  |  | ||||||
| ### 截图展示 | ### 截图展示 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| ## 常见问题 | ## 常见问题 | ||||||
|  |  | ||||||
| 1. 额度是什么?怎么计算的?One API 的额度计算有问题? | 1. 额度是什么?怎么计算的?One API 的额度计算有问题? | ||||||
|    + 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率) |    - 额度 = 分组倍率 _ 模型倍率 _ (提示 token 数 + 补全 token 数 \* 补全倍率) | ||||||
|    + 其中补全倍率对于 GPT3.5 固定为 1.33,GPT4 为 2,与官方保持一致。 |    - 其中补全倍率对于 GPT3.5 固定为 1.33,GPT4 为 2,与官方保持一致。 | ||||||
|    + 如果是非流模式,官方接口会返回消耗的总 token,但是你要注意提示和补全的消耗倍率不一样。 |    - 如果是非流模式,官方接口会返回消耗的总 token,但是你要注意提示和补全的消耗倍率不一样。 | ||||||
|    + 注意,One API 的默认倍率就是官方倍率,是已经调整过的。 |    - 注意,One API 的默认倍率就是官方倍率,是已经调整过的。 | ||||||
| 2. 账户额度足够为什么提示额度不足? | 2. 账户额度足够为什么提示额度不足? | ||||||
|    + 请检查你的令牌额度是否足够,这个和账户额度是分开的。 |    - 请检查你的令牌额度是否足够,这个和账户额度是分开的。 | ||||||
|    + 令牌额度仅供用户设置最大使用量,用户可自由设置。 |    - 令牌额度仅供用户设置最大使用量,用户可自由设置。 | ||||||
| 3. 提示无可用渠道? | 3. 提示无可用渠道? | ||||||
|    + 请检查的用户分组和渠道分组设置。 |    - 请检查的用户分组和渠道分组设置。 | ||||||
|    + 以及渠道的模型设置。 |    - 以及渠道的模型设置。 | ||||||
| 4. 渠道测试报错:`invalid character '<' looking for beginning of value` | 4. 渠道测试报错:`invalid character '<' looking for beginning of value` | ||||||
|    + 这是因为返回值不是合法的 JSON,而是一个 HTML 页面。 |    - 这是因为返回值不是合法的 JSON,而是一个 HTML 页面。 | ||||||
|    + 大概率是你的部署站的 IP 或代理的节点被 CloudFlare 封禁了。 |    - 大概率是你的部署站的 IP 或代理的节点被 CloudFlare 封禁了。 | ||||||
| 5. ChatGPT Next Web 报错:`Failed to fetch` | 5. ChatGPT Next Web 报错:`Failed to fetch` | ||||||
|    + 部署的时候不要设置 `BASE_URL`。 |    - 部署的时候不要设置 `BASE_URL`。 | ||||||
|    + 检查你的接口地址和 API Key 有没有填对。 |    - 检查你的接口地址和 API Key 有没有填对。 | ||||||
|    + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 |    - 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 | ||||||
| 6. 报错:`当前分组负载已饱和,请稍后再试` | 6. 报错:`当前分组负载已饱和,请稍后再试` | ||||||
|    + 上游通道 429 了。 |    - 上游通道 429 了。 | ||||||
| 7. 升级之后我的数据会丢失吗? | 7. 升级之后我的数据会丢失吗? | ||||||
|    + 如果使用 MySQL,不会。 |    - 如果使用 MySQL,不会。 | ||||||
|    + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 |    - 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 | ||||||
| 8. 升级之前数据库需要做变更吗? | 8. 升级之前数据库需要做变更吗? | ||||||
|    + 一般情况下不需要,系统将在初始化的时候自动调整。 |    - 一般情况下不需要,系统将在初始化的时候自动调整。 | ||||||
|    + 如果需要的话,我会在更新日志中说明,并给出脚本。 |    - 如果需要的话,我会在更新日志中说明,并给出脚本。 | ||||||
| 9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`? |  | ||||||
|    + 这是检测到 ability 表里有些记录的通道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的通道。 |  | ||||||
|    + 对于每一个通道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该通道支持该模型。 |  | ||||||
|  |  | ||||||
| ## 相关项目 | ## 相关项目 | ||||||
| * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 |  | ||||||
| * [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web):  一键拥有你自己的跨平台 ChatGPT 应用 | - [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | ||||||
|  | - [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用 | ||||||
|  |  | ||||||
| ## 注意 | ## 注意 | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										299
									
								
								common/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										299
									
								
								common/client.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,299 @@ | |||||||
|  | package common | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/url" | ||||||
|  | 	"one-api/types" | ||||||
|  | 	"strconv" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"golang.org/x/net/proxy" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var clientPool = &sync.Pool{ | ||||||
|  | 	New: func() interface{} { | ||||||
|  | 		return &http.Client{} | ||||||
|  | 	}, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetHttpClient(proxyAddr string) *http.Client { | ||||||
|  | 	client := clientPool.Get().(*http.Client) | ||||||
|  |  | ||||||
|  | 	if RelayTimeout > 0 { | ||||||
|  | 		client.Timeout = time.Duration(RelayTimeout) * time.Second | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if proxyAddr != "" { | ||||||
|  | 		proxyURL, err := url.Parse(proxyAddr) | ||||||
|  | 		if err != nil { | ||||||
|  | 			SysError("Error parsing proxy address: " + err.Error()) | ||||||
|  | 			return client | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		switch proxyURL.Scheme { | ||||||
|  | 		case "http", "https": | ||||||
|  | 			client.Transport = &http.Transport{ | ||||||
|  | 				Proxy: http.ProxyURL(proxyURL), | ||||||
|  | 			} | ||||||
|  | 		case "socks5": | ||||||
|  | 			dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct) | ||||||
|  | 			if err != nil { | ||||||
|  | 				SysError("Error creating SOCKS5 dialer: " + err.Error()) | ||||||
|  | 				return client | ||||||
|  | 			} | ||||||
|  | 			client.Transport = &http.Transport{ | ||||||
|  | 				Dial: dialer.Dial, | ||||||
|  | 			} | ||||||
|  | 		default: | ||||||
|  | 			SysError("Unsupported proxy scheme: " + proxyURL.Scheme) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return client | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func PutHttpClient(c *http.Client) { | ||||||
|  | 	clientPool.Put(c) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Client struct { | ||||||
|  | 	requestBuilder    RequestBuilder | ||||||
|  | 	CreateFormBuilder func(io.Writer) FormBuilder | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewClient() *Client { | ||||||
|  | 	return &Client{ | ||||||
|  | 		requestBuilder: NewRequestBuilder(), | ||||||
|  | 		CreateFormBuilder: func(body io.Writer) FormBuilder { | ||||||
|  | 			return NewFormBuilder(body) | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type requestOptions struct { | ||||||
|  | 	body   any | ||||||
|  | 	header http.Header | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type requestOption func(*requestOptions) | ||||||
|  |  | ||||||
|  | type Stringer interface { | ||||||
|  | 	GetString() *string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func WithBody(body any) requestOption { | ||||||
|  | 	return func(args *requestOptions) { | ||||||
|  | 		args.body = body | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func WithHeader(header map[string]string) requestOption { | ||||||
|  | 	return func(args *requestOptions) { | ||||||
|  | 		for k, v := range header { | ||||||
|  | 			args.header.Set(k, v) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func WithContentType(contentType string) requestOption { | ||||||
|  | 	return func(args *requestOptions) { | ||||||
|  | 		args.header.Set("Content-Type", contentType) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type RequestError struct { | ||||||
|  | 	HTTPStatusCode int | ||||||
|  | 	Err            error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) { | ||||||
|  | 	// Default Options | ||||||
|  | 	args := &requestOptions{ | ||||||
|  | 		body:   nil, | ||||||
|  | 		header: make(http.Header), | ||||||
|  | 	} | ||||||
|  | 	for _, setter := range setters { | ||||||
|  | 		setter(args) | ||||||
|  | 	} | ||||||
|  | 	req, err := c.requestBuilder.Build(method, url, args.body, args.header) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return req, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func SendRequest(req *http.Request, response any, outputResp bool, proxyAddr string) (*http.Response, *types.OpenAIErrorWithStatusCode) { | ||||||
|  | 	// 发送请求 | ||||||
|  | 	client := GetHttpClient(proxyAddr) | ||||||
|  | 	resp, err := client.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	PutHttpClient(client) | ||||||
|  |  | ||||||
|  | 	if !outputResp { | ||||||
|  | 		defer resp.Body.Close() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 处理响应 | ||||||
|  | 	if IsFailureStatusCode(resp) { | ||||||
|  | 		return nil, HandleErrorResp(resp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 解析响应 | ||||||
|  | 	if outputResp { | ||||||
|  | 		var buf bytes.Buffer | ||||||
|  | 		tee := io.TeeReader(resp.Body, &buf) | ||||||
|  | 		err = DecodeResponse(tee, response) | ||||||
|  |  | ||||||
|  | 		// 将响应体重新写入 resp.Body | ||||||
|  | 		resp.Body = io.NopCloser(&buf) | ||||||
|  | 	} else { | ||||||
|  | 		err = DecodeResponse(resp.Body, response) | ||||||
|  | 	} | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if outputResp { | ||||||
|  | 		return resp, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type GeneralErrorResponse struct { | ||||||
|  | 	Error    types.OpenAIError `json:"error"` | ||||||
|  | 	Message  string            `json:"message"` | ||||||
|  | 	Msg      string            `json:"msg"` | ||||||
|  | 	Err      string            `json:"err"` | ||||||
|  | 	ErrorMsg string            `json:"error_msg"` | ||||||
|  | 	Header   struct { | ||||||
|  | 		Message string `json:"message"` | ||||||
|  | 	} `json:"header"` | ||||||
|  | 	Response struct { | ||||||
|  | 		Error struct { | ||||||
|  | 			Message string `json:"message"` | ||||||
|  | 		} `json:"error"` | ||||||
|  | 	} `json:"response"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (e GeneralErrorResponse) ToMessage() string { | ||||||
|  | 	if e.Error.Message != "" { | ||||||
|  | 		return e.Error.Message | ||||||
|  | 	} | ||||||
|  | 	if e.Message != "" { | ||||||
|  | 		return e.Message | ||||||
|  | 	} | ||||||
|  | 	if e.Msg != "" { | ||||||
|  | 		return e.Msg | ||||||
|  | 	} | ||||||
|  | 	if e.Err != "" { | ||||||
|  | 		return e.Err | ||||||
|  | 	} | ||||||
|  | 	if e.ErrorMsg != "" { | ||||||
|  | 		return e.ErrorMsg | ||||||
|  | 	} | ||||||
|  | 	if e.Header.Message != "" { | ||||||
|  | 		return e.Header.Message | ||||||
|  | 	} | ||||||
|  | 	if e.Response.Error.Message != "" { | ||||||
|  | 		return e.Response.Error.Message | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // 处理错误响应 | ||||||
|  | func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { | ||||||
|  | 	openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{ | ||||||
|  | 		StatusCode: resp.StatusCode, | ||||||
|  | 		OpenAIError: types.OpenAIError{ | ||||||
|  | 			Message: "", | ||||||
|  | 			Type:    "upstream_error", | ||||||
|  | 			Code:    "bad_response_status_code", | ||||||
|  | 			Param:   strconv.Itoa(resp.StatusCode), | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	// var errorResponse types.OpenAIErrorResponse | ||||||
|  | 	var errorResponse GeneralErrorResponse | ||||||
|  | 	err = json.Unmarshal(responseBody, &errorResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if errorResponse.Error.Message != "" { | ||||||
|  | 		// OpenAI format error, so we override the default one | ||||||
|  | 		openAIErrorWithStatusCode.OpenAIError = errorResponse.Error | ||||||
|  | 	} else { | ||||||
|  | 		openAIErrorWithStatusCode.OpenAIError.Message = errorResponse.ToMessage() | ||||||
|  | 	} | ||||||
|  | 	if openAIErrorWithStatusCode.OpenAIError.Message == "" { | ||||||
|  | 		openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Client) SendRequestRaw(req *http.Request, proxyAddr string) (body io.ReadCloser, err error) { | ||||||
|  | 	client := GetHttpClient(proxyAddr) | ||||||
|  | 	resp, err := client.Do(req) | ||||||
|  | 	PutHttpClient(client) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return resp.Body, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func IsFailureStatusCode(resp *http.Response) bool { | ||||||
|  | 	return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func DecodeResponse(body io.Reader, v any) error { | ||||||
|  | 	if v == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if result, ok := v.(*string); ok { | ||||||
|  | 		return DecodeString(body, result) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if stringer, ok := v.(Stringer); ok { | ||||||
|  | 		return DecodeString(body, stringer.GetString()) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return json.NewDecoder(body).Decode(v) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func DecodeString(body io.Reader, output *string) error { | ||||||
|  | 	b, err := io.ReadAll(body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	*output = string(b) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func SetEventStreamHeaders(c *gin.Context) { | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||||
|  | 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||||
|  | 	c.Writer.Header().Set("Connection", "keep-alive") | ||||||
|  | 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||||
|  | 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||||
|  | } | ||||||
| @@ -1,127 +0,0 @@ | |||||||
| package config |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" |  | ||||||
| 	"os" |  | ||||||
| 	"strconv" |  | ||||||
| 	"sync" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/google/uuid" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| var SystemName = "One API" |  | ||||||
| var ServerAddress = "http://localhost:3000" |  | ||||||
| var Footer = "" |  | ||||||
| var Logo = "" |  | ||||||
| var TopUpLink = "" |  | ||||||
| var ChatLink = "" |  | ||||||
| var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens |  | ||||||
| var DisplayInCurrencyEnabled = true |  | ||||||
| var DisplayTokenStatEnabled = true |  | ||||||
|  |  | ||||||
| // Any options with "Secret", "Token" in its key won't be return by GetOptions |  | ||||||
|  |  | ||||||
| var SessionSecret = uuid.New().String() |  | ||||||
|  |  | ||||||
| var OptionMap map[string]string |  | ||||||
| var OptionMapRWMutex sync.RWMutex |  | ||||||
|  |  | ||||||
| var ItemsPerPage = 10 |  | ||||||
| var MaxRecentItems = 100 |  | ||||||
|  |  | ||||||
| var PasswordLoginEnabled = true |  | ||||||
| var PasswordRegisterEnabled = true |  | ||||||
| var EmailVerificationEnabled = false |  | ||||||
| var GitHubOAuthEnabled = false |  | ||||||
| var WeChatAuthEnabled = false |  | ||||||
| var TurnstileCheckEnabled = false |  | ||||||
| var RegisterEnabled = true |  | ||||||
|  |  | ||||||
| var EmailDomainRestrictionEnabled = false |  | ||||||
| var EmailDomainWhitelist = []string{ |  | ||||||
| 	"gmail.com", |  | ||||||
| 	"163.com", |  | ||||||
| 	"126.com", |  | ||||||
| 	"qq.com", |  | ||||||
| 	"outlook.com", |  | ||||||
| 	"hotmail.com", |  | ||||||
| 	"icloud.com", |  | ||||||
| 	"yahoo.com", |  | ||||||
| 	"foxmail.com", |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var DebugEnabled = os.Getenv("DEBUG") == "true" |  | ||||||
| var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" |  | ||||||
|  |  | ||||||
| var LogConsumeEnabled = true |  | ||||||
|  |  | ||||||
| var SMTPServer = "" |  | ||||||
| var SMTPPort = 587 |  | ||||||
| var SMTPAccount = "" |  | ||||||
| var SMTPFrom = "" |  | ||||||
| var SMTPToken = "" |  | ||||||
|  |  | ||||||
| var GitHubClientId = "" |  | ||||||
| var GitHubClientSecret = "" |  | ||||||
|  |  | ||||||
| var WeChatServerAddress = "" |  | ||||||
| var WeChatServerToken = "" |  | ||||||
| var WeChatAccountQRCodeImageURL = "" |  | ||||||
|  |  | ||||||
| var TurnstileSiteKey = "" |  | ||||||
| var TurnstileSecretKey = "" |  | ||||||
|  |  | ||||||
| var QuotaForNewUser = 0 |  | ||||||
| var QuotaForInviter = 0 |  | ||||||
| var QuotaForInvitee = 0 |  | ||||||
| var ChannelDisableThreshold = 5.0 |  | ||||||
| var AutomaticDisableChannelEnabled = false |  | ||||||
| var AutomaticEnableChannelEnabled = false |  | ||||||
| var QuotaRemindThreshold = 1000 |  | ||||||
| var PreConsumedQuota = 500 |  | ||||||
| var ApproximateTokenEnabled = false |  | ||||||
| var RetryTimes = 0 |  | ||||||
|  |  | ||||||
| var RootUserEmail = "" |  | ||||||
|  |  | ||||||
| var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" |  | ||||||
|  |  | ||||||
| var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) |  | ||||||
| var RequestInterval = time.Duration(requestInterval) * time.Second |  | ||||||
|  |  | ||||||
| var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second |  | ||||||
|  |  | ||||||
| var BatchUpdateEnabled = false |  | ||||||
| var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) |  | ||||||
|  |  | ||||||
| var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second |  | ||||||
|  |  | ||||||
| var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") |  | ||||||
|  |  | ||||||
| var Theme = helper.GetOrDefaultEnvString("THEME", "default") |  | ||||||
| var ValidThemes = map[string]bool{ |  | ||||||
| 	"default": true, |  | ||||||
| 	"berry":   true, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // All duration's unit is seconds |  | ||||||
| // Shouldn't larger then RateLimitKeyExpirationDuration |  | ||||||
| var ( |  | ||||||
| 	GlobalApiRateLimitNum            = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) |  | ||||||
| 	GlobalApiRateLimitDuration int64 = 3 * 60 |  | ||||||
|  |  | ||||||
| 	GlobalWebRateLimitNum            = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) |  | ||||||
| 	GlobalWebRateLimitDuration int64 = 3 * 60 |  | ||||||
|  |  | ||||||
| 	UploadRateLimitNum            = 10 |  | ||||||
| 	UploadRateLimitDuration int64 = 60 |  | ||||||
|  |  | ||||||
| 	DownloadRateLimitNum            = 10 |  | ||||||
| 	DownloadRateLimitDuration int64 = 60 |  | ||||||
|  |  | ||||||
| 	CriticalRateLimitNum            = 20 |  | ||||||
| 	CriticalRateLimitDuration int64 = 20 * 60 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| var RateLimitKeyExpirationDuration = 20 * time.Minute |  | ||||||
| @@ -1,9 +1,106 @@ | |||||||
| package common | package common | ||||||
|  |  | ||||||
| import "time" | import ( | ||||||
|  | 	"os" | ||||||
|  | 	"strconv" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/google/uuid" | ||||||
|  | ) | ||||||
|  |  | ||||||
| var StartTime = time.Now().Unix() // unit: second | var StartTime = time.Now().Unix() // unit: second | ||||||
| var Version = "v0.0.0"            // this hard coding will be replaced automatically when building, no need to manually change | var Version = "v0.0.0"            // this hard coding will be replaced automatically when building, no need to manually change | ||||||
|  | var SystemName = "One API" | ||||||
|  | var ServerAddress = "http://localhost:3000" | ||||||
|  | var Footer = "" | ||||||
|  | var Logo = "" | ||||||
|  | var TopUpLink = "" | ||||||
|  | var ChatLink = "" | ||||||
|  | var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens | ||||||
|  | var DisplayInCurrencyEnabled = true | ||||||
|  | var DisplayTokenStatEnabled = true | ||||||
|  |  | ||||||
|  | // Any options with "Secret", "Token" in its key won't be return by GetOptions | ||||||
|  |  | ||||||
|  | var SessionSecret = uuid.New().String() | ||||||
|  |  | ||||||
|  | var OptionMap map[string]string | ||||||
|  | var OptionMapRWMutex sync.RWMutex | ||||||
|  |  | ||||||
|  | var ItemsPerPage = 10 | ||||||
|  | var MaxRecentItems = 100 | ||||||
|  |  | ||||||
|  | var PasswordLoginEnabled = true | ||||||
|  | var PasswordRegisterEnabled = true | ||||||
|  | var EmailVerificationEnabled = false | ||||||
|  | var GitHubOAuthEnabled = false | ||||||
|  | var WeChatAuthEnabled = false | ||||||
|  | var TurnstileCheckEnabled = false | ||||||
|  | var RegisterEnabled = true | ||||||
|  |  | ||||||
|  | var EmailDomainRestrictionEnabled = false | ||||||
|  | var EmailDomainWhitelist = []string{ | ||||||
|  | 	"gmail.com", | ||||||
|  | 	"163.com", | ||||||
|  | 	"126.com", | ||||||
|  | 	"qq.com", | ||||||
|  | 	"outlook.com", | ||||||
|  | 	"hotmail.com", | ||||||
|  | 	"icloud.com", | ||||||
|  | 	"yahoo.com", | ||||||
|  | 	"foxmail.com", | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var DebugEnabled = os.Getenv("DEBUG") == "true" | ||||||
|  | var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" | ||||||
|  |  | ||||||
|  | var LogConsumeEnabled = true | ||||||
|  |  | ||||||
|  | var SMTPServer = "" | ||||||
|  | var SMTPPort = 587 | ||||||
|  | var SMTPAccount = "" | ||||||
|  | var SMTPFrom = "" | ||||||
|  | var SMTPToken = "" | ||||||
|  |  | ||||||
|  | var GitHubClientId = "" | ||||||
|  | var GitHubClientSecret = "" | ||||||
|  |  | ||||||
|  | var WeChatServerAddress = "" | ||||||
|  | var WeChatServerToken = "" | ||||||
|  | var WeChatAccountQRCodeImageURL = "" | ||||||
|  |  | ||||||
|  | var TurnstileSiteKey = "" | ||||||
|  | var TurnstileSecretKey = "" | ||||||
|  |  | ||||||
|  | var QuotaForNewUser = 0 | ||||||
|  | var QuotaForInviter = 0 | ||||||
|  | var QuotaForInvitee = 0 | ||||||
|  | var ChannelDisableThreshold = 5.0 | ||||||
|  | var AutomaticDisableChannelEnabled = false | ||||||
|  | var AutomaticEnableChannelEnabled = false | ||||||
|  | var QuotaRemindThreshold = 1000 | ||||||
|  | var PreConsumedQuota = 500 | ||||||
|  | var ApproximateTokenEnabled = false | ||||||
|  | var RetryTimes = 0 | ||||||
|  |  | ||||||
|  | var RootUserEmail = "" | ||||||
|  |  | ||||||
|  | var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" | ||||||
|  |  | ||||||
|  | var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | ||||||
|  | var RequestInterval = time.Duration(requestInterval) * time.Second | ||||||
|  |  | ||||||
|  | var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second | ||||||
|  |  | ||||||
|  | var BatchUpdateEnabled = false | ||||||
|  | var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) | ||||||
|  |  | ||||||
|  | var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	RequestIdKey = "X-Oneapi-Request-Id" | ||||||
|  | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	RoleGuestUser  = 0 | 	RoleGuestUser  = 0 | ||||||
| @@ -12,6 +109,34 @@ const ( | |||||||
| 	RoleRootUser   = 100 | 	RoleRootUser   = 100 | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	FileUploadPermission    = RoleGuestUser | ||||||
|  | 	FileDownloadPermission  = RoleGuestUser | ||||||
|  | 	ImageUploadPermission   = RoleGuestUser | ||||||
|  | 	ImageDownloadPermission = RoleGuestUser | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // All duration's unit is seconds | ||||||
|  | // Shouldn't larger then RateLimitKeyExpirationDuration | ||||||
|  | var ( | ||||||
|  | 	GlobalApiRateLimitNum            = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) | ||||||
|  | 	GlobalApiRateLimitDuration int64 = 3 * 60 | ||||||
|  |  | ||||||
|  | 	GlobalWebRateLimitNum            = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 100) | ||||||
|  | 	GlobalWebRateLimitDuration int64 = 3 * 60 | ||||||
|  |  | ||||||
|  | 	UploadRateLimitNum            = 10 | ||||||
|  | 	UploadRateLimitDuration int64 = 60 | ||||||
|  |  | ||||||
|  | 	DownloadRateLimitNum            = 10 | ||||||
|  | 	DownloadRateLimitDuration int64 = 60 | ||||||
|  |  | ||||||
|  | 	CriticalRateLimitNum            = 20 | ||||||
|  | 	CriticalRateLimitDuration int64 = 20 * 60 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var RateLimitKeyExpirationDuration = 20 * time.Minute | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	UserStatusEnabled  = 1 // don't use 0, 0 is the default value! | 	UserStatusEnabled  = 1 // don't use 0, 0 is the default value! | ||||||
| 	UserStatusDisabled = 2 // also don't use 0 | 	UserStatusDisabled = 2 // also don't use 0 | ||||||
| @@ -38,77 +163,74 @@ const ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	ChannelTypeUnknown = iota | 	ChannelTypeUnknown        = 0 | ||||||
| 	ChannelTypeOpenAI | 	ChannelTypeOpenAI         = 1 | ||||||
| 	ChannelTypeAPI2D | 	ChannelTypeAPI2D          = 2 | ||||||
| 	ChannelTypeAzure | 	ChannelTypeAzure          = 3 | ||||||
| 	ChannelTypeCloseAI | 	ChannelTypeCloseAI        = 4 | ||||||
| 	ChannelTypeOpenAISB | 	ChannelTypeOpenAISB       = 5 | ||||||
| 	ChannelTypeOpenAIMax | 	ChannelTypeOpenAIMax      = 6 | ||||||
| 	ChannelTypeOhMyGPT | 	ChannelTypeOhMyGPT        = 7 | ||||||
| 	ChannelTypeCustom | 	ChannelTypeCustom         = 8 | ||||||
| 	ChannelTypeAILS | 	ChannelTypeAILS           = 9 | ||||||
| 	ChannelTypeAIProxy | 	ChannelTypeAIProxy        = 10 | ||||||
| 	ChannelTypePaLM | 	ChannelTypePaLM           = 11 | ||||||
| 	ChannelTypeAPI2GPT | 	ChannelTypeAPI2GPT        = 12 | ||||||
| 	ChannelTypeAIGC2D | 	ChannelTypeAIGC2D         = 13 | ||||||
| 	ChannelTypeAnthropic | 	ChannelTypeAnthropic      = 14 | ||||||
| 	ChannelTypeBaidu | 	ChannelTypeBaidu          = 15 | ||||||
| 	ChannelTypeZhipu | 	ChannelTypeZhipu          = 16 | ||||||
| 	ChannelTypeAli | 	ChannelTypeAli            = 17 | ||||||
| 	ChannelTypeXunfei | 	ChannelTypeXunfei         = 18 | ||||||
| 	ChannelType360 | 	ChannelType360            = 19 | ||||||
| 	ChannelTypeOpenRouter | 	ChannelTypeOpenRouter     = 20 | ||||||
| 	ChannelTypeAIProxyLibrary | 	ChannelTypeAIProxyLibrary = 21 | ||||||
| 	ChannelTypeFastGPT | 	ChannelTypeFastGPT        = 22 | ||||||
| 	ChannelTypeTencent | 	ChannelTypeTencent        = 23 | ||||||
| 	ChannelTypeGemini | 	ChannelTypeAzureSpeech    = 24 | ||||||
| 	ChannelTypeMoonshot | 	ChannelTypeGemini         = 25 | ||||||
| 	ChannelTypeBaichuan |  | ||||||
| 	ChannelTypeMinimax |  | ||||||
| 	ChannelTypeMistral |  | ||||||
| 	ChannelTypeGroq |  | ||||||
|  |  | ||||||
| 	ChannelTypeDummy |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ChannelBaseURLs = []string{ | var ChannelBaseURLs = []string{ | ||||||
| 	"",                              // 0 | 	"",                                  // 0 | ||||||
| 	"https://api.openai.com",        // 1 | 	"https://api.openai.com",            // 1 | ||||||
| 	"https://oa.api2d.net",          // 2 | 	"https://oa.api2d.net",              // 2 | ||||||
| 	"",                              // 3 | 	"",                                  // 3 | ||||||
| 	"https://api.closeai-proxy.xyz", // 4 | 	"https://api.closeai-proxy.xyz",     // 4 | ||||||
| 	"https://api.openai-sb.com",     // 5 | 	"https://api.openai-sb.com",         // 5 | ||||||
| 	"https://api.openaimax.com",     // 6 | 	"https://api.openaimax.com",         // 6 | ||||||
| 	"https://api.ohmygpt.com",       // 7 | 	"https://api.ohmygpt.com",           // 7 | ||||||
| 	"",                              // 8 | 	"",                                  // 8 | ||||||
| 	"https://api.caipacity.com",     // 9 | 	"https://api.caipacity.com",         // 9 | ||||||
| 	"https://api.aiproxy.io",        // 10 | 	"https://api.aiproxy.io",            // 10 | ||||||
| 	"https://generativelanguage.googleapis.com", // 11 | 	"",                                  // 11 | ||||||
| 	"https://api.api2gpt.com",                   // 12 | 	"https://api.api2gpt.com",           // 12 | ||||||
| 	"https://api.aigc2d.com",                    // 13 | 	"https://api.aigc2d.com",            // 13 | ||||||
| 	"https://api.anthropic.com",                 // 14 | 	"https://api.anthropic.com",         // 14 | ||||||
| 	"https://aip.baidubce.com",                  // 15 | 	"https://aip.baidubce.com",          // 15 | ||||||
| 	"https://open.bigmodel.cn",                  // 16 | 	"https://open.bigmodel.cn",          // 16 | ||||||
| 	"https://dashscope.aliyuncs.com",            // 17 | 	"https://dashscope.aliyuncs.com",    // 17 | ||||||
| 	"",                                          // 18 | 	"",                                  // 18 | ||||||
| 	"https://ai.360.cn",                         // 19 | 	"https://ai.360.cn",                 // 19 | ||||||
| 	"https://openrouter.ai/api",                 // 20 | 	"https://openrouter.ai/api",         // 20 | ||||||
| 	"https://api.aiproxy.io",                    // 21 | 	"https://api.aiproxy.io",            // 21 | ||||||
| 	"https://fastgpt.run/api/openapi",           // 22 | 	"https://fastgpt.run/api/openapi",   // 22 | ||||||
| 	"https://hunyuan.cloud.tencent.com",         // 23 | 	"https://hunyuan.cloud.tencent.com", //23 | ||||||
| 	"https://generativelanguage.googleapis.com", // 24 | 	"",                                  //24 | ||||||
| 	"https://api.moonshot.cn",                   // 25 | 	"",                                  //25 | ||||||
| 	"https://api.baichuan-ai.com",               // 26 |  | ||||||
| 	"https://api.minimax.chat",                  // 27 |  | ||||||
| 	"https://api.mistral.ai",                    // 28 |  | ||||||
| 	"https://api.groq.com/openai",               // 29 |  | ||||||
| } | } | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	ConfigKeyPrefix = "cfg_" | 	RelayModeUnknown = iota | ||||||
|  | 	RelayModeChatCompletions | ||||||
| 	ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version" | 	RelayModeCompletions | ||||||
| 	ConfigKeyLibraryID  = ConfigKeyPrefix + "library_id" | 	RelayModeEmbeddings | ||||||
| 	ConfigKeyPlugin     = ConfigKeyPrefix + "plugin" | 	RelayModeModerations | ||||||
|  | 	RelayModeImagesGenerations | ||||||
|  | 	RelayModeImagesEdits | ||||||
|  | 	RelayModeImagesVariations | ||||||
|  | 	RelayModeEdits | ||||||
|  | 	RelayModeAudioSpeech | ||||||
|  | 	RelayModeAudioTranscription | ||||||
|  | 	RelayModeAudioTranslation | ||||||
| ) | ) | ||||||
|   | |||||||
| @@ -1,9 +1,7 @@ | |||||||
| package common | package common | ||||||
|  |  | ||||||
| import "github.com/songquanpeng/one-api/common/helper" |  | ||||||
|  |  | ||||||
| var UsingSQLite = false | var UsingSQLite = false | ||||||
| var UsingPostgreSQL = false | var UsingPostgreSQL = false | ||||||
|  |  | ||||||
| var SQLitePath = "one-api.db" | var SQLitePath = "one-api.db" | ||||||
| var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) | var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000) | ||||||
|   | |||||||
| @@ -5,20 +5,19 @@ import ( | |||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"net/smtp" | 	"net/smtp" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func SendEmail(subject string, receiver string, content string) error { | func SendEmail(subject string, receiver string, content string) error { | ||||||
| 	if config.SMTPFrom == "" { // for compatibility | 	if SMTPFrom == "" { // for compatibility | ||||||
| 		config.SMTPFrom = config.SMTPAccount | 		SMTPFrom = SMTPAccount | ||||||
| 	} | 	} | ||||||
| 	encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) | 	encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) | ||||||
|  |  | ||||||
| 	// Extract domain from SMTPFrom | 	// Extract domain from SMTPFrom | ||||||
| 	parts := strings.Split(config.SMTPFrom, "@") | 	parts := strings.Split(SMTPFrom, "@") | ||||||
| 	var domain string | 	var domain string | ||||||
| 	if len(parts) > 1 { | 	if len(parts) > 1 { | ||||||
| 		domain = parts[1] | 		domain = parts[1] | ||||||
| @@ -37,21 +36,21 @@ func SendEmail(subject string, receiver string, content string) error { | |||||||
| 		"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 | 		"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 | ||||||
| 		"Date: %s\r\n"+ | 		"Date: %s\r\n"+ | ||||||
| 		"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", | 		"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", | ||||||
| 		receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) | 		receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) | ||||||
| 	auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) | 	auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) | ||||||
| 	addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) | 	addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) | ||||||
| 	to := strings.Split(receiver, ";") | 	to := strings.Split(receiver, ";") | ||||||
|  |  | ||||||
| 	if config.SMTPPort == 465 { | 	if SMTPPort == 465 { | ||||||
| 		tlsConfig := &tls.Config{ | 		tlsConfig := &tls.Config{ | ||||||
| 			InsecureSkipVerify: true, | 			InsecureSkipVerify: true, | ||||||
| 			ServerName:         config.SMTPServer, | 			ServerName:         SMTPServer, | ||||||
| 		} | 		} | ||||||
| 		conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) | 		conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 		client, err := smtp.NewClient(conn, config.SMTPServer) | 		client, err := smtp.NewClient(conn, SMTPServer) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| @@ -59,7 +58,7 @@ func SendEmail(subject string, receiver string, content string) error { | |||||||
| 		if err = client.Auth(auth); err != nil { | 		if err = client.Auth(auth); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 		if err = client.Mail(config.SMTPFrom); err != nil { | 		if err = client.Mail(SMTPFrom); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 		receiverEmails := strings.Split(receiver, ";") | 		receiverEmails := strings.Split(receiver, ";") | ||||||
| @@ -81,7 +80,7 @@ func SendEmail(subject string, receiver string, content string) error { | |||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail) | 		err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) | ||||||
| 	} | 	} | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|   | |||||||
| @@ -15,7 +15,10 @@ type embedFileSystem struct { | |||||||
|  |  | ||||||
| func (e embedFileSystem) Exists(prefix string, path string) bool { | func (e embedFileSystem) Exists(prefix string, path string) bool { | ||||||
| 	_, err := e.Open(path) | 	_, err := e.Open(path) | ||||||
| 	return err == nil | 	if err != nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return true | ||||||
| } | } | ||||||
|  |  | ||||||
| func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem { | func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem { | ||||||
|   | |||||||
							
								
								
									
										71
									
								
								common/form_builder.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								common/form_builder.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | |||||||
|  | package common | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"mime/multipart" | ||||||
|  | 	"path" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type FormBuilder interface { | ||||||
|  | 	CreateFormFile(fieldname string, fileHeader *multipart.FileHeader) error | ||||||
|  | 	CreateFormFileReader(fieldname string, r io.Reader, filename string) error | ||||||
|  | 	WriteField(fieldname, value string) error | ||||||
|  | 	Close() error | ||||||
|  | 	FormDataContentType() string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type DefaultFormBuilder struct { | ||||||
|  | 	writer *multipart.Writer | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewFormBuilder(body io.Writer) *DefaultFormBuilder { | ||||||
|  | 	return &DefaultFormBuilder{ | ||||||
|  | 		writer: multipart.NewWriter(body), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, fileHeader *multipart.FileHeader) error { | ||||||
|  | 	file, err := fileHeader.Open() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	defer file.Close() | ||||||
|  |  | ||||||
|  | 	return fb.createFormFile(fieldname, file, fileHeader.Filename) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { | ||||||
|  | 	return fb.createFormFile(fieldname, r, path.Base(filename)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error { | ||||||
|  | 	if filename == "" { | ||||||
|  | 		return fmt.Errorf("filename cannot be empty") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err = io.Copy(fieldWriter, r) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error { | ||||||
|  | 	return fb.writer.WriteField(fieldname, value) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (fb *DefaultFormBuilder) Close() error { | ||||||
|  | 	return fb.writer.Close() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (fb *DefaultFormBuilder) FormDataContentType() string { | ||||||
|  | 	return fb.writer.FormDataContentType() | ||||||
|  | } | ||||||
| @@ -2,52 +2,60 @@ package common | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"encoding/json" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"io" | 	"io" | ||||||
| 	"strings" | 	"one-api/types" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/go-playground/validator/v10" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const KeyRequestBody = "key_request_body" | func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||||
|  |  | ||||||
| func GetRequestBody(c *gin.Context) ([]byte, error) { |  | ||||||
| 	requestBody, _ := c.Get(KeyRequestBody) |  | ||||||
| 	if requestBody != nil { |  | ||||||
| 		return requestBody.([]byte), nil |  | ||||||
| 	} |  | ||||||
| 	requestBody, err := io.ReadAll(c.Request.Body) | 	requestBody, err := io.ReadAll(c.Request.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return err | ||||||
|  | 	} | ||||||
|  | 	err = c.Request.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||||
|  | 	err = c.ShouldBind(v) | ||||||
|  | 	if err != nil { | ||||||
|  | 		if errs, ok := err.(validator.ValidationErrors); ok { | ||||||
|  | 			// 返回第一个错误字段的名称 | ||||||
|  | 			return fmt.Errorf("field %s is required", errs[0].Field()) | ||||||
|  | 		} | ||||||
|  | 		return err | ||||||
| 	} | 	} | ||||||
| 	_ = c.Request.Body.Close() |  | ||||||
| 	c.Set(KeyRequestBody, requestBody) |  | ||||||
| 	return requestBody.([]byte), nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func UnmarshalBodyReusable(c *gin.Context, v any) error { |  | ||||||
| 	requestBody, err := GetRequestBody(c) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	contentType := c.Request.Header.Get("Content-Type") |  | ||||||
| 	if strings.HasPrefix(contentType, "application/json") { |  | ||||||
| 		err = json.Unmarshal(requestBody, &v) |  | ||||||
| 	} else { |  | ||||||
| 		// skip for now |  | ||||||
| 		// TODO: someday non json request have variant model, we will need to implementation this |  | ||||||
| 	} |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	// Reset request body |  | ||||||
| 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func SetEventStreamHeaders(c *gin.Context) { | func ErrorWrapper(err error, code string, statusCode int) *types.OpenAIErrorWithStatusCode { | ||||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | 	return StringErrorWrapper(err.Error(), code, statusCode) | ||||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | } | ||||||
| 	c.Writer.Header().Set("Connection", "keep-alive") |  | ||||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | func StringErrorWrapper(err string, code string, statusCode int) *types.OpenAIErrorWithStatusCode { | ||||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | 	openAIError := types.OpenAIError{ | ||||||
|  | 		Message: err, | ||||||
|  | 		Type:    "one_api_error", | ||||||
|  | 		Code:    code, | ||||||
|  | 	} | ||||||
|  | 	return &types.OpenAIErrorWithStatusCode{ | ||||||
|  | 		OpenAIError: openAIError, | ||||||
|  | 		StatusCode:  statusCode, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func AbortWithMessage(c *gin.Context, statusCode int, message string) { | ||||||
|  | 	c.JSON(statusCode, gin.H{ | ||||||
|  | 		"error": gin.H{ | ||||||
|  | 			"message": message, | ||||||
|  | 			"type":    "one_api_error", | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	c.Abort() | ||||||
|  | 	LogError(c.Request.Context(), message) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,9 +1,6 @@ | |||||||
| package common | package common | ||||||
|  |  | ||||||
| import ( | import "encoding/json" | ||||||
| 	"encoding/json" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| var GroupRatio = map[string]float64{ | var GroupRatio = map[string]float64{ | ||||||
| 	"default": 1, | 	"default": 1, | ||||||
| @@ -14,7 +11,7 @@ var GroupRatio = map[string]float64{ | |||||||
| func GroupRatio2JSONString() string { | func GroupRatio2JSONString() string { | ||||||
| 	jsonBytes, err := json.Marshal(GroupRatio) | 	jsonBytes, err := json.Marshal(GroupRatio) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("error marshalling model ratio: " + err.Error()) | 		SysError("error marshalling model ratio: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	return string(jsonBytes) | 	return string(jsonBytes) | ||||||
| } | } | ||||||
| @@ -27,7 +24,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error { | |||||||
| func GetGroupRatio(name string) float64 { | func GetGroupRatio(name string) float64 { | ||||||
| 	ratio, ok := GroupRatio[name] | 	ratio, ok := GroupRatio[name] | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		logger.SysError("group ratio not found: " + name) | 		SysError("group ratio not found: " + name) | ||||||
| 		return 1 | 		return 1 | ||||||
| 	} | 	} | ||||||
| 	return ratio | 	return ratio | ||||||
|   | |||||||
| @@ -1,234 +0,0 @@ | |||||||
| package helper |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"github.com/google/uuid" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"html/template" |  | ||||||
| 	"log" |  | ||||||
| 	"math/rand" |  | ||||||
| 	"net" |  | ||||||
| 	"os" |  | ||||||
| 	"os/exec" |  | ||||||
| 	"runtime" |  | ||||||
| 	"strconv" |  | ||||||
| 	"strings" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func OpenBrowser(url string) { |  | ||||||
| 	var err error |  | ||||||
|  |  | ||||||
| 	switch runtime.GOOS { |  | ||||||
| 	case "linux": |  | ||||||
| 		err = exec.Command("xdg-open", url).Start() |  | ||||||
| 	case "windows": |  | ||||||
| 		err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() |  | ||||||
| 	case "darwin": |  | ||||||
| 		err = exec.Command("open", url).Start() |  | ||||||
| 	} |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Println(err) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetIp() (ip string) { |  | ||||||
| 	ips, err := net.InterfaceAddrs() |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Println(err) |  | ||||||
| 		return ip |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for _, a := range ips { |  | ||||||
| 		if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { |  | ||||||
| 			if ipNet.IP.To4() != nil { |  | ||||||
| 				ip = ipNet.IP.String() |  | ||||||
| 				if strings.HasPrefix(ip, "10") { |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
| 				if strings.HasPrefix(ip, "172") { |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
| 				if strings.HasPrefix(ip, "192.168") { |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
| 				ip = "" |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var sizeKB = 1024 |  | ||||||
| var sizeMB = sizeKB * 1024 |  | ||||||
| var sizeGB = sizeMB * 1024 |  | ||||||
|  |  | ||||||
| func Bytes2Size(num int64) string { |  | ||||||
| 	numStr := "" |  | ||||||
| 	unit := "B" |  | ||||||
| 	if num/int64(sizeGB) > 1 { |  | ||||||
| 		numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB)) |  | ||||||
| 		unit = "GB" |  | ||||||
| 	} else if num/int64(sizeMB) > 1 { |  | ||||||
| 		numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB))) |  | ||||||
| 		unit = "MB" |  | ||||||
| 	} else if num/int64(sizeKB) > 1 { |  | ||||||
| 		numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB))) |  | ||||||
| 		unit = "KB" |  | ||||||
| 	} else { |  | ||||||
| 		numStr = fmt.Sprintf("%d", num) |  | ||||||
| 	} |  | ||||||
| 	return numStr + " " + unit |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func Seconds2Time(num int) (time string) { |  | ||||||
| 	if num/31104000 > 0 { |  | ||||||
| 		time += strconv.Itoa(num/31104000) + " 年 " |  | ||||||
| 		num %= 31104000 |  | ||||||
| 	} |  | ||||||
| 	if num/2592000 > 0 { |  | ||||||
| 		time += strconv.Itoa(num/2592000) + " 个月 " |  | ||||||
| 		num %= 2592000 |  | ||||||
| 	} |  | ||||||
| 	if num/86400 > 0 { |  | ||||||
| 		time += strconv.Itoa(num/86400) + " 天 " |  | ||||||
| 		num %= 86400 |  | ||||||
| 	} |  | ||||||
| 	if num/3600 > 0 { |  | ||||||
| 		time += strconv.Itoa(num/3600) + " 小时 " |  | ||||||
| 		num %= 3600 |  | ||||||
| 	} |  | ||||||
| 	if num/60 > 0 { |  | ||||||
| 		time += strconv.Itoa(num/60) + " 分钟 " |  | ||||||
| 		num %= 60 |  | ||||||
| 	} |  | ||||||
| 	time += strconv.Itoa(num) + " 秒" |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func Interface2String(inter interface{}) string { |  | ||||||
| 	switch inter := inter.(type) { |  | ||||||
| 	case string: |  | ||||||
| 		return inter |  | ||||||
| 	case int: |  | ||||||
| 		return fmt.Sprintf("%d", inter) |  | ||||||
| 	case float64: |  | ||||||
| 		return fmt.Sprintf("%f", inter) |  | ||||||
| 	} |  | ||||||
| 	return "Not Implemented" |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func UnescapeHTML(x string) interface{} { |  | ||||||
| 	return template.HTML(x) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func IntMax(a int, b int) int { |  | ||||||
| 	if a >= b { |  | ||||||
| 		return a |  | ||||||
| 	} else { |  | ||||||
| 		return b |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetUUID() string { |  | ||||||
| 	code := uuid.New().String() |  | ||||||
| 	code = strings.Replace(code, "-", "", -1) |  | ||||||
| 	return code |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" |  | ||||||
| const keyNumbers = "0123456789" |  | ||||||
|  |  | ||||||
| func init() { |  | ||||||
| 	rand.Seed(time.Now().UnixNano()) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GenerateKey() string { |  | ||||||
| 	rand.Seed(time.Now().UnixNano()) |  | ||||||
| 	key := make([]byte, 48) |  | ||||||
| 	for i := 0; i < 16; i++ { |  | ||||||
| 		key[i] = keyChars[rand.Intn(len(keyChars))] |  | ||||||
| 	} |  | ||||||
| 	uuid_ := GetUUID() |  | ||||||
| 	for i := 0; i < 32; i++ { |  | ||||||
| 		c := uuid_[i] |  | ||||||
| 		if i%2 == 0 && c >= 'a' && c <= 'z' { |  | ||||||
| 			c = c - 'a' + 'A' |  | ||||||
| 		} |  | ||||||
| 		key[i+16] = c |  | ||||||
| 	} |  | ||||||
| 	return string(key) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetRandomString(length int) string { |  | ||||||
| 	rand.Seed(time.Now().UnixNano()) |  | ||||||
| 	key := make([]byte, length) |  | ||||||
| 	for i := 0; i < length; i++ { |  | ||||||
| 		key[i] = keyChars[rand.Intn(len(keyChars))] |  | ||||||
| 	} |  | ||||||
| 	return string(key) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetRandomNumberString(length int) string { |  | ||||||
| 	rand.Seed(time.Now().UnixNano()) |  | ||||||
| 	key := make([]byte, length) |  | ||||||
| 	for i := 0; i < length; i++ { |  | ||||||
| 		key[i] = keyNumbers[rand.Intn(len(keyNumbers))] |  | ||||||
| 	} |  | ||||||
| 	return string(key) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetTimestamp() int64 { |  | ||||||
| 	return time.Now().Unix() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetTimeString() string { |  | ||||||
| 	now := time.Now() |  | ||||||
| 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func Max(a int, b int) int { |  | ||||||
| 	if a >= b { |  | ||||||
| 		return a |  | ||||||
| 	} else { |  | ||||||
| 		return b |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetOrDefaultEnvInt(env string, defaultValue int) int { |  | ||||||
| 	if env == "" || os.Getenv(env) == "" { |  | ||||||
| 		return defaultValue |  | ||||||
| 	} |  | ||||||
| 	num, err := strconv.Atoi(os.Getenv(env)) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) |  | ||||||
| 		return defaultValue |  | ||||||
| 	} |  | ||||||
| 	return num |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetOrDefaultEnvString(env string, defaultValue string) string { |  | ||||||
| 	if env == "" || os.Getenv(env) == "" { |  | ||||||
| 		return defaultValue |  | ||||||
| 	} |  | ||||||
| 	return os.Getenv(env) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func AssignOrDefault(value string, defaultValue string) string { |  | ||||||
| 	if len(value) != 0 { |  | ||||||
| 		return value |  | ||||||
| 	} |  | ||||||
| 	return defaultValue |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func MessageWithRequestId(message string, id string) string { |  | ||||||
| 	return fmt.Sprintf("%s (request id: %s)", message, id) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func String2Int(str string) int { |  | ||||||
| 	num, err := strconv.Atoi(str) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0 |  | ||||||
| 	} |  | ||||||
| 	return num |  | ||||||
| } |  | ||||||
| @@ -3,6 +3,7 @@ package image | |||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
|  | 	"errors" | ||||||
| 	"image" | 	"image" | ||||||
| 	_ "image/gif" | 	_ "image/gif" | ||||||
| 	_ "image/jpeg" | 	_ "image/jpeg" | ||||||
| @@ -15,9 +16,6 @@ import ( | |||||||
| 	_ "golang.org/x/image/webp" | 	_ "golang.org/x/image/webp" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Regex to match data URL pattern |  | ||||||
| var	dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) |  | ||||||
|  |  | ||||||
| func IsImageUrl(url string) (bool, error) { | func IsImageUrl(url string) (bool, error) { | ||||||
| 	resp, err := http.Head(url) | 	resp, err := http.Head(url) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -47,17 +45,26 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GetImageFromUrl(url string) (mimeType string, data string, err error) { | func GetImageFromUrl(url string) (mimeType string, data string, err error) { | ||||||
| 	// Check if the URL is a data URL |  | ||||||
| 	matches := dataURLPattern.FindStringSubmatch(url) | 	if strings.HasPrefix(url, "data:image/") { | ||||||
| 	if len(matches) == 3 { | 		dataURLPattern := regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) | ||||||
| 		// URL is a data URL |  | ||||||
| 		mimeType = "image/" + matches[1] | 		matches := dataURLPattern.FindStringSubmatch(url) | ||||||
| 		data = matches[2] | 		if len(matches) == 3 && matches[2] != "" { | ||||||
|  | 			mimeType = "image/" + matches[1] | ||||||
|  | 			data = matches[2] | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		err = errors.New("image base64 decode failed") | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	isImage, err := IsImageUrl(url) | 	isImage, err := IsImageUrl(url) | ||||||
| 	if !isImage { | 	if !isImage { | ||||||
|  | 		if err == nil { | ||||||
|  | 			err = errors.New("invalid image link") | ||||||
|  | 		} | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	resp, err := http.Get(url) | 	resp, err := http.Get(url) | ||||||
|   | |||||||
| @@ -12,7 +12,7 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	img "github.com/songquanpeng/one-api/common/image" | 	img "one-api/common/image" | ||||||
|  |  | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| 	_ "golang.org/x/image/webp" | 	_ "golang.org/x/image/webp" | ||||||
| @@ -169,3 +169,34 @@ func TestGetImageSizeFromBase64(t *testing.T) { | |||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestGetImageFromUrl(t *testing.T) { | ||||||
|  | 	for i, c := range cases { | ||||||
|  | 		t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { | ||||||
|  | 			resp, err := http.Get(c.url) | ||||||
|  | 			assert.NoError(t, err) | ||||||
|  | 			defer resp.Body.Close() | ||||||
|  | 			data, err := io.ReadAll(resp.Body) | ||||||
|  | 			assert.NoError(t, err) | ||||||
|  | 			encoded := base64.StdEncoding.EncodeToString(data) | ||||||
|  |  | ||||||
|  | 			mimeType, base64Data, err := img.GetImageFromUrl(c.url) | ||||||
|  | 			assert.NoError(t, err) | ||||||
|  | 			assert.Equal(t, encoded, base64Data) | ||||||
|  | 			assert.Equal(t, "image/"+c.format, mimeType) | ||||||
|  |  | ||||||
|  | 			encodedBase64 := "data:image/" + c.format + ";base64," + encoded | ||||||
|  | 			mimeType, base64Data, err = img.GetImageFromUrl(encodedBase64) | ||||||
|  | 			assert.NoError(t, err) | ||||||
|  | 			assert.Equal(t, encoded, base64Data) | ||||||
|  | 			assert.Equal(t, "image/"+c.format, mimeType) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	url := "https://raw.githubusercontent.com/songquanpeng/one-api/main/README.md" | ||||||
|  | 	_, _, err := img.GetImageFromUrl(url) | ||||||
|  | 	assert.Error(t, err) | ||||||
|  | 	encodedBase64 := "data:image/text;base64," | ||||||
|  | 	_, _, err = img.GetImageFromUrl(encodedBase64) | ||||||
|  | 	assert.Error(t, err) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -3,11 +3,11 @@ package common | |||||||
| import ( | import ( | ||||||
| 	"flag" | 	"flag" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"log" | 	"log" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
|  |  | ||||||
|  | 	"github.com/joho/godotenv" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ( | var ( | ||||||
| @@ -25,6 +25,11 @@ func printHelp() { | |||||||
| } | } | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
|  | 	// 加载.env文件 | ||||||
|  | 	err := godotenv.Load() | ||||||
|  | 	if err != nil { | ||||||
|  | 		SysLog("failed to load .env file: " + err.Error()) | ||||||
|  | 	} | ||||||
| 	flag.Parse() | 	flag.Parse() | ||||||
|  |  | ||||||
| 	if *PrintVersion { | 	if *PrintVersion { | ||||||
| @@ -39,9 +44,9 @@ func init() { | |||||||
|  |  | ||||||
| 	if os.Getenv("SESSION_SECRET") != "" { | 	if os.Getenv("SESSION_SECRET") != "" { | ||||||
| 		if os.Getenv("SESSION_SECRET") == "random_string" { | 		if os.Getenv("SESSION_SECRET") == "random_string" { | ||||||
| 			logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.") | 			SysError("SESSION_SECRET is set to an example value, please change it to a random string.") | ||||||
| 		} else { | 		} else { | ||||||
| 			config.SessionSecret = os.Getenv("SESSION_SECRET") | 			SessionSecret = os.Getenv("SESSION_SECRET") | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("SQLITE_PATH") != "" { | 	if os.Getenv("SQLITE_PATH") != "" { | ||||||
| @@ -59,6 +64,5 @@ func init() { | |||||||
| 				log.Fatal(err) | 				log.Fatal(err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		logger.LogDir = *LogDir |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| package logger | package common | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| @@ -13,7 +13,6 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| 	loggerDEBUG = "DEBUG" |  | ||||||
| 	loggerINFO  = "INFO" | 	loggerINFO  = "INFO" | ||||||
| 	loggerWarn  = "WARN" | 	loggerWarn  = "WARN" | ||||||
| 	loggerError = "ERR" | 	loggerError = "ERR" | ||||||
| @@ -26,7 +25,7 @@ var setupLogLock sync.Mutex | |||||||
| var setupLogWorking bool | var setupLogWorking bool | ||||||
| 
 | 
 | ||||||
| func SetupLogger() { | func SetupLogger() { | ||||||
| 	if LogDir != "" { | 	if *LogDir != "" { | ||||||
| 		ok := setupLogLock.TryLock() | 		ok := setupLogLock.TryLock() | ||||||
| 		if !ok { | 		if !ok { | ||||||
| 			log.Println("setup log is already working") | 			log.Println("setup log is already working") | ||||||
| @@ -36,7 +35,7 @@ func SetupLogger() { | |||||||
| 			setupLogLock.Unlock() | 			setupLogLock.Unlock() | ||||||
| 			setupLogWorking = false | 			setupLogWorking = false | ||||||
| 		}() | 		}() | ||||||
| 		logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) | 		logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) | ||||||
| 		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | 		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Fatal("failed to open log file") | 			log.Fatal("failed to open log file") | ||||||
| @@ -56,38 +55,18 @@ func SysError(s string) { | |||||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Debug(ctx context.Context, msg string) { | func LogInfo(ctx context.Context, msg string) { | ||||||
| 	logHelper(ctx, loggerDEBUG, msg) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func Info(ctx context.Context, msg string) { |  | ||||||
| 	logHelper(ctx, loggerINFO, msg) | 	logHelper(ctx, loggerINFO, msg) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Warn(ctx context.Context, msg string) { | func LogWarn(ctx context.Context, msg string) { | ||||||
| 	logHelper(ctx, loggerWarn, msg) | 	logHelper(ctx, loggerWarn, msg) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Error(ctx context.Context, msg string) { | func LogError(ctx context.Context, msg string) { | ||||||
| 	logHelper(ctx, loggerError, msg) | 	logHelper(ctx, loggerError, msg) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Debugf(ctx context.Context, format string, a ...any) { |  | ||||||
| 	Debug(ctx, fmt.Sprintf(format, a...)) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func Infof(ctx context.Context, format string, a ...any) { |  | ||||||
| 	Info(ctx, fmt.Sprintf(format, a...)) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func Warnf(ctx context.Context, format string, a ...any) { |  | ||||||
| 	Warn(ctx, fmt.Sprintf(format, a...)) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func Errorf(ctx context.Context, format string, a ...any) { |  | ||||||
| 	Error(ctx, fmt.Sprintf(format, a...)) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func logHelper(ctx context.Context, level string, msg string) { | func logHelper(ctx context.Context, level string, msg string) { | ||||||
| 	writer := gin.DefaultErrorWriter | 	writer := gin.DefaultErrorWriter | ||||||
| 	if level == loggerINFO { | 	if level == loggerINFO { | ||||||
| @@ -111,3 +90,11 @@ func FatalLog(v ...any) { | |||||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) | 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) | ||||||
| 	os.Exit(1) | 	os.Exit(1) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func LogQuota(quota int) string { | ||||||
|  | 	if DisplayInCurrencyEnabled { | ||||||
|  | 		return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) | ||||||
|  | 	} else { | ||||||
|  | 		return fmt.Sprintf("%d 点额度", quota) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -1,7 +0,0 @@ | |||||||
| package logger |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	RequestIdKey = "X-Oneapi-Request-Id" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| var LogDir string |  | ||||||
							
								
								
									
										15
									
								
								common/marshaller.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								common/marshaller.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | |||||||
|  | package common | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Marshaller interface { | ||||||
|  | 	Marshal(value any) ([]byte, error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type JSONMarshaller struct{} | ||||||
|  |  | ||||||
|  | func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) { | ||||||
|  | 	return json.Marshal(value) | ||||||
|  | } | ||||||
| @@ -2,89 +2,89 @@ package common | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | var DalleSizeRatios = map[string]map[string]float64{ | ||||||
| 	USD2RMB = 7 | 	"dall-e-2": { | ||||||
| 	USD     = 500 // $0.002 = 1 -> $1 = 500 | 		"256x256":   1, | ||||||
| 	RMB     = USD / USD2RMB | 		"512x512":   1.125, | ||||||
| ) | 		"1024x1024": 1.25, | ||||||
|  | 	}, | ||||||
|  | 	"dall-e-3": { | ||||||
|  | 		"1024x1024": 1, | ||||||
|  | 		"1024x1792": 2, | ||||||
|  | 		"1792x1024": 2, | ||||||
|  | 	}, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var DalleGenerationImageAmounts = map[string][2]int{ | ||||||
|  | 	"dall-e-2": {1, 10}, | ||||||
|  | 	"dall-e-3": {1, 1}, // OpenAI allows n=1 currently. | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var DalleImagePromptLengthLimitations = map[string]int{ | ||||||
|  | 	"dall-e-2": 1000, | ||||||
|  | 	"dall-e-3": 4000, | ||||||
|  | } | ||||||
|  |  | ||||||
| // ModelRatio | // ModelRatio | ||||||
| // https://platform.openai.com/docs/models/model-endpoint-compatibility | // https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf | // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf | ||||||
| // https://openai.com/pricing | // https://openai.com/pricing | ||||||
|  | // TODO: when a new api is enabled, check the pricing here | ||||||
| // 1 === $0.002 / 1K tokens | // 1 === $0.002 / 1K tokens | ||||||
| // 1 === ¥0.014 / 1k tokens | // 1 === ¥0.014 / 1k tokens | ||||||
| var ModelRatio = map[string]float64{ | var ModelRatio = map[string]float64{ | ||||||
| 	// https://openai.com/pricing | 	"gpt-4":                     15, | ||||||
| 	"gpt-4":                   15, | 	"gpt-4-0314":                15, | ||||||
| 	"gpt-4-0314":              15, | 	"gpt-4-0613":                15, | ||||||
| 	"gpt-4-0613":              15, | 	"gpt-4-32k":                 30, | ||||||
| 	"gpt-4-32k":               30, | 	"gpt-4-32k-0314":            30, | ||||||
| 	"gpt-4-32k-0314":          30, | 	"gpt-4-32k-0613":            30, | ||||||
| 	"gpt-4-32k-0613":          30, | 	"gpt-4-1106-preview":        5,    // $0.01 / 1K tokens | ||||||
| 	"gpt-4-1106-preview":      5,    // $0.01 / 1K tokens | 	"gpt-4-vision-preview":      5,    // $0.01 / 1K tokens | ||||||
| 	"gpt-4-0125-preview":      5,    // $0.01 / 1K tokens | 	"gpt-3.5-turbo":             0.75, // $0.0015 / 1K tokens | ||||||
| 	"gpt-4-turbo-preview":     5,    // $0.01 / 1K tokens | 	"gpt-3.5-turbo-0301":        0.75, | ||||||
| 	"gpt-4-vision-preview":    5,    // $0.01 / 1K tokens | 	"gpt-3.5-turbo-0613":        0.75, | ||||||
| 	"gpt-3.5-turbo":           0.75, // $0.0015 / 1K tokens | 	"gpt-3.5-turbo-16k":         1.5, // $0.003 / 1K tokens | ||||||
| 	"gpt-3.5-turbo-0301":      0.75, | 	"gpt-3.5-turbo-16k-0613":    1.5, | ||||||
| 	"gpt-3.5-turbo-0613":      0.75, | 	"gpt-3.5-turbo-instruct":    0.75, // $0.0015 / 1K tokens | ||||||
| 	"gpt-3.5-turbo-16k":       1.5, // $0.003 / 1K tokens | 	"gpt-3.5-turbo-1106":        0.5,  // $0.001 / 1K tokens | ||||||
| 	"gpt-3.5-turbo-16k-0613":  1.5, | 	"text-ada-001":              0.2, | ||||||
| 	"gpt-3.5-turbo-instruct":  0.75, // $0.0015 / 1K tokens | 	"text-babbage-001":          0.25, | ||||||
| 	"gpt-3.5-turbo-1106":      0.5,  // $0.001 / 1K tokens | 	"text-curie-001":            1, | ||||||
| 	"gpt-3.5-turbo-0125":      0.25, // $0.0005 / 1K tokens | 	"text-davinci-002":          10, | ||||||
| 	"davinci-002":             1,    // $0.002 / 1K tokens | 	"text-davinci-003":          10, | ||||||
| 	"babbage-002":             0.2,  // $0.0004 / 1K tokens | 	"text-davinci-edit-001":     10, | ||||||
| 	"text-ada-001":            0.2, | 	"code-davinci-edit-001":     10, | ||||||
| 	"text-babbage-001":        0.25, | 	"whisper-1":                 15,  // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens | ||||||
| 	"text-curie-001":          1, | 	"tts-1":                     7.5, // $0.015 / 1K characters | ||||||
| 	"text-davinci-002":        10, | 	"tts-1-1106":                7.5, | ||||||
| 	"text-davinci-003":        10, | 	"tts-1-hd":                  15, // $0.030 / 1K characters | ||||||
| 	"text-davinci-edit-001":   10, | 	"tts-1-hd-1106":             15, | ||||||
| 	"code-davinci-edit-001":   10, | 	"davinci":                   10, | ||||||
| 	"whisper-1":               15,  // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens | 	"curie":                     10, | ||||||
| 	"tts-1":                   7.5, // $0.015 / 1K characters | 	"babbage":                   10, | ||||||
| 	"tts-1-1106":              7.5, | 	"ada":                       10, | ||||||
| 	"tts-1-hd":                15, // $0.030 / 1K characters | 	"text-embedding-ada-002":    0.05, | ||||||
| 	"tts-1-hd-1106":           15, | 	"text-search-ada-doc-001":   10, | ||||||
| 	"davinci":                 10, | 	"text-moderation-stable":    0.1, | ||||||
| 	"curie":                   10, | 	"text-moderation-latest":    0.1, | ||||||
| 	"babbage":                 10, | 	"dall-e-2":                  8,      // $0.016 - $0.020 / image | ||||||
| 	"ada":                     10, | 	"dall-e-3":                  20,     // $0.040 - $0.120 / image | ||||||
| 	"text-embedding-ada-002":  0.05, | 	"claude-instant-1":          0.815,  // $1.63 / 1M tokens | ||||||
| 	"text-embedding-3-small":  0.01, | 	"claude-2":                  5.51,   // $11.02 / 1M tokens | ||||||
| 	"text-embedding-3-large":  0.065, | 	"claude-2.0":                5.51,   // $11.02 / 1M tokens | ||||||
| 	"text-search-ada-doc-001": 10, | 	"claude-2.1":                5.51,   // $11.02 / 1M tokens | ||||||
| 	"text-moderation-stable":  0.1, | 	"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens | ||||||
| 	"text-moderation-latest":  0.1, | 	"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens | ||||||
| 	"dall-e-2":                8,  // $0.016 - $0.020 / image | 	"ERNIE-Bot-4":               8.572,  // ¥0.12 / 1k tokens | ||||||
| 	"dall-e-3":                20, // $0.040 - $0.120 / image | 	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens | ||||||
| 	// https://www.anthropic.com/api#pricing | 	"PaLM-2":                    1, | ||||||
| 	"claude-instant-1.2":       0.8 / 1000 * USD, | 	"gemini-pro":                1,      // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||||
| 	"claude-2.0":               8.0 / 1000 * USD, | 	"gemini-pro-vision":         1,      // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||||
| 	"claude-2.1":               8.0 / 1000 * USD, |  | ||||||
| 	"claude-3-haiku-20240229":  0.25 / 1000 * USD, |  | ||||||
| 	"claude-3-sonnet-20240229": 3.0 / 1000 * USD, |  | ||||||
| 	"claude-3-opus-20240229":   15.0 / 1000 * USD, |  | ||||||
| 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 |  | ||||||
| 	"ERNIE-Bot":         0.8572,     // ¥0.012 / 1k tokens |  | ||||||
| 	"ERNIE-Bot-turbo":   0.5715,     // ¥0.008 / 1k tokens |  | ||||||
| 	"ERNIE-Bot-4":       0.12 * RMB, // ¥0.12 / 1k tokens |  | ||||||
| 	"ERNIE-Bot-8k":      0.024 * RMB, |  | ||||||
| 	"Embedding-V1":      0.1429, // ¥0.002 / 1k tokens |  | ||||||
| 	"PaLM-2":            1, |  | ||||||
| 	"gemini-pro":        1, // $0.00025 / 1k characters -> $0.001 / 1k tokens |  | ||||||
| 	"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens |  | ||||||
| 	// https://open.bigmodel.cn/pricing |  | ||||||
| 	"glm-4":                     0.1 * RMB, |  | ||||||
| 	"glm-4v":                    0.1 * RMB, |  | ||||||
| 	"glm-3-turbo":               0.005 * RMB, |  | ||||||
| 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens | 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens | ||||||
| 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | ||||||
| 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | ||||||
| @@ -93,65 +93,20 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"qwen-plus":                 1.4286, // ¥0.02 / 1k tokens | 	"qwen-plus":                 1.4286, // ¥0.02 / 1k tokens | ||||||
| 	"qwen-max":                  1.4286, // ¥0.02 / 1k tokens | 	"qwen-max":                  1.4286, // ¥0.02 / 1k tokens | ||||||
| 	"qwen-max-longcontext":      1.4286, // ¥0.02 / 1k tokens | 	"qwen-max-longcontext":      1.4286, // ¥0.02 / 1k tokens | ||||||
|  | 	"qwen-vl-plus":              0.5715, // ¥0.008 / 1k tokens | ||||||
| 	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens | 	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens | ||||||
| 	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens | 	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens | ||||||
| 	"SparkDesk-v1.1":            1.2858, // ¥0.018 / 1k tokens |  | ||||||
| 	"SparkDesk-v2.1":            1.2858, // ¥0.018 / 1k tokens |  | ||||||
| 	"SparkDesk-v3.1":            1.2858, // ¥0.018 / 1k tokens |  | ||||||
| 	"SparkDesk-v3.5":            1.2858, // ¥0.018 / 1k tokens |  | ||||||
| 	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens | 	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens | ||||||
| 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens | 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens | ||||||
| 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens | 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens | ||||||
| 	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens | 	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens | ||||||
| 	"hunyuan":                   7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 | 	"hunyuan":                   7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 | ||||||
| 	"ChatStd":                   0.01 * RMB, |  | ||||||
| 	"ChatPro":                   0.1 * RMB, |  | ||||||
| 	// https://platform.moonshot.cn/pricing |  | ||||||
| 	"moonshot-v1-8k":   0.012 * RMB, |  | ||||||
| 	"moonshot-v1-32k":  0.024 * RMB, |  | ||||||
| 	"moonshot-v1-128k": 0.06 * RMB, |  | ||||||
| 	// https://platform.baichuan-ai.com/price |  | ||||||
| 	"Baichuan2-Turbo":      0.008 * RMB, |  | ||||||
| 	"Baichuan2-Turbo-192k": 0.016 * RMB, |  | ||||||
| 	"Baichuan2-53B":        0.02 * RMB, |  | ||||||
| 	// https://api.minimax.chat/document/price |  | ||||||
| 	"abab6-chat":    0.1 * RMB, |  | ||||||
| 	"abab5.5-chat":  0.015 * RMB, |  | ||||||
| 	"abab5.5s-chat": 0.005 * RMB, |  | ||||||
| 	// https://docs.mistral.ai/platform/pricing/ |  | ||||||
| 	"open-mistral-7b":       0.25 / 1000 * USD, |  | ||||||
| 	"open-mixtral-8x7b":     0.7 / 1000 * USD, |  | ||||||
| 	"mistral-small-latest":  2.0 / 1000 * USD, |  | ||||||
| 	"mistral-medium-latest": 2.7 / 1000 * USD, |  | ||||||
| 	"mistral-large-latest":  8.0 / 1000 * USD, |  | ||||||
| 	"mistral-embed":         0.1 / 1000 * USD, |  | ||||||
| 	// https://wow.groq.com/ |  | ||||||
| 	"llama2-70b-4096":    0.7 / 1000 * USD, |  | ||||||
| 	"llama2-7b-2048":     0.1 / 1000 * USD, |  | ||||||
| 	"mixtral-8x7b-32768": 0.27 / 1000 * USD, |  | ||||||
| 	"gemma-7b-it":        0.1 / 1000 * USD, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var CompletionRatio = map[string]float64{} |  | ||||||
|  |  | ||||||
| var DefaultModelRatio map[string]float64 |  | ||||||
| var DefaultCompletionRatio map[string]float64 |  | ||||||
|  |  | ||||||
| func init() { |  | ||||||
| 	DefaultModelRatio = make(map[string]float64) |  | ||||||
| 	for k, v := range ModelRatio { |  | ||||||
| 		DefaultModelRatio[k] = v |  | ||||||
| 	} |  | ||||||
| 	DefaultCompletionRatio = make(map[string]float64) |  | ||||||
| 	for k, v := range CompletionRatio { |  | ||||||
| 		DefaultCompletionRatio[k] = v |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func ModelRatio2JSONString() string { | func ModelRatio2JSONString() string { | ||||||
| 	jsonBytes, err := json.Marshal(ModelRatio) | 	jsonBytes, err := json.Marshal(ModelRatio) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("error marshalling model ratio: " + err.Error()) | 		SysError("error marshalling model ratio: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	return string(jsonBytes) | 	return string(jsonBytes) | ||||||
| } | } | ||||||
| @@ -167,41 +122,14 @@ func GetModelRatio(name string) float64 { | |||||||
| 	} | 	} | ||||||
| 	ratio, ok := ModelRatio[name] | 	ratio, ok := ModelRatio[name] | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		ratio, ok = DefaultModelRatio[name] | 		SysError("model ratio not found: " + name) | ||||||
| 	} |  | ||||||
| 	if !ok { |  | ||||||
| 		logger.SysError("model ratio not found: " + name) |  | ||||||
| 		return 30 | 		return 30 | ||||||
| 	} | 	} | ||||||
| 	return ratio | 	return ratio | ||||||
| } | } | ||||||
|  |  | ||||||
| func CompletionRatio2JSONString() string { |  | ||||||
| 	jsonBytes, err := json.Marshal(CompletionRatio) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.SysError("error marshalling completion ratio: " + err.Error()) |  | ||||||
| 	} |  | ||||||
| 	return string(jsonBytes) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func UpdateCompletionRatioByJSONString(jsonStr string) error { |  | ||||||
| 	CompletionRatio = make(map[string]float64) |  | ||||||
| 	return json.Unmarshal([]byte(jsonStr), &CompletionRatio) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetCompletionRatio(name string) float64 { | func GetCompletionRatio(name string) float64 { | ||||||
| 	if ratio, ok := CompletionRatio[name]; ok { |  | ||||||
| 		return ratio |  | ||||||
| 	} |  | ||||||
| 	if ratio, ok := DefaultCompletionRatio[name]; ok { |  | ||||||
| 		return ratio |  | ||||||
| 	} |  | ||||||
| 	if strings.HasPrefix(name, "gpt-3.5") { | 	if strings.HasPrefix(name, "gpt-3.5") { | ||||||
| 		if strings.HasSuffix(name, "0125") { |  | ||||||
| 			// https://openai.com/blog/new-embedding-models-and-api-updates |  | ||||||
| 			// Updated GPT-3.5 Turbo model and lower pricing |  | ||||||
| 			return 3 |  | ||||||
| 		} |  | ||||||
| 		if strings.HasSuffix(name, "1106") { | 		if strings.HasSuffix(name, "1106") { | ||||||
| 			return 2 | 			return 2 | ||||||
| 		} | 		} | ||||||
| @@ -214,7 +142,7 @@ func GetCompletionRatio(name string) float64 { | |||||||
| 				return 2 | 				return 2 | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		return 4.0 / 3.0 | 		return 1.333333 | ||||||
| 	} | 	} | ||||||
| 	if strings.HasPrefix(name, "gpt-4") { | 	if strings.HasPrefix(name, "gpt-4") { | ||||||
| 		if strings.HasSuffix(name, "preview") { | 		if strings.HasSuffix(name, "preview") { | ||||||
| @@ -222,18 +150,11 @@ func GetCompletionRatio(name string) float64 { | |||||||
| 		} | 		} | ||||||
| 		return 2 | 		return 2 | ||||||
| 	} | 	} | ||||||
| 	if strings.HasPrefix(name, "claude-3") { | 	if strings.HasPrefix(name, "claude-instant-1") { | ||||||
| 		return 5 | 		return 3.38 | ||||||
| 	} | 	} | ||||||
| 	if strings.HasPrefix(name, "claude-") { | 	if strings.HasPrefix(name, "claude-2") { | ||||||
| 		return 3 | 		return 2.965517 | ||||||
| 	} |  | ||||||
| 	if strings.HasPrefix(name, "mistral-") { |  | ||||||
| 		return 3 |  | ||||||
| 	} |  | ||||||
| 	switch name { |  | ||||||
| 	case "llama2-70b-4096": |  | ||||||
| 		return 0.8 / 0.7 |  | ||||||
| 	} | 	} | ||||||
| 	return 1 | 	return 1 | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										59
									
								
								common/quota.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								common/quota.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,59 @@ | |||||||
|  | package common | ||||||
|  |  | ||||||
|  | // type Quota struct { | ||||||
|  | // 	ModelName  string | ||||||
|  | // 	ModelRatio float64 | ||||||
|  | // 	GroupRatio float64 | ||||||
|  | // 	Ratio      float64 | ||||||
|  | // 	UserQuota  int | ||||||
|  | // } | ||||||
|  |  | ||||||
|  | // func CreateQuota(modelName string, userQuota int, group string) *Quota { | ||||||
|  | // 	modelRatio := GetModelRatio(modelName) | ||||||
|  | // 	groupRatio := GetGroupRatio(group) | ||||||
|  |  | ||||||
|  | // 	return &Quota{ | ||||||
|  | // 		ModelName:  modelName, | ||||||
|  | // 		ModelRatio: modelRatio, | ||||||
|  | // 		GroupRatio: groupRatio, | ||||||
|  | // 		Ratio:      modelRatio * groupRatio, | ||||||
|  | // 		UserQuota:  userQuota, | ||||||
|  | // 	} | ||||||
|  | // } | ||||||
|  |  | ||||||
|  | // func (q *Quota) getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | ||||||
|  | // 	if ApproximateTokenEnabled { | ||||||
|  | // 		return int(float64(len(text)) * 0.38) | ||||||
|  | // 	} | ||||||
|  | // 	return len(tokenEncoder.Encode(text, nil, nil)) | ||||||
|  | // } | ||||||
|  |  | ||||||
|  | // func (q *Quota) CountTokenMessages(messages []Message, model string) int { | ||||||
|  | // 	tokenEncoder := q.getTokenEncoder(model) | ||||||
|  | // 	// Reference: | ||||||
|  | // 	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | ||||||
|  | // 	// https://github.com/pkoukk/tiktoken-go/issues/6 | ||||||
|  | // 	// | ||||||
|  | // 	// Every message follows <|start|>{role/name}\n{content}<|end|>\n | ||||||
|  | // 	var tokensPerMessage int | ||||||
|  | // 	var tokensPerName int | ||||||
|  | // 	if model == "gpt-3.5-turbo-0301" { | ||||||
|  | // 		tokensPerMessage = 4 | ||||||
|  | // 		tokensPerName = -1 // If there's a name, the role is omitted | ||||||
|  | // 	} else { | ||||||
|  | // 		tokensPerMessage = 3 | ||||||
|  | // 		tokensPerName = 1 | ||||||
|  | // 	} | ||||||
|  | // 	tokenNum := 0 | ||||||
|  | // 	for _, message := range messages { | ||||||
|  | // 		tokenNum += tokensPerMessage | ||||||
|  | // 		tokenNum += q.getTokenNum(tokenEncoder, message.StringContent()) | ||||||
|  | // 		tokenNum += q.getTokenNum(tokenEncoder, message.Role) | ||||||
|  | // 		if message.Name != nil { | ||||||
|  | // 			tokenNum += tokensPerName | ||||||
|  | // 			tokenNum += q.getTokenNum(tokenEncoder, *message.Name) | ||||||
|  | // 		} | ||||||
|  | // 	} | ||||||
|  | // 	tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> | ||||||
|  | // 	return tokenNum | ||||||
|  | // } | ||||||
| @@ -1,8 +0,0 @@ | |||||||
| package common |  | ||||||
|  |  | ||||||
| import "math/rand" |  | ||||||
|  |  | ||||||
| // RandRange returns a random number between min and max (max is not included) |  | ||||||
| func RandRange(min, max int) int { |  | ||||||
| 	return min + rand.Intn(max-min) |  | ||||||
| } |  | ||||||
| @@ -3,7 +3,6 @@ package common | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"github.com/go-redis/redis/v8" | 	"github.com/go-redis/redis/v8" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"os" | 	"os" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| @@ -15,18 +14,18 @@ var RedisEnabled = true | |||||||
| func InitRedisClient() (err error) { | func InitRedisClient() (err error) { | ||||||
| 	if os.Getenv("REDIS_CONN_STRING") == "" { | 	if os.Getenv("REDIS_CONN_STRING") == "" { | ||||||
| 		RedisEnabled = false | 		RedisEnabled = false | ||||||
| 		logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled") | 		SysLog("REDIS_CONN_STRING not set, Redis is not enabled") | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("SYNC_FREQUENCY") == "" { | 	if os.Getenv("SYNC_FREQUENCY") == "" { | ||||||
| 		RedisEnabled = false | 		RedisEnabled = false | ||||||
| 		logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") | 		SysLog("SYNC_FREQUENCY not set, Redis is disabled") | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	logger.SysLog("Redis is enabled") | 	SysLog("Redis is enabled") | ||||||
| 	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) | 	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.FatalLog("failed to parse Redis connection string: " + err.Error()) | 		FatalLog("failed to parse Redis connection string: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	RDB = redis.NewClient(opt) | 	RDB = redis.NewClient(opt) | ||||||
|  |  | ||||||
| @@ -35,7 +34,7 @@ func InitRedisClient() (err error) { | |||||||
|  |  | ||||||
| 	_, err = RDB.Ping(ctx).Result() | 	_, err = RDB.Ping(ctx).Result() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.FatalLog("Redis ping test failed: " + err.Error()) | 		FatalLog("Redis ping test failed: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| @@ -43,7 +42,7 @@ func InitRedisClient() (err error) { | |||||||
| func ParseRedisOption() *redis.Options { | func ParseRedisOption() *redis.Options { | ||||||
| 	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) | 	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.FatalLog("failed to parse Redis connection string: " + err.Error()) | 		FatalLog("failed to parse Redis connection string: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	return opt | 	return opt | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										50
									
								
								common/request_builder.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								common/request_builder.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | |||||||
|  | package common | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type RequestBuilder interface { | ||||||
|  | 	Build(method, url string, body any, header http.Header) (*http.Request, error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type HTTPRequestBuilder struct { | ||||||
|  | 	marshaller Marshaller | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewRequestBuilder() *HTTPRequestBuilder { | ||||||
|  | 	return &HTTPRequestBuilder{ | ||||||
|  | 		marshaller: &JSONMarshaller{}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *HTTPRequestBuilder) Build( | ||||||
|  | 	method string, | ||||||
|  | 	url string, | ||||||
|  | 	body any, | ||||||
|  | 	header http.Header, | ||||||
|  | ) (req *http.Request, err error) { | ||||||
|  | 	var bodyReader io.Reader | ||||||
|  | 	if body != nil { | ||||||
|  | 		if v, ok := body.(io.Reader); ok { | ||||||
|  | 			bodyReader = v | ||||||
|  | 		} else { | ||||||
|  | 			var reqBytes []byte | ||||||
|  | 			reqBytes, err = b.marshaller.Marshal(body) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			bodyReader = bytes.NewBuffer(reqBytes) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	req, err = http.NewRequest(method, url, bodyReader) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if header != nil { | ||||||
|  | 		req.Header = header | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
| @@ -1,34 +1,32 @@ | |||||||
| package openai | package common | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/pkoukk/tiktoken-go" |  | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/image" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" |  | ||||||
| 	"math" | 	"math" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"one-api/common/image" | ||||||
|  | 	"one-api/types" | ||||||
|  | 
 | ||||||
|  | 	"github.com/pkoukk/tiktoken-go" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // tokenEncoderMap won't grow after initialization |  | ||||||
| var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | ||||||
| var defaultTokenEncoder *tiktoken.Tiktoken | var defaultTokenEncoder *tiktoken.Tiktoken | ||||||
| 
 | 
 | ||||||
| func InitTokenEncoders() { | func InitTokenEncoders() { | ||||||
| 	logger.SysLog("initializing token encoders") | 	SysLog("initializing token encoders") | ||||||
| 	gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") | 	gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) | 		FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) | ||||||
| 	} | 	} | ||||||
| 	defaultTokenEncoder = gpt35TokenEncoder | 	defaultTokenEncoder = gpt35TokenEncoder | ||||||
| 	gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") | 	gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) | 		FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) | ||||||
| 	} | 	} | ||||||
| 	for model := range common.ModelRatio { | 	for model, _ := range ModelRatio { | ||||||
| 		if strings.HasPrefix(model, "gpt-3.5") { | 		if strings.HasPrefix(model, "gpt-3.5") { | ||||||
| 			tokenEncoderMap[model] = gpt35TokenEncoder | 			tokenEncoderMap[model] = gpt35TokenEncoder | ||||||
| 		} else if strings.HasPrefix(model, "gpt-4") { | 		} else if strings.HasPrefix(model, "gpt-4") { | ||||||
| @@ -37,7 +35,7 @@ func InitTokenEncoders() { | |||||||
| 			tokenEncoderMap[model] = nil | 			tokenEncoderMap[model] = nil | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	logger.SysLog("token encoders initialized") | 	SysLog("token encoders initialized") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getTokenEncoder(model string) *tiktoken.Tiktoken { | func getTokenEncoder(model string) *tiktoken.Tiktoken { | ||||||
| @@ -48,7 +46,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken { | |||||||
| 	if ok { | 	if ok { | ||||||
| 		tokenEncoder, err := tiktoken.EncodingForModel(model) | 		tokenEncoder, err := tiktoken.EncodingForModel(model) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) | 			SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) | ||||||
| 			tokenEncoder = defaultTokenEncoder | 			tokenEncoder = defaultTokenEncoder | ||||||
| 		} | 		} | ||||||
| 		tokenEncoderMap[model] = tokenEncoder | 		tokenEncoderMap[model] = tokenEncoder | ||||||
| @@ -58,13 +56,13 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | ||||||
| 	if config.ApproximateTokenEnabled { | 	if ApproximateTokenEnabled { | ||||||
| 		return int(float64(len(text)) * 0.38) | 		return int(float64(len(text)) * 0.38) | ||||||
| 	} | 	} | ||||||
| 	return len(tokenEncoder.Encode(text, nil, nil)) | 	return len(tokenEncoder.Encode(text, nil, nil)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func CountTokenMessages(messages []model.Message, model string) int { | func CountTokenMessages(messages []types.ChatCompletionMessage, model string) int { | ||||||
| 	tokenEncoder := getTokenEncoder(model) | 	tokenEncoder := getTokenEncoder(model) | ||||||
| 	// Reference: | 	// Reference: | ||||||
| 	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | 	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | ||||||
| @@ -102,7 +100,7 @@ func CountTokenMessages(messages []model.Message, model string) int { | |||||||
| 						} | 						} | ||||||
| 						imageTokens, err := countImageTokens(url, detail) | 						imageTokens, err := countImageTokens(url, detail) | ||||||
| 						if err != nil { | 						if err != nil { | ||||||
| 							logger.SysError("error counting image tokens: " + err.Error()) | 							SysError("error counting image tokens: " + err.Error()) | ||||||
| 						} else { | 						} else { | ||||||
| 							tokenNum += imageTokens | 							tokenNum += imageTokens | ||||||
| 						} | 						} | ||||||
| @@ -110,6 +108,7 @@ func CountTokenMessages(messages []model.Message, model string) int { | |||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | 		tokenNum += getTokenNum(tokenEncoder, message.StringContent()) | ||||||
| 		tokenNum += getTokenNum(tokenEncoder, message.Role) | 		tokenNum += getTokenNum(tokenEncoder, message.Role) | ||||||
| 		if message.Name != nil { | 		if message.Name != nil { | ||||||
| 			tokenNum += tokensPerName | 			tokenNum += tokensPerName | ||||||
| @@ -191,13 +190,13 @@ func countImageTokens(url string, detail string) (_ int, err error) { | |||||||
| func CountTokenInput(input any, model string) int { | func CountTokenInput(input any, model string) int { | ||||||
| 	switch v := input.(type) { | 	switch v := input.(type) { | ||||||
| 	case string: | 	case string: | ||||||
| 		return CountTokenText(v, model) | 		return CountTokenInput(v, model) | ||||||
| 	case []string: | 	case []string: | ||||||
| 		text := "" | 		text := "" | ||||||
| 		for _, s := range v { | 		for _, s := range v { | ||||||
| 			text += s | 			text += s | ||||||
| 		} | 		} | ||||||
| 		return CountTokenText(text, model) | 		return CountTokenInput(text, model) | ||||||
| 	} | 	} | ||||||
| 	return 0 | 	return 0 | ||||||
| } | } | ||||||
| @@ -206,3 +205,34 @@ func CountTokenText(text string, model string) int { | |||||||
| 	tokenEncoder := getTokenEncoder(model) | 	tokenEncoder := getTokenEncoder(model) | ||||||
| 	return getTokenNum(tokenEncoder, text) | 	return getTokenNum(tokenEncoder, text) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func CountTokenImage(input interface{}) (int, error) { | ||||||
|  | 	switch v := input.(type) { | ||||||
|  | 	case types.ImageRequest: | ||||||
|  | 		// 处理 ImageRequest | ||||||
|  | 		return calculateToken(v.Model, v.Size, v.N, v.Quality) | ||||||
|  | 	case types.ImageEditRequest: | ||||||
|  | 		// 处理 ImageEditsRequest | ||||||
|  | 		return calculateToken(v.Model, v.Size, v.N, "") | ||||||
|  | 	default: | ||||||
|  | 		return 0, errors.New("unsupported type") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func calculateToken(model string, size string, n int, quality string) (int, error) { | ||||||
|  | 	imageCostRatio, hasValidSize := DalleSizeRatios[model][size] | ||||||
|  | 
 | ||||||
|  | 	if hasValidSize { | ||||||
|  | 		if quality == "hd" && model == "dall-e-3" { | ||||||
|  | 			if size == "1024x1024" { | ||||||
|  | 				imageCostRatio *= 2 | ||||||
|  | 			} else { | ||||||
|  | 				imageCostRatio *= 1.5 | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		return 0, errors.New("size not supported for this image model") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return int(imageCostRatio*1000) * n, nil | ||||||
|  | } | ||||||
							
								
								
									
										207
									
								
								common/utils.go
									
									
									
									
									
								
							
							
						
						
									
										207
									
								
								common/utils.go
									
									
									
									
									
								
							| @@ -2,13 +2,208 @@ package common | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" | 	"github.com/google/uuid" | ||||||
|  | 	"html/template" | ||||||
|  | 	"log" | ||||||
|  | 	"math/rand" | ||||||
|  | 	"net" | ||||||
|  | 	"os" | ||||||
|  | 	"os/exec" | ||||||
|  | 	"runtime" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func LogQuota(quota int) string { | func OpenBrowser(url string) { | ||||||
| 	if config.DisplayInCurrencyEnabled { | 	var err error | ||||||
| 		return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) |  | ||||||
| 	} else { | 	switch runtime.GOOS { | ||||||
| 		return fmt.Sprintf("%d 点额度", quota) | 	case "linux": | ||||||
|  | 		err = exec.Command("xdg-open", url).Start() | ||||||
|  | 	case "windows": | ||||||
|  | 		err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() | ||||||
|  | 	case "darwin": | ||||||
|  | 		err = exec.Command("open", url).Start() | ||||||
|  | 	} | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Println(err) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GetIp() (ip string) { | ||||||
|  | 	ips, err := net.InterfaceAddrs() | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Println(err) | ||||||
|  | 		return ip | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, a := range ips { | ||||||
|  | 		if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { | ||||||
|  | 			if ipNet.IP.To4() != nil { | ||||||
|  | 				ip = ipNet.IP.String() | ||||||
|  | 				if strings.HasPrefix(ip, "10") { | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 				if strings.HasPrefix(ip, "172") { | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 				if strings.HasPrefix(ip, "192.168") { | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 				ip = "" | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var sizeKB = 1024 | ||||||
|  | var sizeMB = sizeKB * 1024 | ||||||
|  | var sizeGB = sizeMB * 1024 | ||||||
|  |  | ||||||
|  | func Bytes2Size(num int64) string { | ||||||
|  | 	numStr := "" | ||||||
|  | 	unit := "B" | ||||||
|  | 	if num/int64(sizeGB) > 1 { | ||||||
|  | 		numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB)) | ||||||
|  | 		unit = "GB" | ||||||
|  | 	} else if num/int64(sizeMB) > 1 { | ||||||
|  | 		numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB))) | ||||||
|  | 		unit = "MB" | ||||||
|  | 	} else if num/int64(sizeKB) > 1 { | ||||||
|  | 		numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB))) | ||||||
|  | 		unit = "KB" | ||||||
|  | 	} else { | ||||||
|  | 		numStr = fmt.Sprintf("%d", num) | ||||||
|  | 	} | ||||||
|  | 	return numStr + " " + unit | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Seconds2Time(num int) (time string) { | ||||||
|  | 	if num/31104000 > 0 { | ||||||
|  | 		time += strconv.Itoa(num/31104000) + " 年 " | ||||||
|  | 		num %= 31104000 | ||||||
|  | 	} | ||||||
|  | 	if num/2592000 > 0 { | ||||||
|  | 		time += strconv.Itoa(num/2592000) + " 个月 " | ||||||
|  | 		num %= 2592000 | ||||||
|  | 	} | ||||||
|  | 	if num/86400 > 0 { | ||||||
|  | 		time += strconv.Itoa(num/86400) + " 天 " | ||||||
|  | 		num %= 86400 | ||||||
|  | 	} | ||||||
|  | 	if num/3600 > 0 { | ||||||
|  | 		time += strconv.Itoa(num/3600) + " 小时 " | ||||||
|  | 		num %= 3600 | ||||||
|  | 	} | ||||||
|  | 	if num/60 > 0 { | ||||||
|  | 		time += strconv.Itoa(num/60) + " 分钟 " | ||||||
|  | 		num %= 60 | ||||||
|  | 	} | ||||||
|  | 	time += strconv.Itoa(num) + " 秒" | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Interface2String(inter interface{}) string { | ||||||
|  | 	switch inter.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		return inter.(string) | ||||||
|  | 	case int: | ||||||
|  | 		return fmt.Sprintf("%d", inter.(int)) | ||||||
|  | 	case float64: | ||||||
|  | 		return fmt.Sprintf("%f", inter.(float64)) | ||||||
|  | 	} | ||||||
|  | 	return "Not Implemented" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func UnescapeHTML(x string) interface{} { | ||||||
|  | 	return template.HTML(x) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func IntMax(a int, b int) int { | ||||||
|  | 	if a >= b { | ||||||
|  | 		return a | ||||||
|  | 	} else { | ||||||
|  | 		return b | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetUUID() string { | ||||||
|  | 	code := uuid.New().String() | ||||||
|  | 	code = strings.Replace(code, "-", "", -1) | ||||||
|  | 	return code | ||||||
|  | } | ||||||
|  |  | ||||||
|  | const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	rand.Seed(time.Now().UnixNano()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GenerateKey() string { | ||||||
|  | 	rand.Seed(time.Now().UnixNano()) | ||||||
|  | 	key := make([]byte, 48) | ||||||
|  | 	for i := 0; i < 16; i++ { | ||||||
|  | 		key[i] = keyChars[rand.Intn(len(keyChars))] | ||||||
|  | 	} | ||||||
|  | 	uuid_ := GetUUID() | ||||||
|  | 	for i := 0; i < 32; i++ { | ||||||
|  | 		c := uuid_[i] | ||||||
|  | 		if i%2 == 0 && c >= 'a' && c <= 'z' { | ||||||
|  | 			c = c - 'a' + 'A' | ||||||
|  | 		} | ||||||
|  | 		key[i+16] = c | ||||||
|  | 	} | ||||||
|  | 	return string(key) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetRandomString(length int) string { | ||||||
|  | 	rand.Seed(time.Now().UnixNano()) | ||||||
|  | 	key := make([]byte, length) | ||||||
|  | 	for i := 0; i < length; i++ { | ||||||
|  | 		key[i] = keyChars[rand.Intn(len(keyChars))] | ||||||
|  | 	} | ||||||
|  | 	return string(key) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetTimestamp() int64 { | ||||||
|  | 	return time.Now().Unix() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetTimeString() string { | ||||||
|  | 	now := time.Now() | ||||||
|  | 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Max(a int, b int) int { | ||||||
|  | 	if a >= b { | ||||||
|  | 		return a | ||||||
|  | 	} else { | ||||||
|  | 		return b | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetOrDefault(env string, defaultValue int) int { | ||||||
|  | 	if env == "" || os.Getenv(env) == "" { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	num, err := strconv.Atoi(os.Getenv(env)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	return num | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func MessageWithRequestId(message string, id string) string { | ||||||
|  | 	return fmt.Sprintf("%s (request id: %s)", message, id) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func String2Int(str string) int { | ||||||
|  | 	num, err := strconv.Atoi(str) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0 | ||||||
|  | 	} | ||||||
|  | 	return num | ||||||
|  | } | ||||||
|   | |||||||
| @@ -1,10 +1,11 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	"one-api/types" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/model" |  | ||||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func GetSubscription(c *gin.Context) { | func GetSubscription(c *gin.Context) { | ||||||
| @@ -13,7 +14,7 @@ func GetSubscription(c *gin.Context) { | |||||||
| 	var err error | 	var err error | ||||||
| 	var token *model.Token | 	var token *model.Token | ||||||
| 	var expiredTime int64 | 	var expiredTime int64 | ||||||
| 	if config.DisplayTokenStatEnabled { | 	if common.DisplayTokenStatEnabled { | ||||||
| 		tokenId := c.GetInt("token_id") | 		tokenId := c.GetInt("token_id") | ||||||
| 		token, err = model.GetTokenById(tokenId) | 		token, err = model.GetTokenById(tokenId) | ||||||
| 		expiredTime = token.ExpiredTime | 		expiredTime = token.ExpiredTime | ||||||
| @@ -23,26 +24,34 @@ func GetSubscription(c *gin.Context) { | |||||||
| 		userId := c.GetInt("id") | 		userId := c.GetInt("id") | ||||||
| 		remainQuota, err = model.GetUserQuota(userId) | 		remainQuota, err = model.GetUserQuota(userId) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			usedQuota, err = model.GetUserUsedQuota(userId) | 			openAIError := types.OpenAIError{ | ||||||
|  | 				Message: err.Error(), | ||||||
|  | 				Type:    "upstream_error", | ||||||
|  | 			} | ||||||
|  | 			c.JSON(200, gin.H{ | ||||||
|  | 				"error": openAIError, | ||||||
|  | 			}) | ||||||
|  | 			return | ||||||
| 		} | 		} | ||||||
|  | 		usedQuota, err = model.GetUserUsedQuota(userId) | ||||||
| 	} | 	} | ||||||
| 	if expiredTime <= 0 { | 	if expiredTime <= 0 { | ||||||
| 		expiredTime = 0 | 		expiredTime = 0 | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		Error := relaymodel.Error{ | 		openAIError := types.OpenAIError{ | ||||||
| 			Message: err.Error(), | 			Message: err.Error(), | ||||||
| 			Type:    "upstream_error", | 			Type:    "upstream_error", | ||||||
| 		} | 		} | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"error": Error, | 			"error": openAIError, | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	quota := remainQuota + usedQuota | 	quota := remainQuota + usedQuota | ||||||
| 	amount := float64(quota) | 	amount := float64(quota) | ||||||
| 	if config.DisplayInCurrencyEnabled { | 	if common.DisplayInCurrencyEnabled { | ||||||
| 		amount /= config.QuotaPerUnit | 		amount /= common.QuotaPerUnit | ||||||
| 	} | 	} | ||||||
| 	if token != nil && token.UnlimitedQuota { | 	if token != nil && token.UnlimitedQuota { | ||||||
| 		amount = 100000000 | 		amount = 100000000 | ||||||
| @@ -56,14 +65,13 @@ func GetSubscription(c *gin.Context) { | |||||||
| 		AccessUntil:        expiredTime, | 		AccessUntil:        expiredTime, | ||||||
| 	} | 	} | ||||||
| 	c.JSON(200, subscription) | 	c.JSON(200, subscription) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetUsage(c *gin.Context) { | func GetUsage(c *gin.Context) { | ||||||
| 	var quota int | 	var quota int | ||||||
| 	var err error | 	var err error | ||||||
| 	var token *model.Token | 	var token *model.Token | ||||||
| 	if config.DisplayTokenStatEnabled { | 	if common.DisplayTokenStatEnabled { | ||||||
| 		tokenId := c.GetInt("token_id") | 		tokenId := c.GetInt("token_id") | ||||||
| 		token, err = model.GetTokenById(tokenId) | 		token, err = model.GetTokenById(tokenId) | ||||||
| 		quota = token.UsedQuota | 		quota = token.UsedQuota | ||||||
| @@ -72,23 +80,22 @@ func GetUsage(c *gin.Context) { | |||||||
| 		quota, err = model.GetUserUsedQuota(userId) | 		quota, err = model.GetUserUsedQuota(userId) | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		Error := relaymodel.Error{ | 		openAIError := types.OpenAIError{ | ||||||
| 			Message: err.Error(), | 			Message: err.Error(), | ||||||
| 			Type:    "one_api_error", | 			Type:    "one_api_error", | ||||||
| 		} | 		} | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"error": Error, | 			"error": openAIError, | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	amount := float64(quota) | 	amount := float64(quota) | ||||||
| 	if config.DisplayInCurrencyEnabled { | 	if common.DisplayInCurrencyEnabled { | ||||||
| 		amount /= config.QuotaPerUnit | 		amount /= common.QuotaPerUnit | ||||||
| 	} | 	} | ||||||
| 	usage := OpenAIUsageResponse{ | 	usage := OpenAIUsageResponse{ | ||||||
| 		Object:     "list", | 		Object:     "list", | ||||||
| 		TotalUsage: amount * 100, | 		TotalUsage: amount * 100, | ||||||
| 	} | 	} | ||||||
| 	c.JSON(200, usage) | 	c.JSON(200, usage) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,16 +1,13 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"github.com/songquanpeng/one-api/model" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	"one-api/providers" | ||||||
|  | 	providersBase "one-api/providers/base" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -49,216 +46,29 @@ type OpenAIUsageResponse struct { | |||||||
| 	TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar | 	TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar | ||||||
| } | } | ||||||
|  |  | ||||||
| type OpenAISBUsageResponse struct { |  | ||||||
| 	Msg  string `json:"msg"` |  | ||||||
| 	Data *struct { |  | ||||||
| 		Credit string `json:"credit"` |  | ||||||
| 	} `json:"data"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type AIProxyUserOverviewResponse struct { |  | ||||||
| 	Success   bool   `json:"success"` |  | ||||||
| 	Message   string `json:"message"` |  | ||||||
| 	ErrorCode int    `json:"error_code"` |  | ||||||
| 	Data      struct { |  | ||||||
| 		TotalPoints float64 `json:"totalPoints"` |  | ||||||
| 	} `json:"data"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type API2GPTUsageResponse struct { |  | ||||||
| 	Object         string  `json:"object"` |  | ||||||
| 	TotalGranted   float64 `json:"total_granted"` |  | ||||||
| 	TotalUsed      float64 `json:"total_used"` |  | ||||||
| 	TotalRemaining float64 `json:"total_remaining"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type APGC2DGPTUsageResponse struct { |  | ||||||
| 	//Grants         interface{} `json:"grants"` |  | ||||||
| 	Object         string  `json:"object"` |  | ||||||
| 	TotalAvailable float64 `json:"total_available"` |  | ||||||
| 	TotalGranted   float64 `json:"total_granted"` |  | ||||||
| 	TotalUsed      float64 `json:"total_used"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // GetAuthHeader get auth header |  | ||||||
| func GetAuthHeader(token string) http.Header { |  | ||||||
| 	h := http.Header{} |  | ||||||
| 	h.Add("Authorization", fmt.Sprintf("Bearer %s", token)) |  | ||||||
| 	return h |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { |  | ||||||
| 	req, err := http.NewRequest(method, url, nil) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	for k := range headers { |  | ||||||
| 		req.Header.Add(k, headers.Get(k)) |  | ||||||
| 	} |  | ||||||
| 	res, err := util.HTTPClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	if res.StatusCode != http.StatusOK { |  | ||||||
| 		return nil, fmt.Errorf("status code: %d", res.StatusCode) |  | ||||||
| 	} |  | ||||||
| 	body, err := io.ReadAll(res.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	err = res.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	return body, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { |  | ||||||
| 	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL()) |  | ||||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) |  | ||||||
|  |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, err |  | ||||||
| 	} |  | ||||||
| 	response := OpenAICreditGrants{} |  | ||||||
| 	err = json.Unmarshal(body, &response) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, err |  | ||||||
| 	} |  | ||||||
| 	channel.UpdateBalance(response.TotalAvailable) |  | ||||||
| 	return response.TotalAvailable, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) { |  | ||||||
| 	url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key) |  | ||||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, err |  | ||||||
| 	} |  | ||||||
| 	response := OpenAISBUsageResponse{} |  | ||||||
| 	err = json.Unmarshal(body, &response) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, err |  | ||||||
| 	} |  | ||||||
| 	if response.Data == nil { |  | ||||||
| 		return 0, errors.New(response.Msg) |  | ||||||
| 	} |  | ||||||
| 	balance, err := strconv.ParseFloat(response.Data.Credit, 64) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, err |  | ||||||
| 	} |  | ||||||
| 	channel.UpdateBalance(balance) |  | ||||||
| 	return balance, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) { |  | ||||||
| 	url := "https://aiproxy.io/api/report/getUserOverview" |  | ||||||
| 	headers := http.Header{} |  | ||||||
| 	headers.Add("Api-Key", channel.Key) |  | ||||||
| 	body, err := GetResponseBody("GET", url, channel, headers) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, err |  | ||||||
| 	} |  | ||||||
| 	response := AIProxyUserOverviewResponse{} |  | ||||||
| 	err = json.Unmarshal(body, &response) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, err |  | ||||||
| 	} |  | ||||||
| 	if !response.Success { |  | ||||||
| 		return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message) |  | ||||||
| 	} |  | ||||||
| 	channel.UpdateBalance(response.Data.TotalPoints) |  | ||||||
| 	return response.Data.TotalPoints, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) { |  | ||||||
| 	url := "https://api.api2gpt.com/dashboard/billing/credit_grants" |  | ||||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) |  | ||||||
|  |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, err |  | ||||||
| 	} |  | ||||||
| 	response := API2GPTUsageResponse{} |  | ||||||
| 	err = json.Unmarshal(body, &response) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, err |  | ||||||
| 	} |  | ||||||
| 	channel.UpdateBalance(response.TotalRemaining) |  | ||||||
| 	return response.TotalRemaining, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { |  | ||||||
| 	url := "https://api.aigc2d.com/dashboard/billing/credit_grants" |  | ||||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, err |  | ||||||
| 	} |  | ||||||
| 	response := APGC2DGPTUsageResponse{} |  | ||||||
| 	err = json.Unmarshal(body, &response) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, err |  | ||||||
| 	} |  | ||||||
| 	channel.UpdateBalance(response.TotalAvailable) |  | ||||||
| 	return response.TotalAvailable, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func updateChannelBalance(channel *model.Channel) (float64, error) { | func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||||
| 	baseURL := common.ChannelBaseURLs[channel.Type] | 	req, err := http.NewRequest("POST", "/balance", nil) | ||||||
| 	if channel.GetBaseURL() == "" { | 	if err != nil { | ||||||
| 		channel.BaseURL = &baseURL | 		return 0, err | ||||||
| 	} | 	} | ||||||
| 	switch channel.Type { | 	w := httptest.NewRecorder() | ||||||
| 	case common.ChannelTypeOpenAI: | 	c, _ := gin.CreateTestContext(w) | ||||||
| 		if channel.GetBaseURL() != "" { | 	c.Request = req | ||||||
| 			baseURL = channel.GetBaseURL() |  | ||||||
| 		} |  | ||||||
| 	case common.ChannelTypeAzure: |  | ||||||
| 		return 0, errors.New("尚未实现") |  | ||||||
| 	case common.ChannelTypeCustom: |  | ||||||
| 		baseURL = channel.GetBaseURL() |  | ||||||
| 	case common.ChannelTypeCloseAI: |  | ||||||
| 		return updateChannelCloseAIBalance(channel) |  | ||||||
| 	case common.ChannelTypeOpenAISB: |  | ||||||
| 		return updateChannelOpenAISBBalance(channel) |  | ||||||
| 	case common.ChannelTypeAIProxy: |  | ||||||
| 		return updateChannelAIProxyBalance(channel) |  | ||||||
| 	case common.ChannelTypeAPI2GPT: |  | ||||||
| 		return updateChannelAPI2GPTBalance(channel) |  | ||||||
| 	case common.ChannelTypeAIGC2D: |  | ||||||
| 		return updateChannelAIGC2DBalance(channel) |  | ||||||
| 	default: |  | ||||||
| 		return 0, errors.New("尚未实现") |  | ||||||
| 	} |  | ||||||
| 	url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL) |  | ||||||
|  |  | ||||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | 	req.Header.Set("Content-Type", "application/json") | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, err | 	provider := providers.GetProvider(channel, c) | ||||||
|  | 	if provider == nil { | ||||||
|  | 		return 0, errors.New("provider not found") | ||||||
| 	} | 	} | ||||||
| 	subscription := OpenAISubscriptionResponse{} |  | ||||||
| 	err = json.Unmarshal(body, &subscription) | 	balanceProvider, ok := provider.(providersBase.BalanceInterface) | ||||||
| 	if err != nil { | 	if !ok { | ||||||
| 		return 0, err | 		return 0, errors.New("provider not implemented") | ||||||
| 	} | 	} | ||||||
| 	now := time.Now() |  | ||||||
| 	startDate := fmt.Sprintf("%s-01", now.Format("2006-01")) | 	return balanceProvider.Balance(channel) | ||||||
| 	endDate := now.Format("2006-01-02") |  | ||||||
| 	if !subscription.HasPaymentMethod { |  | ||||||
| 		startDate = now.AddDate(0, 0, -100).Format("2006-01-02") |  | ||||||
| 	} |  | ||||||
| 	url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate) |  | ||||||
| 	body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, err |  | ||||||
| 	} |  | ||||||
| 	usage := OpenAIUsageResponse{} |  | ||||||
| 	err = json.Unmarshal(body, &usage) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, err |  | ||||||
| 	} |  | ||||||
| 	balance := subscription.HardLimitUSD - usage.TotalUsage/100 |  | ||||||
| 	channel.UpdateBalance(balance) |  | ||||||
| 	return balance, nil |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateChannelBalance(c *gin.Context) { | func UpdateChannelBalance(c *gin.Context) { | ||||||
| @@ -291,7 +101,6 @@ func UpdateChannelBalance(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"balance": balance, | 		"balance": balance, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func updateAllChannelsBalance() error { | func updateAllChannelsBalance() error { | ||||||
| @@ -316,7 +125,7 @@ func updateAllChannelsBalance() error { | |||||||
| 				disableChannel(channel.Id, channel.Name, "余额不足") | 				disableChannel(channel.Id, channel.Name, "余额不足") | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		time.Sleep(config.RequestInterval) | 		time.Sleep(common.RequestInterval) | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| @@ -335,14 +144,13 @@ func UpdateAllChannelsBalance(c *gin.Context) { | |||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func AutomaticallyUpdateChannels(frequency int) { | func AutomaticallyUpdateChannels(frequency int) { | ||||||
| 	for { | 	for { | ||||||
| 		time.Sleep(time.Duration(frequency) * time.Minute) | 		time.Sleep(time.Duration(frequency) * time.Minute) | ||||||
| 		logger.SysLog("updating all channels") | 		common.SysLog("updating all channels") | ||||||
| 		_ = updateAllChannelsBalance() | 		_ = updateAllChannelsBalance() | ||||||
| 		logger.SysLog("channels update done") | 		common.SysLog("channels update done") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,111 +1,99 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"github.com/songquanpeng/one-api/middleware" |  | ||||||
| 	"github.com/songquanpeng/one-api/model" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/constant" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/helper" |  | ||||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"net/url" | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	"one-api/providers" | ||||||
|  | 	providers_base "one-api/providers/base" | ||||||
|  | 	"one-api/types" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" |  | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func buildTestRequest() *relaymodel.GeneralOpenAIRequest { | func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (err error, openaiErr *types.OpenAIError) { | ||||||
| 	testRequest := &relaymodel.GeneralOpenAIRequest{ | 	// 创建一个 http.Request | ||||||
| 		MaxTokens: 1, | 	req, err := http.NewRequest("POST", "/v1/chat/completions", nil) | ||||||
| 		Stream:    false, | 	if err != nil { | ||||||
| 		Model:     "gpt-3.5-turbo", | 		return err, nil | ||||||
| 	} | 	} | ||||||
| 	testMessage := relaymodel.Message{ | 	req.Header.Set("Content-Type", "application/json") | ||||||
| 		Role:    "user", |  | ||||||
| 		Content: "hi", |  | ||||||
| 	} |  | ||||||
| 	testRequest.Messages = append(testRequest.Messages, testMessage) |  | ||||||
| 	return testRequest |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) { |  | ||||||
| 	w := httptest.NewRecorder() | 	w := httptest.NewRecorder() | ||||||
| 	c, _ := gin.CreateTestContext(w) | 	c, _ := gin.CreateTestContext(w) | ||||||
| 	c.Request = &http.Request{ | 	c.Request = req | ||||||
| 		Method: "POST", |  | ||||||
| 		URL:    &url.URL{Path: "/v1/chat/completions"}, | 	// 创建映射 | ||||||
| 		Body:   nil, | 	channelTypeToModel := map[int]string{ | ||||||
| 		Header: make(http.Header), | 		common.ChannelTypePaLM:      "PaLM-2", | ||||||
|  | 		common.ChannelTypeAnthropic: "claude-2", | ||||||
|  | 		common.ChannelTypeBaidu:     "ERNIE-Bot", | ||||||
|  | 		common.ChannelTypeZhipu:     "chatglm_lite", | ||||||
|  | 		common.ChannelTypeAli:       "qwen-turbo", | ||||||
|  | 		common.ChannelType360:       "360GPT_S2_V9", | ||||||
|  | 		common.ChannelTypeXunfei:    "SparkDesk", | ||||||
|  | 		common.ChannelTypeTencent:   "hunyuan", | ||||||
|  | 		common.ChannelTypeAzure:     "gpt-3.5-turbo", | ||||||
| 	} | 	} | ||||||
| 	c.Request.Header.Set("Authorization", "Bearer "+channel.Key) |  | ||||||
| 	c.Request.Header.Set("Content-Type", "application/json") | 	// 从映射中获取模型名称 | ||||||
| 	c.Set("channel", channel.Type) | 	model, ok := channelTypeToModel[channel.Type] | ||||||
| 	c.Set("base_url", channel.GetBaseURL()) | 	if !ok { | ||||||
| 	middleware.SetupContextForSelectedChannel(c, channel, "") | 		model = "gpt-3.5-turbo" // 默认值 | ||||||
| 	meta := util.GetRelayMeta(c) |  | ||||||
| 	apiType := constant.ChannelType2APIType(channel.Type) |  | ||||||
| 	adaptor := helper.GetAdaptor(apiType) |  | ||||||
| 	if adaptor == nil { |  | ||||||
| 		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil |  | ||||||
| 	} | 	} | ||||||
| 	adaptor.Init(meta) | 	request.Model = model | ||||||
| 	modelName := adaptor.GetModelList()[0] |  | ||||||
| 	if !strings.Contains(channel.Models, modelName) { | 	provider := providers.GetProvider(channel, c) | ||||||
| 		modelNames := strings.Split(channel.Models, ",") | 	if provider == nil { | ||||||
| 		if len(modelNames) > 0 { | 		return errors.New("channel not implemented"), nil | ||||||
| 			modelName = modelNames[0] |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 	request := buildTestRequest() | 	chatProvider, ok := provider.(providers_base.ChatInterface) | ||||||
| 	request.Model = modelName | 	if !ok { | ||||||
| 	meta.OriginModelName, meta.ActualModelName = modelName, modelName | 		return errors.New("channel not implemented"), nil | ||||||
| 	convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request) | 	} | ||||||
|  |  | ||||||
|  | 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err, nil | 		return err, nil | ||||||
| 	} | 	} | ||||||
| 	jsonData, err := json.Marshal(convertedRequest) | 	if modelMap != nil && modelMap[request.Model] != "" { | ||||||
| 	if err != nil { | 		request.Model = modelMap[request.Model] | ||||||
| 		return err, nil |  | ||||||
| 	} | 	} | ||||||
| 	requestBody := bytes.NewBuffer(jsonData) |  | ||||||
| 	c.Request.Body = io.NopCloser(requestBody) | 	promptTokens := common.CountTokenMessages(request.Messages, request.Model) | ||||||
| 	resp, err := adaptor.DoRequest(c, meta, requestBody) | 	Usage, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens) | ||||||
| 	if err != nil { | 	if openAIErrorWithStatusCode != nil { | ||||||
| 		return err, nil | 		return nil, &openAIErrorWithStatusCode.OpenAIError | ||||||
| 	} | 	} | ||||||
| 	if resp.StatusCode != http.StatusOK { |  | ||||||
| 		err := util.RelayErrorHandler(resp) | 	if Usage.CompletionTokens == 0 { | ||||||
| 		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error | 		return fmt.Errorf("channel %s, message 补全 tokens 非预期返回 0", channel.Name), nil | ||||||
| 	} | 	} | ||||||
| 	usage, respErr := adaptor.DoResponse(c, resp, meta) |  | ||||||
| 	if respErr != nil { |  | ||||||
| 		return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error |  | ||||||
| 	} |  | ||||||
| 	if usage == nil { |  | ||||||
| 		return errors.New("usage is nil"), nil |  | ||||||
| 	} |  | ||||||
| 	result := w.Result() |  | ||||||
| 	// print result.Body |  | ||||||
| 	respBody, err := io.ReadAll(result.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err, nil |  | ||||||
| 	} |  | ||||||
| 	logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) |  | ||||||
| 	return nil, nil | 	return nil, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func buildTestRequest() *types.ChatCompletionRequest { | ||||||
|  | 	testRequest := &types.ChatCompletionRequest{ | ||||||
|  | 		Messages: []types.ChatCompletionMessage{ | ||||||
|  | 			{ | ||||||
|  | 				Role:    "user", | ||||||
|  | 				Content: "You just need to output 'hi' next.", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		Model:     "", | ||||||
|  | 		MaxTokens: 1, | ||||||
|  | 		Stream:    false, | ||||||
|  | 	} | ||||||
|  | 	return testRequest | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestChannel(c *gin.Context) { | func TestChannel(c *gin.Context) { | ||||||
| 	id, err := strconv.Atoi(c.Param("id")) | 	id, err := strconv.Atoi(c.Param("id")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -123,8 +111,9 @@ func TestChannel(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  | 	testRequest := buildTestRequest() | ||||||
| 	tik := time.Now() | 	tik := time.Now() | ||||||
| 	err, _ = testChannel(channel) | 	err, _ = testChannel(channel, *testRequest) | ||||||
| 	tok := time.Now() | 	tok := time.Now() | ||||||
| 	milliseconds := tok.Sub(tik).Milliseconds() | 	milliseconds := tok.Sub(tik).Milliseconds() | ||||||
| 	go channel.UpdateResponseTime(milliseconds) | 	go channel.UpdateResponseTime(milliseconds) | ||||||
| @@ -142,19 +131,18 @@ func TestChannel(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"time":    consumedTime, | 		"time":    consumedTime, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| var testAllChannelsLock sync.Mutex | var testAllChannelsLock sync.Mutex | ||||||
| var testAllChannelsRunning bool = false | var testAllChannelsRunning bool = false | ||||||
|  |  | ||||||
| func notifyRootUser(subject string, content string) { | func notifyRootUser(subject string, content string) { | ||||||
| 	if config.RootUserEmail == "" { | 	if common.RootUserEmail == "" { | ||||||
| 		config.RootUserEmail = model.GetRootUserEmail() | 		common.RootUserEmail = model.GetRootUserEmail() | ||||||
| 	} | 	} | ||||||
| 	err := common.SendEmail(subject, config.RootUserEmail, content) | 	err := common.SendEmail(subject, common.RootUserEmail, content) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | 		common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -175,8 +163,8 @@ func enableChannel(channelId int, channelName string) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func testAllChannels(notify bool) error { | func testAllChannels(notify bool) error { | ||||||
| 	if config.RootUserEmail == "" { | 	if common.RootUserEmail == "" { | ||||||
| 		config.RootUserEmail = model.GetRootUserEmail() | 		common.RootUserEmail = model.GetRootUserEmail() | ||||||
| 	} | 	} | ||||||
| 	testAllChannelsLock.Lock() | 	testAllChannelsLock.Lock() | ||||||
| 	if testAllChannelsRunning { | 	if testAllChannelsRunning { | ||||||
| @@ -189,7 +177,8 @@ func testAllChannels(notify bool) error { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	var disableThreshold = int64(config.ChannelDisableThreshold * 1000) | 	testRequest := buildTestRequest() | ||||||
|  | 	var disableThreshold = int64(common.ChannelDisableThreshold * 1000) | ||||||
| 	if disableThreshold == 0 { | 	if disableThreshold == 0 { | ||||||
| 		disableThreshold = 10000000 // a impossible value | 		disableThreshold = 10000000 // a impossible value | ||||||
| 	} | 	} | ||||||
| @@ -197,29 +186,29 @@ func testAllChannels(notify bool) error { | |||||||
| 		for _, channel := range channels { | 		for _, channel := range channels { | ||||||
| 			isChannelEnabled := channel.Status == common.ChannelStatusEnabled | 			isChannelEnabled := channel.Status == common.ChannelStatusEnabled | ||||||
| 			tik := time.Now() | 			tik := time.Now() | ||||||
| 			err, openaiErr := testChannel(channel) | 			err, openaiErr := testChannel(channel, *testRequest) | ||||||
| 			tok := time.Now() | 			tok := time.Now() | ||||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | 			milliseconds := tok.Sub(tik).Milliseconds() | ||||||
| 			if isChannelEnabled && milliseconds > disableThreshold { | 			if milliseconds > disableThreshold { | ||||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | 				err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) | ||||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||||
| 			} | 			} | ||||||
| 			if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { | 			if isChannelEnabled && shouldDisableChannel(openaiErr, -1) { | ||||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||||
| 			} | 			} | ||||||
| 			if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { | 			if !isChannelEnabled && shouldEnableChannel(err, openaiErr) { | ||||||
| 				enableChannel(channel.Id, channel.Name) | 				enableChannel(channel.Id, channel.Name) | ||||||
| 			} | 			} | ||||||
| 			channel.UpdateResponseTime(milliseconds) | 			channel.UpdateResponseTime(milliseconds) | ||||||
| 			time.Sleep(config.RequestInterval) | 			time.Sleep(common.RequestInterval) | ||||||
| 		} | 		} | ||||||
| 		testAllChannelsLock.Lock() | 		testAllChannelsLock.Lock() | ||||||
| 		testAllChannelsRunning = false | 		testAllChannelsRunning = false | ||||||
| 		testAllChannelsLock.Unlock() | 		testAllChannelsLock.Unlock() | ||||||
| 		if notify { | 		if notify { | ||||||
| 			err := common.SendEmail("通道测试完成", config.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") | 			err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | 				common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| @@ -239,14 +228,13 @@ func TestAllChannels(c *gin.Context) { | |||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func AutomaticallyTestChannels(frequency int) { | func AutomaticallyTestChannels(frequency int) { | ||||||
| 	for { | 	for { | ||||||
| 		time.Sleep(time.Duration(frequency) * time.Minute) | 		time.Sleep(time.Duration(frequency) * time.Minute) | ||||||
| 		logger.SysLog("testing all channels") | 		common.SysLog("testing all channels") | ||||||
| 		_ = testAllChannels(false) | 		_ = testAllChannels(false) | ||||||
| 		logger.SysLog("channel test finished") | 		common.SysLog("channel test finished") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,13 +1,13 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" |  | ||||||
| 	"github.com/songquanpeng/one-api/model" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func GetAllChannels(c *gin.Context) { | func GetAllChannels(c *gin.Context) { | ||||||
| @@ -15,7 +15,7 @@ func GetAllChannels(c *gin.Context) { | |||||||
| 	if p < 0 { | 	if p < 0 { | ||||||
| 		p = 0 | 		p = 0 | ||||||
| 	} | 	} | ||||||
| 	channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, false) | 	channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -28,7 +28,6 @@ func GetAllChannels(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    channels, | 		"data":    channels, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchChannels(c *gin.Context) { | func SearchChannels(c *gin.Context) { | ||||||
| @@ -46,7 +45,6 @@ func SearchChannels(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    channels, | 		"data":    channels, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetChannel(c *gin.Context) { | func GetChannel(c *gin.Context) { | ||||||
| @@ -71,7 +69,6 @@ func GetChannel(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    channel, | 		"data":    channel, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func AddChannel(c *gin.Context) { | func AddChannel(c *gin.Context) { | ||||||
| @@ -84,7 +81,7 @@ func AddChannel(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	channel.CreatedTime = helper.GetTimestamp() | 	channel.CreatedTime = common.GetTimestamp() | ||||||
| 	keys := strings.Split(channel.Key, "\n") | 	keys := strings.Split(channel.Key, "\n") | ||||||
| 	channels := make([]model.Channel, 0, len(keys)) | 	channels := make([]model.Channel, 0, len(keys)) | ||||||
| 	for _, key := range keys { | 	for _, key := range keys { | ||||||
| @@ -107,7 +104,6 @@ func AddChannel(c *gin.Context) { | |||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func DeleteChannel(c *gin.Context) { | func DeleteChannel(c *gin.Context) { | ||||||
| @@ -125,7 +121,6 @@ func DeleteChannel(c *gin.Context) { | |||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func DeleteDisabledChannel(c *gin.Context) { | func DeleteDisabledChannel(c *gin.Context) { | ||||||
| @@ -142,7 +137,6 @@ func DeleteDisabledChannel(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    rows, | 		"data":    rows, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateChannel(c *gin.Context) { | func UpdateChannel(c *gin.Context) { | ||||||
| @@ -168,5 +162,4 @@ func UpdateChannel(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    channel, | 		"data":    channel, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -5,16 +5,14 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-contrib/sessions" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"github.com/songquanpeng/one-api/model" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-contrib/sessions" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type GitHubOAuthResponse struct { | type GitHubOAuthResponse struct { | ||||||
| @@ -33,7 +31,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { | |||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		return nil, errors.New("无效的参数") | 		return nil, errors.New("无效的参数") | ||||||
| 	} | 	} | ||||||
| 	values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code} | 	values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code} | ||||||
| 	jsonData, err := json.Marshal(values) | 	jsonData, err := json.Marshal(values) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -49,7 +47,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { | |||||||
| 	} | 	} | ||||||
| 	res, err := client.Do(req) | 	res, err := client.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysLog(err.Error()) | 		common.SysLog(err.Error()) | ||||||
| 		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") | 		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") | ||||||
| 	} | 	} | ||||||
| 	defer res.Body.Close() | 	defer res.Body.Close() | ||||||
| @@ -65,7 +63,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { | |||||||
| 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) | 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) | ||||||
| 	res2, err := client.Do(req) | 	res2, err := client.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysLog(err.Error()) | 		common.SysLog(err.Error()) | ||||||
| 		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") | 		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") | ||||||
| 	} | 	} | ||||||
| 	defer res2.Body.Close() | 	defer res2.Body.Close() | ||||||
| @@ -96,7 +94,7 @@ func GitHubOAuth(c *gin.Context) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if !config.GitHubOAuthEnabled { | 	if !common.GitHubOAuthEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "管理员未开启通过 GitHub 登录以及注册", | 			"message": "管理员未开启通过 GitHub 登录以及注册", | ||||||
| @@ -125,7 +123,7 @@ func GitHubOAuth(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		if config.RegisterEnabled { | 		if common.RegisterEnabled { | ||||||
| 			user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) | 			user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) | ||||||
| 			if githubUser.Name != "" { | 			if githubUser.Name != "" { | ||||||
| 				user.DisplayName = githubUser.Name | 				user.DisplayName = githubUser.Name | ||||||
| @@ -163,7 +161,7 @@ func GitHubOAuth(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GitHubBind(c *gin.Context) { | func GitHubBind(c *gin.Context) { | ||||||
| 	if !config.GitHubOAuthEnabled { | 	if !common.GitHubOAuthEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "管理员未开启通过 GitHub 登录以及注册", | 			"message": "管理员未开启通过 GitHub 登录以及注册", | ||||||
| @@ -214,12 +212,11 @@ func GitHubBind(c *gin.Context) { | |||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "bind", | 		"message": "bind", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GenerateOAuthCode(c *gin.Context) { | func GenerateOAuthCode(c *gin.Context) { | ||||||
| 	session := sessions.Default(c) | 	session := sessions.Default(c) | ||||||
| 	state := helper.GetRandomString(12) | 	state := common.GetRandomString(12) | ||||||
| 	session.Set("oauth_state", state) | 	session.Set("oauth_state", state) | ||||||
| 	err := session.Save() | 	err := session.Save() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|   | |||||||
| @@ -1,9 +1,10 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func GetGroups(c *gin.Context) { | func GetGroups(c *gin.Context) { | ||||||
|   | |||||||
| @@ -1,11 +1,12 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/model" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func GetAllLogs(c *gin.Context) { | func GetAllLogs(c *gin.Context) { | ||||||
| @@ -20,7 +21,7 @@ func GetAllLogs(c *gin.Context) { | |||||||
| 	tokenName := c.Query("token_name") | 	tokenName := c.Query("token_name") | ||||||
| 	modelName := c.Query("model_name") | 	modelName := c.Query("model_name") | ||||||
| 	channel, _ := strconv.Atoi(c.Query("channel")) | 	channel, _ := strconv.Atoi(c.Query("channel")) | ||||||
| 	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel) | 	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -33,7 +34,6 @@ func GetAllLogs(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    logs, | 		"data":    logs, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetUserLogs(c *gin.Context) { | func GetUserLogs(c *gin.Context) { | ||||||
| @@ -47,7 +47,7 @@ func GetUserLogs(c *gin.Context) { | |||||||
| 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | ||||||
| 	tokenName := c.Query("token_name") | 	tokenName := c.Query("token_name") | ||||||
| 	modelName := c.Query("model_name") | 	modelName := c.Query("model_name") | ||||||
| 	logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage) | 	logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -60,7 +60,6 @@ func GetUserLogs(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    logs, | 		"data":    logs, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchAllLogs(c *gin.Context) { | func SearchAllLogs(c *gin.Context) { | ||||||
| @@ -78,7 +77,6 @@ func SearchAllLogs(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    logs, | 		"data":    logs, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchUserLogs(c *gin.Context) { | func SearchUserLogs(c *gin.Context) { | ||||||
| @@ -97,7 +95,6 @@ func SearchUserLogs(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    logs, | 		"data":    logs, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetLogsStat(c *gin.Context) { | func GetLogsStat(c *gin.Context) { | ||||||
| @@ -118,7 +115,6 @@ func GetLogsStat(c *gin.Context) { | |||||||
| 			//"token": tokenNum, | 			//"token": tokenNum, | ||||||
| 		}, | 		}, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetLogsSelfStat(c *gin.Context) { | func GetLogsSelfStat(c *gin.Context) { | ||||||
| @@ -139,7 +135,6 @@ func GetLogsSelfStat(c *gin.Context) { | |||||||
| 			//"token": tokenNum, | 			//"token": tokenNum, | ||||||
| 		}, | 		}, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func DeleteHistoryLogs(c *gin.Context) { | func DeleteHistoryLogs(c *gin.Context) { | ||||||
| @@ -164,5 +159,4 @@ func DeleteHistoryLogs(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    count, | 		"data":    count, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -3,10 +3,9 @@ package controller | |||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/model" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| @@ -19,57 +18,53 @@ func GetStatus(c *gin.Context) { | |||||||
| 		"data": gin.H{ | 		"data": gin.H{ | ||||||
| 			"version":             common.Version, | 			"version":             common.Version, | ||||||
| 			"start_time":          common.StartTime, | 			"start_time":          common.StartTime, | ||||||
| 			"email_verification":  config.EmailVerificationEnabled, | 			"email_verification":  common.EmailVerificationEnabled, | ||||||
| 			"github_oauth":        config.GitHubOAuthEnabled, | 			"github_oauth":        common.GitHubOAuthEnabled, | ||||||
| 			"github_client_id":    config.GitHubClientId, | 			"github_client_id":    common.GitHubClientId, | ||||||
| 			"system_name":         config.SystemName, | 			"system_name":         common.SystemName, | ||||||
| 			"logo":                config.Logo, | 			"logo":                common.Logo, | ||||||
| 			"footer_html":         config.Footer, | 			"footer_html":         common.Footer, | ||||||
| 			"wechat_qrcode":       config.WeChatAccountQRCodeImageURL, | 			"wechat_qrcode":       common.WeChatAccountQRCodeImageURL, | ||||||
| 			"wechat_login":        config.WeChatAuthEnabled, | 			"wechat_login":        common.WeChatAuthEnabled, | ||||||
| 			"server_address":      config.ServerAddress, | 			"server_address":      common.ServerAddress, | ||||||
| 			"turnstile_check":     config.TurnstileCheckEnabled, | 			"turnstile_check":     common.TurnstileCheckEnabled, | ||||||
| 			"turnstile_site_key":  config.TurnstileSiteKey, | 			"turnstile_site_key":  common.TurnstileSiteKey, | ||||||
| 			"top_up_link":         config.TopUpLink, | 			"top_up_link":         common.TopUpLink, | ||||||
| 			"chat_link":           config.ChatLink, | 			"chat_link":           common.ChatLink, | ||||||
| 			"quota_per_unit":      config.QuotaPerUnit, | 			"quota_per_unit":      common.QuotaPerUnit, | ||||||
| 			"display_in_currency": config.DisplayInCurrencyEnabled, | 			"display_in_currency": common.DisplayInCurrencyEnabled, | ||||||
| 		}, | 		}, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetNotice(c *gin.Context) { | func GetNotice(c *gin.Context) { | ||||||
| 	config.OptionMapRWMutex.RLock() | 	common.OptionMapRWMutex.RLock() | ||||||
| 	defer config.OptionMapRWMutex.RUnlock() | 	defer common.OptionMapRWMutex.RUnlock() | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    config.OptionMap["Notice"], | 		"data":    common.OptionMap["Notice"], | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAbout(c *gin.Context) { | func GetAbout(c *gin.Context) { | ||||||
| 	config.OptionMapRWMutex.RLock() | 	common.OptionMapRWMutex.RLock() | ||||||
| 	defer config.OptionMapRWMutex.RUnlock() | 	defer common.OptionMapRWMutex.RUnlock() | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    config.OptionMap["About"], | 		"data":    common.OptionMap["About"], | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetHomePageContent(c *gin.Context) { | func GetHomePageContent(c *gin.Context) { | ||||||
| 	config.OptionMapRWMutex.RLock() | 	common.OptionMapRWMutex.RLock() | ||||||
| 	defer config.OptionMapRWMutex.RUnlock() | 	defer common.OptionMapRWMutex.RUnlock() | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    config.OptionMap["HomePageContent"], | 		"data":    common.OptionMap["HomePageContent"], | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func SendEmailVerification(c *gin.Context) { | func SendEmailVerification(c *gin.Context) { | ||||||
| @@ -81,9 +76,9 @@ func SendEmailVerification(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if config.EmailDomainRestrictionEnabled { | 	if common.EmailDomainRestrictionEnabled { | ||||||
| 		allowed := false | 		allowed := false | ||||||
| 		for _, domain := range config.EmailDomainWhitelist { | 		for _, domain := range common.EmailDomainWhitelist { | ||||||
| 			if strings.HasSuffix(email, "@"+domain) { | 			if strings.HasSuffix(email, "@"+domain) { | ||||||
| 				allowed = true | 				allowed = true | ||||||
| 				break | 				break | ||||||
| @@ -106,10 +101,10 @@ func SendEmailVerification(c *gin.Context) { | |||||||
| 	} | 	} | ||||||
| 	code := common.GenerateVerificationCode(6) | 	code := common.GenerateVerificationCode(6) | ||||||
| 	common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose) | 	common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose) | ||||||
| 	subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName) | 	subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) | ||||||
| 	content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+ | 	content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+ | ||||||
| 		"<p>您的验证码为: <strong>%s</strong></p>"+ | 		"<p>您的验证码为: <strong>%s</strong></p>"+ | ||||||
| 		"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, code, common.VerificationValidMinutes) | 		"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes) | ||||||
| 	err := common.SendEmail(subject, email, content) | 	err := common.SendEmail(subject, email, content) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| @@ -122,7 +117,6 @@ func SendEmailVerification(c *gin.Context) { | |||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func SendPasswordResetEmail(c *gin.Context) { | func SendPasswordResetEmail(c *gin.Context) { | ||||||
| @@ -143,12 +137,12 @@ func SendPasswordResetEmail(c *gin.Context) { | |||||||
| 	} | 	} | ||||||
| 	code := common.GenerateVerificationCode(0) | 	code := common.GenerateVerificationCode(0) | ||||||
| 	common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) | 	common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) | ||||||
| 	link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code) | 	link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) | ||||||
| 	subject := fmt.Sprintf("%s密码重置", config.SystemName) | 	subject := fmt.Sprintf("%s密码重置", common.SystemName) | ||||||
| 	content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ | 	content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ | ||||||
| 		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ | 		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ | ||||||
| 		"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ | 		"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ | ||||||
| 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, link, link, common.VerificationValidMinutes) | 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes) | ||||||
| 	err := common.SendEmail(subject, email, content) | 	err := common.SendEmail(subject, email, content) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| @@ -161,7 +155,6 @@ func SendPasswordResetEmail(c *gin.Context) { | |||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| type PasswordResetRequest struct { | type PasswordResetRequest struct { | ||||||
| @@ -201,5 +194,4 @@ func ResetPassword(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    password, | 		"data":    password, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -2,14 +2,13 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/constant" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/helper" |  | ||||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	"one-api/types" | ||||||
|  | 	"sort" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // https://platform.openai.com/docs/api-reference/models/list | // https://platform.openai.com/docs/api-reference/models/list | ||||||
| @@ -30,96 +29,74 @@ type OpenAIModelPermission struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| type OpenAIModels struct { | type OpenAIModels struct { | ||||||
| 	Id         string                  `json:"id"` | 	Id         string                   `json:"id"` | ||||||
| 	Object     string                  `json:"object"` | 	Object     string                   `json:"object"` | ||||||
| 	Created    int                     `json:"created"` | 	Created    int                      `json:"created"` | ||||||
| 	OwnedBy    string                  `json:"owned_by"` | 	OwnedBy    *string                  `json:"owned_by"` | ||||||
| 	Permission []OpenAIModelPermission `json:"permission"` | 	Permission *[]OpenAIModelPermission `json:"permission"` | ||||||
| 	Root       string                  `json:"root"` | 	Root       *string                  `json:"root"` | ||||||
| 	Parent     *string                 `json:"parent"` | 	Parent     *string                  `json:"parent"` | ||||||
| } | } | ||||||
|  |  | ||||||
| var openAIModels []OpenAIModels | var openAIModels []OpenAIModels | ||||||
| var openAIModelsMap map[string]OpenAIModels | var openAIModelsMap map[string]OpenAIModels | ||||||
| var channelId2Models map[int][]string |  | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
| 	var permission []OpenAIModelPermission |  | ||||||
| 	permission = append(permission, OpenAIModelPermission{ |  | ||||||
| 		Id:                 "modelperm-LwHkVFn8AcMItP432fKKDIKJ", |  | ||||||
| 		Object:             "model_permission", |  | ||||||
| 		Created:            1626777600, |  | ||||||
| 		AllowCreateEngine:  true, |  | ||||||
| 		AllowSampling:      true, |  | ||||||
| 		AllowLogprobs:      true, |  | ||||||
| 		AllowSearchIndices: false, |  | ||||||
| 		AllowView:          true, |  | ||||||
| 		AllowFineTuning:    false, |  | ||||||
| 		Organization:       "*", |  | ||||||
| 		Group:              nil, |  | ||||||
| 		IsBlocking:         false, |  | ||||||
| 	}) |  | ||||||
| 	// https://platform.openai.com/docs/models/model-endpoint-compatibility | 	// https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||||
| 	for i := 0; i < constant.APITypeDummy; i++ { | 	keys := make([]string, 0, len(common.ModelRatio)) | ||||||
| 		if i == constant.APITypeAIProxyLibrary { | 	for k := range common.ModelRatio { | ||||||
| 			continue | 		keys = append(keys, k) | ||||||
| 		} |  | ||||||
| 		adaptor := helper.GetAdaptor(i) |  | ||||||
| 		channelName := adaptor.GetChannelName() |  | ||||||
| 		modelNames := adaptor.GetModelList() |  | ||||||
| 		for _, modelName := range modelNames { |  | ||||||
| 			openAIModels = append(openAIModels, OpenAIModels{ |  | ||||||
| 				Id:         modelName, |  | ||||||
| 				Object:     "model", |  | ||||||
| 				Created:    1626777600, |  | ||||||
| 				OwnedBy:    channelName, |  | ||||||
| 				Permission: permission, |  | ||||||
| 				Root:       modelName, |  | ||||||
| 				Parent:     nil, |  | ||||||
| 			}) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 	for _, channelType := range openai.CompatibleChannels { | 	sort.Strings(keys) | ||||||
| 		if channelType == common.ChannelTypeAzure { |  | ||||||
| 			continue | 	for _, modelId := range keys { | ||||||
| 		} | 		openAIModels = append(openAIModels, OpenAIModels{ | ||||||
| 		channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) | 			Id:         modelId, | ||||||
| 		for _, modelName := range channelModelList { | 			Object:     "model", | ||||||
| 			openAIModels = append(openAIModels, OpenAIModels{ | 			Created:    1677649963, | ||||||
| 				Id:         modelName, | 			OwnedBy:    nil, | ||||||
| 				Object:     "model", | 			Permission: nil, | ||||||
| 				Created:    1626777600, | 			Root:       nil, | ||||||
| 				OwnedBy:    channelName, | 			Parent:     nil, | ||||||
| 				Permission: permission, | 		}) | ||||||
| 				Root:       modelName, |  | ||||||
| 				Parent:     nil, |  | ||||||
| 			}) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	openAIModelsMap = make(map[string]OpenAIModels) | 	openAIModelsMap = make(map[string]OpenAIModels) | ||||||
| 	for _, model := range openAIModels { | 	for _, model := range openAIModels { | ||||||
| 		openAIModelsMap[model.Id] = model | 		openAIModelsMap[model.Id] = model | ||||||
| 	} | 	} | ||||||
| 	channelId2Models = make(map[int][]string) |  | ||||||
| 	for i := 1; i < common.ChannelTypeDummy; i++ { |  | ||||||
| 		adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i)) |  | ||||||
| 		meta := &util.RelayMeta{ |  | ||||||
| 			ChannelType: i, |  | ||||||
| 		} |  | ||||||
| 		adaptor.Init(meta) |  | ||||||
| 		channelId2Models[i] = adaptor.GetModelList() |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func DashboardListModels(c *gin.Context) { |  | ||||||
| 	c.JSON(http.StatusOK, gin.H{ |  | ||||||
| 		"success": true, |  | ||||||
| 		"message": "", |  | ||||||
| 		"data":    channelId2Models, |  | ||||||
| 	}) |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func ListModels(c *gin.Context) { | func ListModels(c *gin.Context) { | ||||||
|  | 	groupName := c.GetString("group") | ||||||
|  |  | ||||||
|  | 	models, err := model.CacheGetGroupModels(groupName) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	sort.Strings(models) | ||||||
|  |  | ||||||
|  | 	groupOpenAIModels := make([]OpenAIModels, 0, len(models)) | ||||||
|  | 	for _, modelId := range models { | ||||||
|  | 		groupOpenAIModels = append(groupOpenAIModels, OpenAIModels{ | ||||||
|  | 			Id:         modelId, | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    nil, | ||||||
|  | 			Permission: nil, | ||||||
|  | 			Root:       nil, | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	c.JSON(200, gin.H{ | ||||||
|  | 		"object": "list", | ||||||
|  | 		"data":   groupOpenAIModels, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ListModelsForAdmin(c *gin.Context) { | ||||||
| 	c.JSON(200, gin.H{ | 	c.JSON(200, gin.H{ | ||||||
| 		"object": "list", | 		"object": "list", | ||||||
| 		"data":   openAIModels, | 		"data":   openAIModels, | ||||||
| @@ -131,14 +108,14 @@ func RetrieveModel(c *gin.Context) { | |||||||
| 	if model, ok := openAIModelsMap[modelId]; ok { | 	if model, ok := openAIModelsMap[modelId]; ok { | ||||||
| 		c.JSON(200, model) | 		c.JSON(200, model) | ||||||
| 	} else { | 	} else { | ||||||
| 		Error := relaymodel.Error{ | 		openAIError := types.OpenAIError{ | ||||||
| 			Message: fmt.Sprintf("The model '%s' does not exist", modelId), | 			Message: fmt.Sprintf("The model '%s' does not exist", modelId), | ||||||
| 			Type:    "invalid_request_error", | 			Type:    "invalid_request_error", | ||||||
| 			Param:   "model", | 			Param:   "model", | ||||||
| 			Code:    "model_not_found", | 			Code:    "model_not_found", | ||||||
| 		} | 		} | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"error": Error, | 			"error": openAIError, | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -2,10 +2,9 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" |  | ||||||
| 	"github.com/songquanpeng/one-api/model" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| @@ -13,17 +12,17 @@ import ( | |||||||
|  |  | ||||||
| func GetOptions(c *gin.Context) { | func GetOptions(c *gin.Context) { | ||||||
| 	var options []*model.Option | 	var options []*model.Option | ||||||
| 	config.OptionMapRWMutex.Lock() | 	common.OptionMapRWMutex.Lock() | ||||||
| 	for k, v := range config.OptionMap { | 	for k, v := range common.OptionMap { | ||||||
| 		if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { | 		if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		options = append(options, &model.Option{ | 		options = append(options, &model.Option{ | ||||||
| 			Key:   k, | 			Key:   k, | ||||||
| 			Value: helper.Interface2String(v), | 			Value: common.Interface2String(v), | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| 	config.OptionMapRWMutex.Unlock() | 	common.OptionMapRWMutex.Unlock() | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| @@ -43,16 +42,8 @@ func UpdateOption(c *gin.Context) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	switch option.Key { | 	switch option.Key { | ||||||
| 	case "Theme": |  | ||||||
| 		if !config.ValidThemes[option.Value] { |  | ||||||
| 			c.JSON(http.StatusOK, gin.H{ |  | ||||||
| 				"success": false, |  | ||||||
| 				"message": "无效的主题", |  | ||||||
| 			}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	case "GitHubOAuthEnabled": | 	case "GitHubOAuthEnabled": | ||||||
| 		if option.Value == "true" && config.GitHubClientId == "" { | 		if option.Value == "true" && common.GitHubClientId == "" { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", | 				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", | ||||||
| @@ -60,7 +51,7 @@ func UpdateOption(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	case "EmailDomainRestrictionEnabled": | 	case "EmailDomainRestrictionEnabled": | ||||||
| 		if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 { | 		if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", | 				"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", | ||||||
| @@ -68,7 +59,7 @@ func UpdateOption(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	case "WeChatAuthEnabled": | 	case "WeChatAuthEnabled": | ||||||
| 		if option.Value == "true" && config.WeChatServerAddress == "" { | 		if option.Value == "true" && common.WeChatServerAddress == "" { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "无法启用微信登录,请先填入微信登录相关配置信息!", | 				"message": "无法启用微信登录,请先填入微信登录相关配置信息!", | ||||||
| @@ -76,7 +67,7 @@ func UpdateOption(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	case "TurnstileCheckEnabled": | 	case "TurnstileCheckEnabled": | ||||||
| 		if option.Value == "true" && config.TurnstileSiteKey == "" { | 		if option.Value == "true" && common.TurnstileSiteKey == "" { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", | 				"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", | ||||||
|   | |||||||
| @@ -2,10 +2,9 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" |  | ||||||
| 	"github.com/songquanpeng/one-api/model" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -14,7 +13,7 @@ func GetAllRedemptions(c *gin.Context) { | |||||||
| 	if p < 0 { | 	if p < 0 { | ||||||
| 		p = 0 | 		p = 0 | ||||||
| 	} | 	} | ||||||
| 	redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage) | 	redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -106,12 +105,12 @@ func AddRedemption(c *gin.Context) { | |||||||
| 	} | 	} | ||||||
| 	var keys []string | 	var keys []string | ||||||
| 	for i := 0; i < redemption.Count; i++ { | 	for i := 0; i < redemption.Count; i++ { | ||||||
| 		key := helper.GetUUID() | 		key := common.GetUUID() | ||||||
| 		cleanRedemption := model.Redemption{ | 		cleanRedemption := model.Redemption{ | ||||||
| 			UserId:      c.GetInt("id"), | 			UserId:      c.GetInt("id"), | ||||||
| 			Name:        redemption.Name, | 			Name:        redemption.Name, | ||||||
| 			Key:         key, | 			Key:         key, | ||||||
| 			CreatedTime: helper.GetTimestamp(), | 			CreatedTime: common.GetTimestamp(), | ||||||
| 			Quota:       redemption.Quota, | 			Quota:       redemption.Quota, | ||||||
| 		} | 		} | ||||||
| 		err = cleanRedemption.Insert() | 		err = cleanRedemption.Insert() | ||||||
|   | |||||||
							
								
								
									
										94
									
								
								controller/relay-chat.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								controller/relay-chat.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"math" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	providersBase "one-api/providers/base" | ||||||
|  | 	"one-api/types" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func RelayChat(c *gin.Context) { | ||||||
|  |  | ||||||
|  | 	var chatRequest types.ChatCompletionRequest | ||||||
|  | 	if err := common.UnmarshalBodyReusable(c, &chatRequest); err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	channel, pass := fetchChannel(c, chatRequest.Model) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if chatRequest.MaxTokens < 0 || chatRequest.MaxTokens > math.MaxInt32/2 { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 解析模型映射 | ||||||
|  | 	var isModelMapped bool | ||||||
|  | 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if modelMap != nil && modelMap[chatRequest.Model] != "" { | ||||||
|  | 		chatRequest.Model = modelMap[chatRequest.Model] | ||||||
|  | 		isModelMapped = true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取供应商 | ||||||
|  | 	provider, pass := getProvider(c, channel, common.RelayModeChatCompletions) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	chatProvider, ok := provider.(providersBase.ChatInterface) | ||||||
|  | 	if !ok { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取Input Tokens | ||||||
|  | 	promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model) | ||||||
|  |  | ||||||
|  | 	var quotaInfo *QuotaInfo | ||||||
|  | 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||||
|  | 	var usage *types.Usage | ||||||
|  | 	quotaInfo, errWithCode = generateQuotaInfo(c, chatRequest.Model, promptTokens) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	usage, errWithCode = chatProvider.ChatAction(&chatRequest, isModelMapped, promptTokens) | ||||||
|  |  | ||||||
|  | 	// 如果报错,则退还配额 | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		tokenId := c.GetInt("token_id") | ||||||
|  | 		if quotaInfo.HandelStatus { | ||||||
|  | 			go func(ctx context.Context) { | ||||||
|  | 				// return pre-consumed quota | ||||||
|  | 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||||
|  | 				} | ||||||
|  | 			}(c.Request.Context()) | ||||||
|  | 		} | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} else { | ||||||
|  | 		tokenName := c.GetString("token_name") | ||||||
|  | 		// 如果没有报错,则消费配额 | ||||||
|  | 		go func(ctx context.Context) { | ||||||
|  | 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.LogError(ctx, err.Error()) | ||||||
|  | 			} | ||||||
|  | 		}(c.Request.Context()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										94
									
								
								controller/relay-completions.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								controller/relay-completions.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"math" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	providersBase "one-api/providers/base" | ||||||
|  | 	"one-api/types" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func RelayCompletions(c *gin.Context) { | ||||||
|  |  | ||||||
|  | 	var completionRequest types.CompletionRequest | ||||||
|  | 	if err := common.UnmarshalBodyReusable(c, &completionRequest); err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	channel, pass := fetchChannel(c, completionRequest.Model) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if completionRequest.MaxTokens < 0 || completionRequest.MaxTokens > math.MaxInt32/2 { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 解析模型映射 | ||||||
|  | 	var isModelMapped bool | ||||||
|  | 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if modelMap != nil && modelMap[completionRequest.Model] != "" { | ||||||
|  | 		completionRequest.Model = modelMap[completionRequest.Model] | ||||||
|  | 		isModelMapped = true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取供应商 | ||||||
|  | 	provider, pass := getProvider(c, channel, common.RelayModeCompletions) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	completionProvider, ok := provider.(providersBase.CompletionInterface) | ||||||
|  | 	if !ok { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取Input Tokens | ||||||
|  | 	promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model) | ||||||
|  |  | ||||||
|  | 	var quotaInfo *QuotaInfo | ||||||
|  | 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||||
|  | 	var usage *types.Usage | ||||||
|  | 	quotaInfo, errWithCode = generateQuotaInfo(c, completionRequest.Model, promptTokens) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	usage, errWithCode = completionProvider.CompleteAction(&completionRequest, isModelMapped, promptTokens) | ||||||
|  |  | ||||||
|  | 	// 如果报错,则退还配额 | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		tokenId := c.GetInt("token_id") | ||||||
|  | 		if quotaInfo.HandelStatus { | ||||||
|  | 			go func(ctx context.Context) { | ||||||
|  | 				// return pre-consumed quota | ||||||
|  | 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||||
|  | 				} | ||||||
|  | 			}(c.Request.Context()) | ||||||
|  | 		} | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} else { | ||||||
|  | 		tokenName := c.GetString("token_name") | ||||||
|  | 		// 如果没有报错,则消费配额 | ||||||
|  | 		go func(ctx context.Context) { | ||||||
|  | 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.LogError(ctx, err.Error()) | ||||||
|  | 			} | ||||||
|  | 		}(c.Request.Context()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										93
									
								
								controller/relay-embeddings.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								controller/relay-embeddings.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,93 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	providersBase "one-api/providers/base" | ||||||
|  | 	"one-api/types" | ||||||
|  | 	"strings" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func RelayEmbeddings(c *gin.Context) { | ||||||
|  |  | ||||||
|  | 	var embeddingsRequest types.EmbeddingRequest | ||||||
|  | 	if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||||
|  | 		embeddingsRequest.Model = c.Param("model") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := common.UnmarshalBodyReusable(c, &embeddingsRequest); err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	channel, pass := fetchChannel(c, embeddingsRequest.Model) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 解析模型映射 | ||||||
|  | 	var isModelMapped bool | ||||||
|  | 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if modelMap != nil && modelMap[embeddingsRequest.Model] != "" { | ||||||
|  | 		embeddingsRequest.Model = modelMap[embeddingsRequest.Model] | ||||||
|  | 		isModelMapped = true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取供应商 | ||||||
|  | 	provider, pass := getProvider(c, channel, common.RelayModeEmbeddings) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	embeddingsProvider, ok := provider.(providersBase.EmbeddingsInterface) | ||||||
|  | 	if !ok { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取Input Tokens | ||||||
|  | 	promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model) | ||||||
|  |  | ||||||
|  | 	var quotaInfo *QuotaInfo | ||||||
|  | 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||||
|  | 	var usage *types.Usage | ||||||
|  | 	quotaInfo, errWithCode = generateQuotaInfo(c, embeddingsRequest.Model, promptTokens) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	usage, errWithCode = embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens) | ||||||
|  |  | ||||||
|  | 	// 如果报错,则退还配额 | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		tokenId := c.GetInt("token_id") | ||||||
|  | 		if quotaInfo.HandelStatus { | ||||||
|  | 			go func(ctx context.Context) { | ||||||
|  | 				// return pre-consumed quota | ||||||
|  | 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||||
|  | 				} | ||||||
|  | 			}(c.Request.Context()) | ||||||
|  | 		} | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} else { | ||||||
|  | 		tokenName := c.GetString("token_name") | ||||||
|  | 		// 如果没有报错,则消费配额 | ||||||
|  | 		go func(ctx context.Context) { | ||||||
|  | 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.LogError(ctx, err.Error()) | ||||||
|  | 			} | ||||||
|  | 		}(c.Request.Context()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										106
									
								
								controller/relay-image-edits.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								controller/relay-image-edits.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,106 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	providersBase "one-api/providers/base" | ||||||
|  | 	"one-api/types" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func RelayImageEdits(c *gin.Context) { | ||||||
|  |  | ||||||
|  | 	var imageEditRequest types.ImageEditRequest | ||||||
|  |  | ||||||
|  | 	if err := common.UnmarshalBodyReusable(c, &imageEditRequest); err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if imageEditRequest.Prompt == "" { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusBadRequest, "field prompt is required") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if imageEditRequest.Model == "" { | ||||||
|  | 		imageEditRequest.Model = "dall-e-2" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if imageEditRequest.Size == "" { | ||||||
|  | 		imageEditRequest.Size = "1024x1024" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	channel, pass := fetchChannel(c, imageEditRequest.Model) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 解析模型映射 | ||||||
|  | 	var isModelMapped bool | ||||||
|  | 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if modelMap != nil && modelMap[imageEditRequest.Model] != "" { | ||||||
|  | 		imageEditRequest.Model = modelMap[imageEditRequest.Model] | ||||||
|  | 		isModelMapped = true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取供应商 | ||||||
|  | 	provider, pass := getProvider(c, channel, common.RelayModeImagesEdits) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	imageEditsProvider, ok := provider.(providersBase.ImageEditsInterface) | ||||||
|  | 	if !ok { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取Input Tokens | ||||||
|  | 	promptTokens, err := common.CountTokenImage(imageEditRequest) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var quotaInfo *QuotaInfo | ||||||
|  | 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||||
|  | 	var usage *types.Usage | ||||||
|  | 	quotaInfo, errWithCode = generateQuotaInfo(c, imageEditRequest.Model, promptTokens) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	usage, errWithCode = imageEditsProvider.ImageEditsAction(&imageEditRequest, isModelMapped, promptTokens) | ||||||
|  |  | ||||||
|  | 	// 如果报错,则退还配额 | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		tokenId := c.GetInt("token_id") | ||||||
|  | 		if quotaInfo.HandelStatus { | ||||||
|  | 			go func(ctx context.Context) { | ||||||
|  | 				// return pre-consumed quota | ||||||
|  | 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||||
|  | 				} | ||||||
|  | 			}(c.Request.Context()) | ||||||
|  | 		} | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} else { | ||||||
|  | 		tokenName := c.GetString("token_name") | ||||||
|  | 		// 如果没有报错,则消费配额 | ||||||
|  | 		go func(ctx context.Context) { | ||||||
|  | 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.LogError(ctx, err.Error()) | ||||||
|  | 			} | ||||||
|  | 		}(c.Request.Context()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										109
									
								
								controller/relay-image-generations.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								controller/relay-image-generations.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,109 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	providersBase "one-api/providers/base" | ||||||
|  | 	"one-api/types" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func RelayImageGenerations(c *gin.Context) { | ||||||
|  |  | ||||||
|  | 	var imageRequest types.ImageRequest | ||||||
|  |  | ||||||
|  | 	if err := common.UnmarshalBodyReusable(c, &imageRequest); err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if imageRequest.Model == "" { | ||||||
|  | 		imageRequest.Model = "dall-e-2" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if imageRequest.N == 0 { | ||||||
|  | 		imageRequest.N = 1 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if imageRequest.Size == "" { | ||||||
|  | 		imageRequest.Size = "1024x1024" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if imageRequest.Quality == "" { | ||||||
|  | 		imageRequest.Quality = "standard" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	channel, pass := fetchChannel(c, imageRequest.Model) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 解析模型映射 | ||||||
|  | 	var isModelMapped bool | ||||||
|  | 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if modelMap != nil && modelMap[imageRequest.Model] != "" { | ||||||
|  | 		imageRequest.Model = modelMap[imageRequest.Model] | ||||||
|  | 		isModelMapped = true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取供应商 | ||||||
|  | 	provider, pass := getProvider(c, channel, common.RelayModeImagesGenerations) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	imageGenerationsProvider, ok := provider.(providersBase.ImageGenerationsInterface) | ||||||
|  | 	if !ok { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取Input Tokens | ||||||
|  | 	promptTokens, err := common.CountTokenImage(imageRequest) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var quotaInfo *QuotaInfo | ||||||
|  | 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||||
|  | 	var usage *types.Usage | ||||||
|  | 	quotaInfo, errWithCode = generateQuotaInfo(c, imageRequest.Model, promptTokens) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	usage, errWithCode = imageGenerationsProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens) | ||||||
|  |  | ||||||
|  | 	// 如果报错,则退还配额 | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		tokenId := c.GetInt("token_id") | ||||||
|  | 		if quotaInfo.HandelStatus { | ||||||
|  | 			go func(ctx context.Context) { | ||||||
|  | 				// return pre-consumed quota | ||||||
|  | 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||||
|  | 				} | ||||||
|  | 			}(c.Request.Context()) | ||||||
|  | 		} | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} else { | ||||||
|  | 		tokenName := c.GetString("token_name") | ||||||
|  | 		// 如果没有报错,则消费配额 | ||||||
|  | 		go func(ctx context.Context) { | ||||||
|  | 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.LogError(ctx, err.Error()) | ||||||
|  | 			} | ||||||
|  | 		}(c.Request.Context()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										101
									
								
								controller/relay-image-variationsy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								controller/relay-image-variationsy.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,101 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	providersBase "one-api/providers/base" | ||||||
|  | 	"one-api/types" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func RelayImageVariations(c *gin.Context) { | ||||||
|  |  | ||||||
|  | 	var imageEditRequest types.ImageEditRequest | ||||||
|  |  | ||||||
|  | 	if err := common.UnmarshalBodyReusable(c, &imageEditRequest); err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if imageEditRequest.Model == "" { | ||||||
|  | 		imageEditRequest.Model = "dall-e-2" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if imageEditRequest.Size == "" { | ||||||
|  | 		imageEditRequest.Size = "1024x1024" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	channel, pass := fetchChannel(c, imageEditRequest.Model) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 解析模型映射 | ||||||
|  | 	var isModelMapped bool | ||||||
|  | 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if modelMap != nil && modelMap[imageEditRequest.Model] != "" { | ||||||
|  | 		imageEditRequest.Model = modelMap[imageEditRequest.Model] | ||||||
|  | 		isModelMapped = true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取供应商 | ||||||
|  | 	provider, pass := getProvider(c, channel, common.RelayModeImagesVariations) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	imageVariations, ok := provider.(providersBase.ImageVariationsInterface) | ||||||
|  | 	if !ok { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取Input Tokens | ||||||
|  | 	promptTokens, err := common.CountTokenImage(imageEditRequest) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var quotaInfo *QuotaInfo | ||||||
|  | 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||||
|  | 	var usage *types.Usage | ||||||
|  | 	quotaInfo, errWithCode = generateQuotaInfo(c, imageEditRequest.Model, promptTokens) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	usage, errWithCode = imageVariations.ImageVariationsAction(&imageEditRequest, isModelMapped, promptTokens) | ||||||
|  |  | ||||||
|  | 	// 如果报错,则退还配额 | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		tokenId := c.GetInt("token_id") | ||||||
|  | 		if quotaInfo.HandelStatus { | ||||||
|  | 			go func(ctx context.Context) { | ||||||
|  | 				// return pre-consumed quota | ||||||
|  | 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||||
|  | 				} | ||||||
|  | 			}(c.Request.Context()) | ||||||
|  | 		} | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} else { | ||||||
|  | 		tokenName := c.GetString("token_name") | ||||||
|  | 		// 如果没有报错,则消费配额 | ||||||
|  | 		go func(ctx context.Context) { | ||||||
|  | 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.LogError(ctx, err.Error()) | ||||||
|  | 			} | ||||||
|  | 		}(c.Request.Context()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										93
									
								
								controller/relay-moderations.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								controller/relay-moderations.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,93 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	providersBase "one-api/providers/base" | ||||||
|  | 	"one-api/types" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func RelayModerations(c *gin.Context) { | ||||||
|  |  | ||||||
|  | 	var moderationRequest types.ModerationRequest | ||||||
|  |  | ||||||
|  | 	if err := common.UnmarshalBodyReusable(c, &moderationRequest); err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if moderationRequest.Model == "" { | ||||||
|  | 		moderationRequest.Model = "text-moderation-stable" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	channel, pass := fetchChannel(c, moderationRequest.Model) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 解析模型映射 | ||||||
|  | 	var isModelMapped bool | ||||||
|  | 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if modelMap != nil && modelMap[moderationRequest.Model] != "" { | ||||||
|  | 		moderationRequest.Model = modelMap[moderationRequest.Model] | ||||||
|  | 		isModelMapped = true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取供应商 | ||||||
|  | 	provider, pass := getProvider(c, channel, common.RelayModeModerations) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	moderationProvider, ok := provider.(providersBase.ModerationInterface) | ||||||
|  | 	if !ok { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取Input Tokens | ||||||
|  | 	promptTokens := common.CountTokenInput(moderationRequest.Input, moderationRequest.Model) | ||||||
|  |  | ||||||
|  | 	var quotaInfo *QuotaInfo | ||||||
|  | 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||||
|  | 	var usage *types.Usage | ||||||
|  | 	quotaInfo, errWithCode = generateQuotaInfo(c, moderationRequest.Model, promptTokens) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	usage, errWithCode = moderationProvider.ModerationAction(&moderationRequest, isModelMapped, promptTokens) | ||||||
|  |  | ||||||
|  | 	// 如果报错,则退还配额 | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		tokenId := c.GetInt("token_id") | ||||||
|  | 		if quotaInfo.HandelStatus { | ||||||
|  | 			go func(ctx context.Context) { | ||||||
|  | 				// return pre-consumed quota | ||||||
|  | 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||||
|  | 				} | ||||||
|  | 			}(c.Request.Context()) | ||||||
|  | 		} | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} else { | ||||||
|  | 		tokenName := c.GetString("token_name") | ||||||
|  | 		// 如果没有报错,则消费配额 | ||||||
|  | 		go func(ctx context.Context) { | ||||||
|  | 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.LogError(ctx, err.Error()) | ||||||
|  | 			} | ||||||
|  | 		}(c.Request.Context()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										89
									
								
								controller/relay-speech.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								controller/relay-speech.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	providersBase "one-api/providers/base" | ||||||
|  | 	"one-api/types" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func RelaySpeech(c *gin.Context) { | ||||||
|  |  | ||||||
|  | 	var speechRequest types.SpeechAudioRequest | ||||||
|  |  | ||||||
|  | 	if err := common.UnmarshalBodyReusable(c, &speechRequest); err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	channel, pass := fetchChannel(c, speechRequest.Model) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 解析模型映射 | ||||||
|  | 	var isModelMapped bool | ||||||
|  | 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if modelMap != nil && modelMap[speechRequest.Model] != "" { | ||||||
|  | 		speechRequest.Model = modelMap[speechRequest.Model] | ||||||
|  | 		isModelMapped = true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取供应商 | ||||||
|  | 	provider, pass := getProvider(c, channel, common.RelayModeAudioSpeech) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	speechProvider, ok := provider.(providersBase.SpeechInterface) | ||||||
|  | 	if !ok { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取Input Tokens | ||||||
|  | 	promptTokens := len(speechRequest.Input) | ||||||
|  |  | ||||||
|  | 	var quotaInfo *QuotaInfo | ||||||
|  | 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||||
|  | 	var usage *types.Usage | ||||||
|  | 	quotaInfo, errWithCode = generateQuotaInfo(c, speechRequest.Model, promptTokens) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	usage, errWithCode = speechProvider.SpeechAction(&speechRequest, isModelMapped, promptTokens) | ||||||
|  |  | ||||||
|  | 	// 如果报错,则退还配额 | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		tokenId := c.GetInt("token_id") | ||||||
|  | 		if quotaInfo.HandelStatus { | ||||||
|  | 			go func(ctx context.Context) { | ||||||
|  | 				// return pre-consumed quota | ||||||
|  | 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||||
|  | 				} | ||||||
|  | 			}(c.Request.Context()) | ||||||
|  | 		} | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} else { | ||||||
|  | 		tokenName := c.GetString("token_name") | ||||||
|  | 		// 如果没有报错,则消费配额 | ||||||
|  | 		go func(ctx context.Context) { | ||||||
|  | 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.LogError(ctx, err.Error()) | ||||||
|  | 			} | ||||||
|  | 		}(c.Request.Context()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										89
									
								
								controller/relay-transcriptions.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								controller/relay-transcriptions.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	providersBase "one-api/providers/base" | ||||||
|  | 	"one-api/types" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func RelayTranscriptions(c *gin.Context) { | ||||||
|  |  | ||||||
|  | 	var audioRequest types.AudioRequest | ||||||
|  |  | ||||||
|  | 	if err := common.UnmarshalBodyReusable(c, &audioRequest); err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	channel, pass := fetchChannel(c, audioRequest.Model) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 解析模型映射 | ||||||
|  | 	var isModelMapped bool | ||||||
|  | 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if modelMap != nil && modelMap[audioRequest.Model] != "" { | ||||||
|  | 		audioRequest.Model = modelMap[audioRequest.Model] | ||||||
|  | 		isModelMapped = true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取供应商 | ||||||
|  | 	provider, pass := getProvider(c, channel, common.RelayModeAudioTranscription) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	transcriptionsProvider, ok := provider.(providersBase.TranscriptionsInterface) | ||||||
|  | 	if !ok { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取Input Tokens | ||||||
|  | 	promptTokens := 0 | ||||||
|  |  | ||||||
|  | 	var quotaInfo *QuotaInfo | ||||||
|  | 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||||
|  | 	var usage *types.Usage | ||||||
|  | 	quotaInfo, errWithCode = generateQuotaInfo(c, audioRequest.Model, promptTokens) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	usage, errWithCode = transcriptionsProvider.TranscriptionsAction(&audioRequest, isModelMapped, promptTokens) | ||||||
|  |  | ||||||
|  | 	// 如果报错,则退还配额 | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		tokenId := c.GetInt("token_id") | ||||||
|  | 		if quotaInfo.HandelStatus { | ||||||
|  | 			go func(ctx context.Context) { | ||||||
|  | 				// return pre-consumed quota | ||||||
|  | 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||||
|  | 				} | ||||||
|  | 			}(c.Request.Context()) | ||||||
|  | 		} | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} else { | ||||||
|  | 		tokenName := c.GetString("token_name") | ||||||
|  | 		// 如果没有报错,则消费配额 | ||||||
|  | 		go func(ctx context.Context) { | ||||||
|  | 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.LogError(ctx, err.Error()) | ||||||
|  | 			} | ||||||
|  | 		}(c.Request.Context()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										89
									
								
								controller/relay-translations.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								controller/relay-translations.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	providersBase "one-api/providers/base" | ||||||
|  | 	"one-api/types" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func RelayTranslations(c *gin.Context) { | ||||||
|  |  | ||||||
|  | 	var audioRequest types.AudioRequest | ||||||
|  |  | ||||||
|  | 	if err := common.UnmarshalBodyReusable(c, &audioRequest); err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	channel, pass := fetchChannel(c, audioRequest.Model) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 解析模型映射 | ||||||
|  | 	var isModelMapped bool | ||||||
|  | 	modelMap, err := parseModelMapping(channel.GetModelMapping()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if modelMap != nil && modelMap[audioRequest.Model] != "" { | ||||||
|  | 		audioRequest.Model = modelMap[audioRequest.Model] | ||||||
|  | 		isModelMapped = true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取供应商 | ||||||
|  | 	provider, pass := getProvider(c, channel, common.RelayModeAudioTranslation) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	translationProvider, ok := provider.(providersBase.TranslationInterface) | ||||||
|  | 	if !ok { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 获取Input Tokens | ||||||
|  | 	promptTokens := 0 | ||||||
|  |  | ||||||
|  | 	var quotaInfo *QuotaInfo | ||||||
|  | 	var errWithCode *types.OpenAIErrorWithStatusCode | ||||||
|  | 	var usage *types.Usage | ||||||
|  | 	quotaInfo, errWithCode = generateQuotaInfo(c, audioRequest.Model, promptTokens) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	usage, errWithCode = translationProvider.TranslationAction(&audioRequest, isModelMapped, promptTokens) | ||||||
|  |  | ||||||
|  | 	// 如果报错,则退还配额 | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		tokenId := c.GetInt("token_id") | ||||||
|  | 		if quotaInfo.HandelStatus { | ||||||
|  | 			go func(ctx context.Context) { | ||||||
|  | 				// return pre-consumed quota | ||||||
|  | 				err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||||
|  | 				} | ||||||
|  | 			}(c.Request.Context()) | ||||||
|  | 		} | ||||||
|  | 		errorHelper(c, errWithCode) | ||||||
|  | 		return | ||||||
|  | 	} else { | ||||||
|  | 		tokenName := c.GetString("token_name") | ||||||
|  | 		// 如果没有报错,则消费配额 | ||||||
|  | 		go func(ctx context.Context) { | ||||||
|  | 			err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.LogError(ctx, err.Error()) | ||||||
|  | 			} | ||||||
|  | 		}(c.Request.Context()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										280
									
								
								controller/relay-utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										280
									
								
								controller/relay-utils.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,280 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"math" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	"one-api/providers" | ||||||
|  | 	providersBase "one-api/providers/base" | ||||||
|  | 	"one-api/types" | ||||||
|  | 	"reflect" | ||||||
|  | 	"strconv" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/go-playground/validator/v10" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func GetValidFieldName(err error, obj interface{}) string { | ||||||
|  | 	getObj := reflect.TypeOf(obj) | ||||||
|  | 	if errs, ok := err.(validator.ValidationErrors); ok { | ||||||
|  | 		for _, e := range errs { | ||||||
|  | 			if f, exist := getObj.Elem().FieldByName(e.Field()); exist { | ||||||
|  | 				return f.Name | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return err.Error() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, pass bool) { | ||||||
|  | 	channelId, ok := c.Get("channelId") | ||||||
|  | 	if ok { | ||||||
|  | 		channel, pass = fetchChannelById(c, channelId.(int)) | ||||||
|  | 		if pass { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 	} | ||||||
|  | 	channel, pass = fetchChannelByModel(c, modelName) | ||||||
|  | 	if pass { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	c.Set("channel_id", channel.Id) | ||||||
|  |  | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func fetchChannelById(c *gin.Context, channelId any) (*model.Channel, bool) { | ||||||
|  | 	id, err := strconv.Atoi(channelId.(string)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") | ||||||
|  | 		return nil, true | ||||||
|  | 	} | ||||||
|  | 	channel, err := model.GetChannelById(id, true) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") | ||||||
|  | 		return nil, true | ||||||
|  | 	} | ||||||
|  | 	if channel.Status != common.ChannelStatusEnabled { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") | ||||||
|  | 		return nil, true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return channel, false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, bool) { | ||||||
|  | 	group := c.GetString("group") | ||||||
|  | 	channel, err := model.CacheGetRandomSatisfiedChannel(group, modelName) | ||||||
|  | 	if err != nil { | ||||||
|  | 		message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName) | ||||||
|  | 		if channel != nil { | ||||||
|  | 			common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | ||||||
|  | 			message = "数据库一致性已被破坏,请联系管理员" | ||||||
|  | 		} | ||||||
|  | 		common.AbortWithMessage(c, http.StatusServiceUnavailable, message) | ||||||
|  | 		return nil, true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return channel, false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getProvider(c *gin.Context, channel *model.Channel, relayMode int) (providersBase.ProviderInterface, bool) { | ||||||
|  | 	provider := providers.GetProvider(channel, c) | ||||||
|  | 	if provider == nil { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found") | ||||||
|  | 		return nil, true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !provider.SupportAPI(relayMode) { | ||||||
|  | 		common.AbortWithMessage(c, http.StatusNotImplemented, "channel does not support this API") | ||||||
|  | 		return nil, true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return provider, false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool { | ||||||
|  | 	if !common.AutomaticDisableChannelEnabled { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if err == nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if statusCode == http.StatusUnauthorized { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool { | ||||||
|  | 	if !common.AutomaticEnableChannelEnabled { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if openAIErr != nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { | ||||||
|  | 	// quotaDelta is remaining quota to be consumed | ||||||
|  | 	err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.SysError("error consuming token remain quota: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | 	err = model.CacheUpdateUserQuota(userId) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.SysError("error update user quota cache: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | 	// totalQuota is total quota consumed | ||||||
|  | 	if totalQuota != 0 { | ||||||
|  | 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||||
|  | 		model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) | ||||||
|  | 		model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) | ||||||
|  | 		model.UpdateChannelUsedQuota(channelId, totalQuota) | ||||||
|  | 	} | ||||||
|  | 	if totalQuota <= 0 { | ||||||
|  | 		common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseModelMapping(modelMapping string) (map[string]string, error) { | ||||||
|  | 	if modelMapping == "" || modelMapping == "{}" { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  | 	modelMap := make(map[string]string) | ||||||
|  | 	err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return modelMap, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type QuotaInfo struct { | ||||||
|  | 	modelName         string | ||||||
|  | 	promptTokens      int | ||||||
|  | 	preConsumedTokens int | ||||||
|  | 	modelRatio        float64 | ||||||
|  | 	groupRatio        float64 | ||||||
|  | 	ratio             float64 | ||||||
|  | 	preConsumedQuota  int | ||||||
|  | 	userId            int | ||||||
|  | 	channelId         int | ||||||
|  | 	tokenId           int | ||||||
|  | 	HandelStatus      bool | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func generateQuotaInfo(c *gin.Context, modelName string, promptTokens int) (*QuotaInfo, *types.OpenAIErrorWithStatusCode) { | ||||||
|  | 	quotaInfo := &QuotaInfo{ | ||||||
|  | 		modelName:    modelName, | ||||||
|  | 		promptTokens: promptTokens, | ||||||
|  | 		userId:       c.GetInt("id"), | ||||||
|  | 		channelId:    c.GetInt("channel_id"), | ||||||
|  | 		tokenId:      c.GetInt("token_id"), | ||||||
|  | 		HandelStatus: false, | ||||||
|  | 	} | ||||||
|  | 	quotaInfo.initQuotaInfo(c.GetString("group")) | ||||||
|  |  | ||||||
|  | 	errWithCode := quotaInfo.preQuotaConsumption() | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		return nil, errWithCode | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return quotaInfo, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (q *QuotaInfo) initQuotaInfo(groupName string) { | ||||||
|  | 	modelRatio := common.GetModelRatio(q.modelName) | ||||||
|  | 	groupRatio := common.GetGroupRatio(groupName) | ||||||
|  | 	preConsumedTokens := common.PreConsumedQuota | ||||||
|  | 	ratio := modelRatio * groupRatio | ||||||
|  | 	preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio) | ||||||
|  |  | ||||||
|  | 	q.preConsumedTokens = preConsumedTokens | ||||||
|  | 	q.modelRatio = modelRatio | ||||||
|  | 	q.groupRatio = groupRatio | ||||||
|  | 	q.ratio = ratio | ||||||
|  | 	q.preConsumedQuota = preConsumedQuota | ||||||
|  |  | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode { | ||||||
|  | 	userQuota, err := model.CacheGetUserQuota(q.userId) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if userQuota < q.preConsumedQuota { | ||||||
|  | 		return common.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return common.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if userQuota > 100*q.preConsumedQuota { | ||||||
|  | 		// in this case, we do not pre-consume quota | ||||||
|  | 		// because the user has enough quota | ||||||
|  | 		q.preConsumedQuota = 0 | ||||||
|  | 		// common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if q.preConsumedQuota > 0 { | ||||||
|  | 		err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return common.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||||
|  | 		} | ||||||
|  | 		q.HandelStatus = true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error { | ||||||
|  | 	quota := 0 | ||||||
|  | 	completionRatio := common.GetCompletionRatio(q.modelName) | ||||||
|  | 	promptTokens := usage.PromptTokens | ||||||
|  | 	completionTokens := usage.CompletionTokens | ||||||
|  | 	quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * q.ratio)) | ||||||
|  | 	if q.ratio != 0 && quota <= 0 { | ||||||
|  | 		quota = 1 | ||||||
|  | 	} | ||||||
|  | 	totalTokens := promptTokens + completionTokens | ||||||
|  | 	if totalTokens == 0 { | ||||||
|  | 		// in this case, must be some error happened | ||||||
|  | 		// we cannot just return, because we may have to return the pre-consumed quota | ||||||
|  | 		quota = 0 | ||||||
|  | 	} | ||||||
|  | 	quotaDelta := quota - q.preConsumedQuota | ||||||
|  | 	err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errors.New("error consuming token remain quota: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | 	err = model.CacheUpdateUserQuota(q.userId) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errors.New("error consuming token remain quota: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | 	if quota != 0 { | ||||||
|  | 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", q.modelRatio, q.groupRatio) | ||||||
|  | 		model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent) | ||||||
|  | 		model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota) | ||||||
|  | 		model.UpdateChannelUsedQuota(q.channelId, quota) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
| @@ -1,128 +1,17 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" |  | ||||||
| 	"context" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"github.com/songquanpeng/one-api/middleware" |  | ||||||
| 	dbmodel "github.com/songquanpeng/one-api/model" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/constant" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/controller" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/model" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/util" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/types" | ||||||
|  | 	"strconv" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // https://platform.openai.com/docs/api-reference/chat |  | ||||||
|  |  | ||||||
| func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { |  | ||||||
| 	var err *model.ErrorWithStatusCode |  | ||||||
| 	switch relayMode { |  | ||||||
| 	case constant.RelayModeImagesGenerations: |  | ||||||
| 		err = controller.RelayImageHelper(c, relayMode) |  | ||||||
| 	case constant.RelayModeAudioSpeech: |  | ||||||
| 		fallthrough |  | ||||||
| 	case constant.RelayModeAudioTranslation: |  | ||||||
| 		fallthrough |  | ||||||
| 	case constant.RelayModeAudioTranscription: |  | ||||||
| 		err = controller.RelayAudioHelper(c, relayMode) |  | ||||||
| 	default: |  | ||||||
| 		err = controller.RelayTextHelper(c) |  | ||||||
| 	} |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func Relay(c *gin.Context) { |  | ||||||
| 	ctx := c.Request.Context() |  | ||||||
| 	relayMode := constant.Path2RelayMode(c.Request.URL.Path) |  | ||||||
| 	if config.DebugEnabled { |  | ||||||
| 		requestBody, _ := common.GetRequestBody(c) |  | ||||||
| 		logger.Debugf(ctx, "request body: %s", string(requestBody)) |  | ||||||
| 	} |  | ||||||
| 	bizErr := relay(c, relayMode) |  | ||||||
| 	if bizErr == nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	channelId := c.GetInt("channel_id") |  | ||||||
| 	lastFailedChannelId := channelId |  | ||||||
| 	channelName := c.GetString("channel_name") |  | ||||||
| 	group := c.GetString("group") |  | ||||||
| 	originalModel := c.GetString("original_model") |  | ||||||
| 	go processChannelRelayError(ctx, channelId, channelName, bizErr) |  | ||||||
| 	requestId := c.GetString(logger.RequestIdKey) |  | ||||||
| 	retryTimes := config.RetryTimes |  | ||||||
| 	if !shouldRetry(c, bizErr.StatusCode) { |  | ||||||
| 		logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode) |  | ||||||
| 		retryTimes = 0 |  | ||||||
| 	} |  | ||||||
| 	for i := retryTimes; i > 0; i-- { |  | ||||||
| 		channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes) |  | ||||||
| 		if err != nil { |  | ||||||
| 			logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) |  | ||||||
| 			break |  | ||||||
| 		} |  | ||||||
| 		logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i) |  | ||||||
| 		if channel.Id == lastFailedChannelId { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		middleware.SetupContextForSelectedChannel(c, channel, originalModel) |  | ||||||
| 		requestBody, err := common.GetRequestBody(c) |  | ||||||
| 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) |  | ||||||
| 		bizErr = relay(c, relayMode) |  | ||||||
| 		if bizErr == nil { |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		channelId := c.GetInt("channel_id") |  | ||||||
| 		lastFailedChannelId = channelId |  | ||||||
| 		channelName := c.GetString("channel_name") |  | ||||||
| 		go processChannelRelayError(ctx, channelId, channelName, bizErr) |  | ||||||
| 	} |  | ||||||
| 	if bizErr != nil { |  | ||||||
| 		if bizErr.StatusCode == http.StatusTooManyRequests { |  | ||||||
| 			bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" |  | ||||||
| 		} |  | ||||||
| 		bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId) |  | ||||||
| 		c.JSON(bizErr.StatusCode, gin.H{ |  | ||||||
| 			"error": bizErr.Error, |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func shouldRetry(c *gin.Context, statusCode int) bool { |  | ||||||
| 	if _, ok := c.Get("specific_channel_id"); ok { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 	if statusCode == http.StatusTooManyRequests { |  | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| 	if statusCode/100 == 5 { |  | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| 	if statusCode == http.StatusBadRequest { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 	if statusCode/100 == 2 { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 	return true |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) { |  | ||||||
| 	logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) |  | ||||||
| 	// https://platform.openai.com/docs/guides/error-codes/api-errors |  | ||||||
| 	if util.ShouldDisableChannel(&err.Error, err.StatusCode) { |  | ||||||
| 		disableChannel(channelId, channelName, err.Message) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func RelayNotImplemented(c *gin.Context) { | func RelayNotImplemented(c *gin.Context) { | ||||||
| 	err := model.Error{ | 	err := types.OpenAIError{ | ||||||
| 		Message: "API not implemented", | 		Message: "API not implemented", | ||||||
| 		Type:    "one_api_error", | 		Type:    "one_api_error", | ||||||
| 		Param:   "", | 		Param:   "", | ||||||
| @@ -134,7 +23,7 @@ func RelayNotImplemented(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func RelayNotFound(c *gin.Context) { | func RelayNotFound(c *gin.Context) { | ||||||
| 	err := model.Error{ | 	err := types.OpenAIError{ | ||||||
| 		Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), | 		Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), | ||||||
| 		Type:    "invalid_request_error", | 		Type:    "invalid_request_error", | ||||||
| 		Param:   "", | 		Param:   "", | ||||||
| @@ -144,3 +33,31 @@ func RelayNotFound(c *gin.Context) { | |||||||
| 		"error": err, | 		"error": err, | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func errorHelper(c *gin.Context, err *types.OpenAIErrorWithStatusCode) { | ||||||
|  | 	requestId := c.GetString(common.RequestIdKey) | ||||||
|  | 	retryTimesStr := c.Query("retry") | ||||||
|  | 	retryTimes, _ := strconv.Atoi(retryTimesStr) | ||||||
|  | 	if retryTimesStr == "" { | ||||||
|  | 		retryTimes = common.RetryTimes | ||||||
|  | 	} | ||||||
|  | 	if retryTimes > 0 { | ||||||
|  | 		c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) | ||||||
|  | 	} else { | ||||||
|  | 		if err.StatusCode == http.StatusTooManyRequests { | ||||||
|  | 			err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" | ||||||
|  | 		} | ||||||
|  | 		err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) | ||||||
|  | 		c.JSON(err.StatusCode, gin.H{ | ||||||
|  | 			"error": err.OpenAIError, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	channelId := c.GetInt("channel_id") | ||||||
|  | 	common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | ||||||
|  | 	// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||||
|  | 	if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { | ||||||
|  | 		channelId := c.GetInt("channel_id") | ||||||
|  | 		channelName := c.GetString("channel_name") | ||||||
|  | 		disableChannel(channelId, channelName, err.Message) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
| @@ -1,13 +1,12 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" |  | ||||||
| 	"github.com/songquanpeng/one-api/model" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func GetAllTokens(c *gin.Context) { | func GetAllTokens(c *gin.Context) { | ||||||
| @@ -16,7 +15,7 @@ func GetAllTokens(c *gin.Context) { | |||||||
| 	if p < 0 { | 	if p < 0 { | ||||||
| 		p = 0 | 		p = 0 | ||||||
| 	} | 	} | ||||||
| 	tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage) | 	tokens, err := model.GetAllUserTokens(userId, p*common.ItemsPerPage, common.ItemsPerPage) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -29,7 +28,6 @@ func GetAllTokens(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    tokens, | 		"data":    tokens, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchTokens(c *gin.Context) { | func SearchTokens(c *gin.Context) { | ||||||
| @@ -48,7 +46,6 @@ func SearchTokens(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    tokens, | 		"data":    tokens, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetToken(c *gin.Context) { | func GetToken(c *gin.Context) { | ||||||
| @@ -74,7 +71,6 @@ func GetToken(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    token, | 		"data":    token, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetTokenStatus(c *gin.Context) { | func GetTokenStatus(c *gin.Context) { | ||||||
| @@ -121,9 +117,9 @@ func AddToken(c *gin.Context) { | |||||||
| 	cleanToken := model.Token{ | 	cleanToken := model.Token{ | ||||||
| 		UserId:         c.GetInt("id"), | 		UserId:         c.GetInt("id"), | ||||||
| 		Name:           token.Name, | 		Name:           token.Name, | ||||||
| 		Key:            helper.GenerateKey(), | 		Key:            common.GenerateKey(), | ||||||
| 		CreatedTime:    helper.GetTimestamp(), | 		CreatedTime:    common.GetTimestamp(), | ||||||
| 		AccessedTime:   helper.GetTimestamp(), | 		AccessedTime:   common.GetTimestamp(), | ||||||
| 		ExpiredTime:    token.ExpiredTime, | 		ExpiredTime:    token.ExpiredTime, | ||||||
| 		RemainQuota:    token.RemainQuota, | 		RemainQuota:    token.RemainQuota, | ||||||
| 		UnlimitedQuota: token.UnlimitedQuota, | 		UnlimitedQuota: token.UnlimitedQuota, | ||||||
| @@ -140,7 +136,6 @@ func AddToken(c *gin.Context) { | |||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func DeleteToken(c *gin.Context) { | func DeleteToken(c *gin.Context) { | ||||||
| @@ -158,7 +153,6 @@ func DeleteToken(c *gin.Context) { | |||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateToken(c *gin.Context) { | func UpdateToken(c *gin.Context) { | ||||||
| @@ -189,7 +183,7 @@ func UpdateToken(c *gin.Context) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if token.Status == common.TokenStatusEnabled { | 	if token.Status == common.TokenStatusEnabled { | ||||||
| 		if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { | 		if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", | 				"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", | ||||||
| @@ -226,5 +220,4 @@ func UpdateToken(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    cleanToken, | 		"data":    cleanToken, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -3,11 +3,9 @@ package controller | |||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" |  | ||||||
| 	"github.com/songquanpeng/one-api/model" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -21,7 +19,7 @@ type LoginRequest struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func Login(c *gin.Context) { | func Login(c *gin.Context) { | ||||||
| 	if !config.PasswordLoginEnabled { | 	if !common.PasswordLoginEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"message": "管理员关闭了密码登录", | 			"message": "管理员关闭了密码登录", | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -108,14 +106,14 @@ func Logout(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func Register(c *gin.Context) { | func Register(c *gin.Context) { | ||||||
| 	if !config.RegisterEnabled { | 	if !common.RegisterEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"message": "管理员关闭了新用户注册", | 			"message": "管理员关闭了新用户注册", | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if !config.PasswordRegisterEnabled { | 	if !common.PasswordRegisterEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册", | 			"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册", | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -138,7 +136,7 @@ func Register(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if config.EmailVerificationEnabled { | 	if common.EmailVerificationEnabled { | ||||||
| 		if user.Email == "" || user.VerificationCode == "" { | 		if user.Email == "" || user.VerificationCode == "" { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| @@ -162,7 +160,7 @@ func Register(c *gin.Context) { | |||||||
| 		DisplayName: user.Username, | 		DisplayName: user.Username, | ||||||
| 		InviterId:   inviterId, | 		InviterId:   inviterId, | ||||||
| 	} | 	} | ||||||
| 	if config.EmailVerificationEnabled { | 	if common.EmailVerificationEnabled { | ||||||
| 		cleanUser.Email = user.Email | 		cleanUser.Email = user.Email | ||||||
| 	} | 	} | ||||||
| 	if err := cleanUser.Insert(inviterId); err != nil { | 	if err := cleanUser.Insert(inviterId); err != nil { | ||||||
| @@ -176,7 +174,6 @@ func Register(c *gin.Context) { | |||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllUsers(c *gin.Context) { | func GetAllUsers(c *gin.Context) { | ||||||
| @@ -184,7 +181,7 @@ func GetAllUsers(c *gin.Context) { | |||||||
| 	if p < 0 { | 	if p < 0 { | ||||||
| 		p = 0 | 		p = 0 | ||||||
| 	} | 	} | ||||||
| 	users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage) | 	users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -197,7 +194,6 @@ func GetAllUsers(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    users, | 		"data":    users, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchUsers(c *gin.Context) { | func SearchUsers(c *gin.Context) { | ||||||
| @@ -215,7 +211,6 @@ func SearchUsers(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    users, | 		"data":    users, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetUser(c *gin.Context) { | func GetUser(c *gin.Context) { | ||||||
| @@ -248,30 +243,30 @@ func GetUser(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    user, | 		"data":    user, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetUserDashboard(c *gin.Context) { | func GetUserDashboard(c *gin.Context) { | ||||||
| 	id := c.GetInt("id") | 	id := c.GetInt("id") | ||||||
|  | 	// 获取7天前 00:00:00 和 今天23:59:59  的秒时间戳 | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| 	startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix() | 	toDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) | ||||||
| 	endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix() | 	endOfDay := toDay.Add(time.Hour * 24).Add(-time.Second).Unix() | ||||||
|  | 	startOfDay := toDay.AddDate(0, 0, -7).Unix() | ||||||
|  |  | ||||||
| 	dashboards, err := model.SearchLogsByDayAndModel(id, int(startOfDay), int(endOfDay)) | 	dashboards, err := model.SearchLogsByDayAndModel(id, int(startOfDay), int(endOfDay)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "无法获取统计信息", | 			"message": "无法获取统计信息.", | ||||||
| 			"data":    nil, |  | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    dashboards, | 		"data":    dashboards, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GenerateAccessToken(c *gin.Context) { | func GenerateAccessToken(c *gin.Context) { | ||||||
| @@ -284,7 +279,7 @@ func GenerateAccessToken(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	user.AccessToken = helper.GetUUID() | 	user.AccessToken = common.GetUUID() | ||||||
|  |  | ||||||
| 	if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { | 	if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| @@ -307,7 +302,6 @@ func GenerateAccessToken(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    user.AccessToken, | 		"data":    user.AccessToken, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAffCode(c *gin.Context) { | func GetAffCode(c *gin.Context) { | ||||||
| @@ -321,7 +315,7 @@ func GetAffCode(c *gin.Context) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if user.AffCode == "" { | 	if user.AffCode == "" { | ||||||
| 		user.AffCode = helper.GetRandomString(4) | 		user.AffCode = common.GetRandomString(4) | ||||||
| 		if err := user.Update(false); err != nil { | 		if err := user.Update(false); err != nil { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| @@ -335,7 +329,6 @@ func GetAffCode(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    user.AffCode, | 		"data":    user.AffCode, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetSelf(c *gin.Context) { | func GetSelf(c *gin.Context) { | ||||||
| @@ -353,7 +346,6 @@ func GetSelf(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    user, | 		"data":    user, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateUser(c *gin.Context) { | func UpdateUser(c *gin.Context) { | ||||||
| @@ -417,7 +409,6 @@ func UpdateUser(c *gin.Context) { | |||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateSelf(c *gin.Context) { | func UpdateSelf(c *gin.Context) { | ||||||
| @@ -464,7 +455,6 @@ func UpdateSelf(c *gin.Context) { | |||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func DeleteUser(c *gin.Context) { | func DeleteUser(c *gin.Context) { | ||||||
| @@ -526,7 +516,6 @@ func DeleteSelf(c *gin.Context) { | |||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func CreateUser(c *gin.Context) { | func CreateUser(c *gin.Context) { | ||||||
| @@ -575,7 +564,6 @@ func CreateUser(c *gin.Context) { | |||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| type ManageRequest struct { | type ManageRequest struct { | ||||||
| @@ -692,7 +680,6 @@ func ManageUser(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    clearUser, | 		"data":    clearUser, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func EmailBind(c *gin.Context) { | func EmailBind(c *gin.Context) { | ||||||
| @@ -728,13 +715,12 @@ func EmailBind(c *gin.Context) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if user.Role == common.RoleRootUser { | 	if user.Role == common.RoleRootUser { | ||||||
| 		config.RootUserEmail = email | 		common.RootUserEmail = email | ||||||
| 	} | 	} | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| type topUpRequest struct { | type topUpRequest struct { | ||||||
| @@ -765,5 +751,4 @@ func TopUp(c *gin.Context) { | |||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    quota, | 		"data":    quota, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -4,13 +4,13 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/model" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type wechatLoginResponse struct { | type wechatLoginResponse struct { | ||||||
| @@ -23,11 +23,11 @@ func getWeChatIdByCode(code string) (string, error) { | |||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		return "", errors.New("无效的参数") | 		return "", errors.New("无效的参数") | ||||||
| 	} | 	} | ||||||
| 	req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil) | 	req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
| 	req.Header.Set("Authorization", config.WeChatServerToken) | 	req.Header.Set("Authorization", common.WeChatServerToken) | ||||||
| 	client := http.Client{ | 	client := http.Client{ | ||||||
| 		Timeout: 5 * time.Second, | 		Timeout: 5 * time.Second, | ||||||
| 	} | 	} | ||||||
| @@ -51,7 +51,7 @@ func getWeChatIdByCode(code string) (string, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func WeChatAuth(c *gin.Context) { | func WeChatAuth(c *gin.Context) { | ||||||
| 	if !config.WeChatAuthEnabled { | 	if !common.WeChatAuthEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"message": "管理员未开启通过微信登录以及注册", | 			"message": "管理员未开启通过微信登录以及注册", | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -80,7 +80,7 @@ func WeChatAuth(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		if config.RegisterEnabled { | 		if common.RegisterEnabled { | ||||||
| 			user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) | 			user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) | ||||||
| 			user.DisplayName = "WeChat User" | 			user.DisplayName = "WeChat User" | ||||||
| 			user.Role = common.RoleCommonUser | 			user.Role = common.RoleCommonUser | ||||||
| @@ -113,7 +113,7 @@ func WeChatAuth(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func WeChatBind(c *gin.Context) { | func WeChatBind(c *gin.Context) { | ||||||
| 	if !config.WeChatAuthEnabled { | 	if !common.WeChatAuthEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"message": "管理员未开启通过微信登录以及注册", | 			"message": "管理员未开启通过微信登录以及注册", | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -161,5 +161,4 @@ func WeChatBind(c *gin.Context) { | |||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										5
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								go.mod
									
									
									
									
									
								
							| @@ -1,4 +1,4 @@ | |||||||
| module github.com/songquanpeng/one-api | module one-api | ||||||
|  |  | ||||||
| // +heroku goVersion go1.18 | // +heroku goVersion go1.18 | ||||||
| go 1.18 | go 1.18 | ||||||
| @@ -45,6 +45,7 @@ require ( | |||||||
| 	github.com/jackc/pgx/v5 v5.3.1 // indirect | 	github.com/jackc/pgx/v5 v5.3.1 // indirect | ||||||
| 	github.com/jinzhu/inflection v1.0.0 // indirect | 	github.com/jinzhu/inflection v1.0.0 // indirect | ||||||
| 	github.com/jinzhu/now v1.1.5 // indirect | 	github.com/jinzhu/now v1.1.5 // indirect | ||||||
|  | 	github.com/joho/godotenv v1.5.1 // indirect | ||||||
| 	github.com/json-iterator/go v1.1.12 // indirect | 	github.com/json-iterator/go v1.1.12 // indirect | ||||||
| 	github.com/klauspost/cpuid/v2 v2.2.4 // indirect | 	github.com/klauspost/cpuid/v2 v2.2.4 // indirect | ||||||
| 	github.com/leodido/go-urn v1.2.4 // indirect | 	github.com/leodido/go-urn v1.2.4 // indirect | ||||||
| @@ -57,7 +58,7 @@ require ( | |||||||
| 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | ||||||
| 	github.com/ugorji/go/codec v1.2.11 // indirect | 	github.com/ugorji/go/codec v1.2.11 // indirect | ||||||
| 	golang.org/x/arch v0.3.0 // indirect | 	golang.org/x/arch v0.3.0 // indirect | ||||||
| 	golang.org/x/net v0.17.0 // indirect | 	golang.org/x/net v0.19.0 // indirect | ||||||
| 	golang.org/x/sys v0.15.0 // indirect | 	golang.org/x/sys v0.15.0 // indirect | ||||||
| 	golang.org/x/text v0.14.0 // indirect | 	golang.org/x/text v0.14.0 // indirect | ||||||
| 	google.golang.org/protobuf v1.30.0 // indirect | 	google.golang.org/protobuf v1.30.0 // indirect | ||||||
|   | |||||||
							
								
								
									
										4
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
									
									
									
									
								
							| @@ -80,6 +80,8 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr | |||||||
| github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||||
| github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= | github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= | ||||||
| github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||||
|  | github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= | ||||||
|  | github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= | ||||||
| github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= | github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= | ||||||
| github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= | github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= | ||||||
| github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= | github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= | ||||||
| @@ -157,6 +159,8 @@ golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE | |||||||
| golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | ||||||
| golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= | golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= | ||||||
| golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= | golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= | ||||||
|  | golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= | ||||||
|  | golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= | ||||||
| golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||||
| golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||||
| golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
|   | |||||||
							
								
								
									
										250
									
								
								i18n/en.json
									
									
									
									
									
								
							
							
						
						
									
										250
									
								
								i18n/en.json
									
									
									
									
									
								
							| @@ -86,7 +86,6 @@ | |||||||
|   "该令牌已过期": "The token has expired", |   "该令牌已过期": "The token has expired", | ||||||
|   "该令牌额度已用尽": "The token quota has been used up", |   "该令牌额度已用尽": "The token quota has been used up", | ||||||
|   "无效的令牌": "Invalid token", |   "无效的令牌": "Invalid token", | ||||||
|   "令牌验证失败": "Token verification failed", |  | ||||||
|   "id 或 userId 为空!": "id or userId is empty!", |   "id 或 userId 为空!": "id or userId is empty!", | ||||||
|   "quota 不能为负数!": "quota cannot be negative!", |   "quota 不能为负数!": "quota cannot be negative!", | ||||||
|   "令牌额度不足": "Insufficient token quota", |   "令牌额度不足": "Insufficient token quota", | ||||||
| @@ -456,11 +455,9 @@ | |||||||
|   "已绑定的邮箱账户": "Email Account Bound", |   "已绑定的邮箱账户": "Email Account Bound", | ||||||
|   "用户信息更新成功!": "User information updated successfully!", |   "用户信息更新成功!": "User information updated successfully!", | ||||||
|   "模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f", |   "模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f", | ||||||
|   "模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f": "model rate %.2f, group rate %.2f, completion rate %.2f", |  | ||||||
|   "使用明细(总消耗额度:{renderQuota(stat.quota)})": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})", |   "使用明细(总消耗额度:{renderQuota(stat.quota)})": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})", | ||||||
|   "用户名称": "User Name", |   "用户名称": "User Name", | ||||||
|   "令牌名称": "Token Name", |   "令牌名称": "Token Name", | ||||||
|   "默认令牌": "Default Token", |  | ||||||
|   "留空则查询全部用户": "Leave blank to query all users", |   "留空则查询全部用户": "Leave blank to query all users", | ||||||
|   "留空则查询全部令牌": "Leave blank to query all tokens", |   "留空则查询全部令牌": "Leave blank to query all tokens", | ||||||
|   "模型名称": "Model Name", |   "模型名称": "Model Name", | ||||||
| @@ -529,250 +526,5 @@ | |||||||
|   "模型版本": "Model version", |   "模型版本": "Model version", | ||||||
|   "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", |   "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", | ||||||
|   "点击查看": "click to view", |   "点击查看": "click to view", | ||||||
|   "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!", |   "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!" | ||||||
|   "处理中...": "Processing...", |  | ||||||
|   "绑定成功!": "Binding successful!", |  | ||||||
|   "登录成功!": "Login successful!", |  | ||||||
|   "操作失败,重定向至登录界面中...": "Operation failed, redirecting to login screen...", |  | ||||||
|   "出现错误,第 ${count} 次重试中...": "An error occurred, retrying ${count}...", |  | ||||||
|   "首页": "Home", |  | ||||||
|   "渠道": "Channel", |  | ||||||
|   "令牌": "API Keys", |  | ||||||
|   "兑换": "Redeem", |  | ||||||
|   "充值": "Recharge", |  | ||||||
|   "用户": "Users", |  | ||||||
|   "日志": "Logs", |  | ||||||
|   "设置": "Settings", |  | ||||||
|   "关于": "About", |  | ||||||
|   "聊天": "Chat", |  | ||||||
|   "注销成功!": "Logout successful!", |  | ||||||
|   "注销": "Log out", |  | ||||||
|   "登录": "Log in", |  | ||||||
|   "注册": "Sign up", |  | ||||||
|   "加载{name}中...": "Loading {name}...", |  | ||||||
|   "未登录或登录已过期,请重新登录!": "Not logged in or login has expired, please log in again!", |  | ||||||
|   "请立刻修改默认密码!": "Please change the default password immediately!", |  | ||||||
|   "欢迎回来": "Welcome back", |  | ||||||
|   "没有账户?": "No account?", |  | ||||||
|   "立刻注册": "Sign up now", |  | ||||||
|   "用户名": "Username", |  | ||||||
|   "密码": "Password", |  | ||||||
|   "正在登录……": "Logging in...", |  | ||||||
|   "忘记密码": "Forgot password", |  | ||||||
|   "其他方式": "Other methods", |  | ||||||
|   "微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)": "Scan the QR code with WeChat, follow the official account and enter 'verification code' to get the verification code (valid within three minutes)", |  | ||||||
|   "验证码": "Verification code", |  | ||||||
|   "全部用户": "All users", |  | ||||||
|   "当前用户": "Current user", |  | ||||||
|   "全部": "All", |  | ||||||
|   "消费": "Consumption", |  | ||||||
|   "管理": "Management", |  | ||||||
|   "系统": "System", |  | ||||||
|   "未知": "Unknown", |  | ||||||
|   "其他模型": "Other models", |  | ||||||
|   "复制成功": "Copy successful", |  | ||||||
|   "使用明细": "Usages", |  | ||||||
|   "刷新": "Refresh", |  | ||||||
|   "收起面板": "Collapse panel", |  | ||||||
|   "展开面板": "Expand panel", |  | ||||||
|   "显示查询选项": "Show search options", |  | ||||||
|   "隐藏查询选项": "Hide search options", |  | ||||||
|   "用户名称": "User name", |  | ||||||
|   "可选值": "Optional values", |  | ||||||
|   "渠道 ID": "Channel ID", |  | ||||||
|   "令牌名称": "Key name", |  | ||||||
|   "模型名称": "Model name", |  | ||||||
|   "起始时间": "Start time", |  | ||||||
|   "结束时间": "End time", |  | ||||||
|   "查询": "Query", |  | ||||||
|   "隐藏条形图": "Hide bar chart", |  | ||||||
|   "显示条形图": "Show bar chart", |  | ||||||
|   "折线条形图只展示最新50条数据": "Line and bar charts only show the latest 50 pieces of data", |  | ||||||
|   "总消耗": "Total consumption", |  | ||||||
|   "总共调用了 {payload[0].value} 次": "A total of {payload[0].value} calls were made", |  | ||||||
|   "{model.name}: {model.value} 次": "{model.name}: {model.value} times", |  | ||||||
|   "总共调用了 {payload[0].value} 次 {payload[0].name}": "A total of {payload[0].value} {payload[0].name} calls were made", |  | ||||||
|   "总消耗额度": "Total consumption limit", |  | ||||||
|   "暂无数据": "No data available", |  | ||||||
|   "更多数据统计图形即将到来,敬请期待!": "More data statistics graphics are coming soon, stay tuned!", |  | ||||||
|   "复制用户名": "Copy username", |  | ||||||
|   "{`共 ${counts} 条数据`}": "{`A total of ${counts} pieces of data`}", |  | ||||||
|   "共 0 条数据": "A total of 0 pieces of data", |  | ||||||
|   "选择明细分类": "Select detail category", |  | ||||||
|   "模型倍率": "model rate", |  | ||||||
|   "分组倍率": "group rate", |  | ||||||
|   "新密码已复制到剪贴板:": "New password has been copied to the clipboard:", |  | ||||||
|   "密码重置确认": "Password reset confirmation", |  | ||||||
|   "邮箱地址": "Email address", |  | ||||||
|   "新密码": "New password", |  | ||||||
|   "密码已复制到剪贴板:": "Password has been copied to the clipboard:", |  | ||||||
|   "密码重置完成": "Password reset complete", |  | ||||||
|   "提交": "Submit", |  | ||||||
|   "返回登录": "Return to login", |  | ||||||
|   "请稍后重试,浏览器环境检查未通过": "Please try again later, browser environment check failed", |  | ||||||
|   "重置邮件发送成功,请检查邮箱!": "Reset email sent successfully, please check your email!", |  | ||||||
|   "密码重置": "Password reset", |  | ||||||
|   "重试": "Retry", |  | ||||||
|   "组": "Group", |  | ||||||
|   "令牌已重置并已复制到剪贴板": "Token has been reset and copied to the clipboard", |  | ||||||
|   "邀请链接已复制到剪切板": "Invitation link has been copied to the clipboard", |  | ||||||
|   "系统令牌已复制到剪切板": "System token has been copied to the clipboard", |  | ||||||
|   "请输入你的账户名以确认删除!": "Please enter your account name to confirm deletion!", |  | ||||||
|   "账户已删除!": "Account has been deleted!", |  | ||||||
|   "微信账户绑定成功!": "WeChat account binding successful!", |  | ||||||
|   "请稍后几秒重试,Turnstile 正在检查用户环境!": "Please try again in a few seconds, Turnstile is checking the user environment!", |  | ||||||
|   "验证码发送成功,请检查邮箱!": "Verification code sent successfully, please check your email!", |  | ||||||
|   "邮箱账户绑定成功!": "Email account binding successful!", |  | ||||||
|   "个人信息": "Personal information", |  | ||||||
|   "编辑个人信息": "Edit personal information", |  | ||||||
|   "生成系统访问令牌": "Generate system access token", |  | ||||||
|   "复制邀请链接": "Copy invitation link", |  | ||||||
|   "删除个人帐户": "Delete personal account", |  | ||||||
|   "普通用户": "Regular user", |  | ||||||
|   "管理员": "Administrator", |  | ||||||
|   "超级管理员": "Super administrator", |  | ||||||
|   "显示名称": "Display name", |  | ||||||
|   "GitHub 账号": "GitHub account", |  | ||||||
|   "微信账号": "WeChat account", |  | ||||||
|   "修改个人信息只允许在电脑端进行。生成的令牌用于系统管理,而非用于请求 OpenAI 相关的服务,请知悉。": "Modifying personal information is only allowed on a computer. The generated token is for system management, not for requesting OpenAI related services. Please be aware.", |  | ||||||
|   "可用模型": "Available models", |  | ||||||
|   "账号绑定": "Account binding", |  | ||||||
|   "绑定微信": "Bind WeChat", |  | ||||||
|   "绑定 GitHub": "Bind GitHub", |  | ||||||
|   "绑定邮箱": "Bind Email", |  | ||||||
|   "绑定": "Bind", |  | ||||||
|   "绑定邮箱地址": "Bind email address", |  | ||||||
|   "输入邮箱地址": "Enter email address", |  | ||||||
|   "重新发送": "Resend", |  | ||||||
|   "获取验证码": "Get verification code", |  | ||||||
|   "确认绑定": "Confirm binding", |  | ||||||
|   "取消": "Cancel", |  | ||||||
|   "危险操作": "Dangerous operation", |  | ||||||
|   "您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your own account, all data will be cleared and cannot be recovered", |  | ||||||
|   "输入你的账户名": "Enter your account name", |  | ||||||
|   "以确认删除": "To confirm deletion", |  | ||||||
|   "确认删除": "Confirm deletion", |  | ||||||
|   "未使用": "Not used", |  | ||||||
|   "已禁用": "Disabled", |  | ||||||
|   "已使用": "Used", |  | ||||||
|   "未知状态": "Unknown status", |  | ||||||
|   "操作成功完成!": "Operation successfully completed!", |  | ||||||
|   "搜索兑换码的 ID 和名称 ...": "Search for the ID and name of the redemption code ...", |  | ||||||
|   "名称": "Name", |  | ||||||
|   "状态": "Status", |  | ||||||
|   "额度": "Quota", |  | ||||||
|   "创建时间": "Creation time", |  | ||||||
|   "兑换时间": "Redemption time", |  | ||||||
|   "操作": "Operation", |  | ||||||
|   "尚未兑换": "Not yet redeemed", |  | ||||||
|   "已复制到剪贴板!": "Copied to clipboard!", |  | ||||||
|   "无法复制到剪贴板,请手动复制,已将兑换码填入搜索框。": "Unable to copy to clipboard, please copy manually. The redemption code has been filled in the search box.", |  | ||||||
|   "复制": "Copy", |  | ||||||
|   "删除": "Delete", |  | ||||||
|   "禁用": "Disable", |  | ||||||
|   "启用": "Enable", |  | ||||||
|   "编辑": "Edit", |  | ||||||
|   "添加新的兑换码": "Add new redemption code", |  | ||||||
|   "密码长度不得小于 8 位!": "Password length must not be less than 8 characters!", |  | ||||||
|   "两次输入的密码不一致": "The two passwords entered do not match", |  | ||||||
|   "注册成功!": "Registration successful!", |  | ||||||
|   "请填写注册邮箱!": "Please fill in the registration email!", |  | ||||||
|   "请在${verificationTimeout}秒后再试": "Please try again after ${verificationTimeout} seconds", |  | ||||||
|   "验证码发送成功,请检查你的邮箱!": "Verification code sent successfully, please check your email!", |  | ||||||
|   "已有账户?": "Already have an account?", |  | ||||||
|   "请输入用户名(最长 12 位)": "Please enter a username (up to 12 characters)", |  | ||||||
|   "请输入密码(最短 8 位,最长 20 位)": "Please enter a password (minimum 8 characters, maximum 20 characters)", |  | ||||||
|   "请再次输入密码": "Please enter the password again", |  | ||||||
|   "请输入邮箱地址": "Please enter an email address", |  | ||||||
|   "秒后可重发": "Can be resent after seconds", |  | ||||||
|   "请输入邮箱验证码": "Please enter the email verification code", |  | ||||||
|   "已过期": "Expired", |  | ||||||
|   "已启用": "Enabled", |  | ||||||
|   "已耗尽": "Exhausted", |  | ||||||
|   "无": "None", |  | ||||||
|   "令牌密钥": "API Key", |  | ||||||
|   "令牌状态": "Key status", |  | ||||||
|   "已用额度": "Used quota", |  | ||||||
|   "剩余额度": "Remaining quota", |  | ||||||
|   "过期时间": "Expiration time", |  | ||||||
|   "你确定要删除这个令牌吗?": "Are you sure you want to delete this key?", |  | ||||||
|   "无法复制到剪贴板,请手动复制,已将令牌密钥填入搜索框": "Unable to copy to clipboard, please copy manually. The key key has been filled in the search box.", |  | ||||||
|   "无限制": "Unlimited", |  | ||||||
|   "永不过期": "Never expires", |  | ||||||
|   "使用 API 访问令牌进行服务鉴权和计费。": "Use API Key for service authentication and billing.", |  | ||||||
|   "API 访问令牌关系到您的个人利益,请妥善留存,不要与其他人共享,也不要保存在客户端代码中。": "API Key is related to your personal interests. Please keep it properly. Do not share it with others or save it in client code.", |  | ||||||
|   "创建令牌": "Create Key", |  | ||||||
|   "什么都还没有,快去创建一个令牌开始使用吧!": "Nothing yet, go create a key to start using!", |  | ||||||
|   "你确定要删除该令牌吗": "Are you sure you want to delete this key", |  | ||||||
|   "导出令牌信息": "Export key information", |  | ||||||
|   "错误:未登录或登录已过期,请重新登录!": "Error: Not logged in or login has expired, please log in again!", |  | ||||||
|   "错误:请求次数过多,请稍后再试!": "Error: Too many requests, please try again later!", |  | ||||||
|   "错误:服务器内部错误,请联系管理员!": "Error: Server internal error, please contact the online customer service!", |  | ||||||
|   "本站仅作演示之用,无服务端!": "This site is for demonstration purposes only, no server!", |  | ||||||
|   "错误:": "Error:", |  | ||||||
|   "加载首页内容失败...": "Failed to load homepage content...", |  | ||||||
|   "系统状况": "System status", |  | ||||||
|   "系统信息": "System information", |  | ||||||
|   "系统信息总览": "System information overview", |  | ||||||
|   "名称:": "Name:", |  | ||||||
|   "版本:": "Version:", |  | ||||||
|   "源码:": "Source code:", |  | ||||||
|   "启动时间:": "Startup time:", |  | ||||||
|   "系统配置": "System configuration", |  | ||||||
|   "系统配置总览": "System configuration overview", |  | ||||||
|   "邮箱验证:": "Email verification:", |  | ||||||
|   "未启用": "Not enabled", |  | ||||||
|   "Turnstile 用户校验:": "Turnstile user verification:", |  | ||||||
|   "页面不存在": "Page does not exist", |  | ||||||
|   "请检查你的浏览器地址是否正确": "Please check if your browser address is correct", |  | ||||||
|   "个人设置": "Personal settings", |  | ||||||
|   "运营设置": "Operations settings", |  | ||||||
|   "系统设置": "System settings", |  | ||||||
|   "其他设置": "Other settings", |  | ||||||
|   "默认令牌": "Default key", |  | ||||||
|   "过期时间必须在当前时间之后!": "Expiration time must be after the current time!", |  | ||||||
|   "额度必须大于等于 0!": "Quota must be greater than or equal to 0!", |  | ||||||
|   "过期时间格式错误!": "Expiration time format error!", |  | ||||||
|   "创建令牌数量必须大于等于 1!": "The number of keys to create must be greater than or equal to 1!", |  | ||||||
|   "令牌修改成功": "API Key modification successful", |  | ||||||
|   "令牌创建成功": "API Key creation successful", |  | ||||||
|   "更新令牌信息": "Update key information", |  | ||||||
|   "创建新的令牌": "Create a new key", |  | ||||||
|   "请输入名称": "Please enter a name", |  | ||||||
|   "请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss,-1 表示无限制": "Please enter the expiration time, the format is yyyy-MM-dd HH:mm:ss, -1 means unlimited", |  | ||||||
|   "无限额度": "Unlimited quota", |  | ||||||
|   "注意:启用无限额度后,已用额度将不再进行计算。": "Note: After enabling unlimited quota, the used quota will no longer be calculated.", |  | ||||||
|   "等于": "Equals", |  | ||||||
|   "请输入额度(单位:token)": "Please enter the quota (unit: token)", |  | ||||||
|   "创建令牌数量": "Create key quantity", |  | ||||||
|   "请输入令牌数量": "Please enter the number of keys", |  | ||||||
|   "注意:令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note: The quota of the key is only used to limit the maximum quota usage of the key itself, and the actual usage is subject to the remaining quota of the account.", |  | ||||||
|   "我的令牌": "My keys", |  | ||||||
|   "请输入额度兑换码!": "Please enter the redeem code!", |  | ||||||
|   "充值成功!": "Recharge successful!", |  | ||||||
|   "请求失败": "Request failed", |  | ||||||
|   "超级管理员未设置充值链接!": "The super administrator did not set a recharge link!", |  | ||||||
|   "充值额度": "Recharge quota", |  | ||||||
|   "兑换中...": "Redeeming...", |  | ||||||
|   "请点击充值以获取额度兑换码。": "Please click recharge to get the quota redemption code.", |  | ||||||
|   "用户信息更新成功!": "User information updated successfully!", |  | ||||||
|   "更新用户信息": "Update user information", |  | ||||||
|   "请输入新的用户名": "Please enter a new username", |  | ||||||
|   "请输入新的密码,最短 8 位": "Please enter a new password, at least 8 characters", |  | ||||||
|   "请输入新的显示名称": "Please enter a new display name", |  | ||||||
|   "分组": "Group", |  | ||||||
|   "请选择分组": "Please select a group", |  | ||||||
|   "请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit the group rate on the system settings page to add a new group:", |  | ||||||
|   "请输入新的剩余额度": "Please enter a new remaining quota", |  | ||||||
|   "已绑定的 GitHub 账户": "Bound GitHub account", |  | ||||||
|   "此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改": "This item is read-only, users need to bind through the relevant binding button on the personal settings page, cannot be directly modified", |  | ||||||
|   "已绑定的微信账户": "Bound WeChat account", |  | ||||||
|   "已绑定的邮箱账户": "Bound email account", |  | ||||||
|   "新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面": "New version available: ${data.version}, please refresh the page using the shortcut key Shift + F5", |  | ||||||
|   "无法正常连接至服务器!": "Unable to connect to the server normally!", |  | ||||||
|   "提示:": "Input:", |  | ||||||
|   "补全:": "Output:", |  | ||||||
|   "搜索令牌名称": "Search key name", |  | ||||||
|   "测试所有渠道": "Test all channels", |  | ||||||
|   "更新已启用渠道余额": "Update the balance of enabled channels" |  | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										68
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										68
									
								
								main.go
									
									
									
									
									
								
							| @@ -3,87 +3,87 @@ package main | |||||||
| import ( | import ( | ||||||
| 	"embed" | 	"embed" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/controller" | ||||||
|  | 	"one-api/middleware" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	"one-api/router" | ||||||
|  | 	"os" | ||||||
|  | 	"strconv" | ||||||
|  |  | ||||||
| 	"github.com/gin-contrib/sessions" | 	"github.com/gin-contrib/sessions" | ||||||
| 	"github.com/gin-contrib/sessions/cookie" | 	"github.com/gin-contrib/sessions/cookie" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"github.com/songquanpeng/one-api/controller" |  | ||||||
| 	"github.com/songquanpeng/one-api/middleware" |  | ||||||
| 	"github.com/songquanpeng/one-api/model" |  | ||||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" |  | ||||||
| 	"github.com/songquanpeng/one-api/router" |  | ||||||
| 	"os" |  | ||||||
| 	"strconv" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| //go:embed web/build/* | //go:embed web/build | ||||||
| var buildFS embed.FS | var buildFS embed.FS | ||||||
|  |  | ||||||
|  | //go:embed web/build/index.html | ||||||
|  | var indexPage []byte | ||||||
|  |  | ||||||
| func main() { | func main() { | ||||||
| 	logger.SetupLogger() | 	common.SetupLogger() | ||||||
| 	logger.SysLog(fmt.Sprintf("One API %s started", common.Version)) | 	common.SysLog("One API " + common.Version + " started") | ||||||
| 	if os.Getenv("GIN_MODE") != "debug" { | 	if os.Getenv("GIN_MODE") != "debug" { | ||||||
| 		gin.SetMode(gin.ReleaseMode) | 		gin.SetMode(gin.ReleaseMode) | ||||||
| 	} | 	} | ||||||
| 	if config.DebugEnabled { | 	if common.DebugEnabled { | ||||||
| 		logger.SysLog("running in debug mode") | 		common.SysLog("running in debug mode") | ||||||
| 	} | 	} | ||||||
| 	// Initialize SQL Database | 	// Initialize SQL Database | ||||||
| 	err := model.InitDB() | 	err := model.InitDB() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.FatalLog("failed to initialize database: " + err.Error()) | 		common.FatalLog("failed to initialize database: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		err := model.CloseDB() | 		err := model.CloseDB() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.FatalLog("failed to close database: " + err.Error()) | 			common.FatalLog("failed to close database: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	// Initialize Redis | 	// Initialize Redis | ||||||
| 	err = common.InitRedisClient() | 	err = common.InitRedisClient() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.FatalLog("failed to initialize Redis: " + err.Error()) | 		common.FatalLog("failed to initialize Redis: " + err.Error()) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Initialize options | 	// Initialize options | ||||||
| 	model.InitOptionMap() | 	model.InitOptionMap() | ||||||
| 	logger.SysLog(fmt.Sprintf("using theme %s", config.Theme)) |  | ||||||
| 	if common.RedisEnabled { | 	if common.RedisEnabled { | ||||||
| 		// for compatibility with old versions | 		// for compatibility with old versions | ||||||
| 		config.MemoryCacheEnabled = true | 		common.MemoryCacheEnabled = true | ||||||
| 	} | 	} | ||||||
| 	if config.MemoryCacheEnabled { | 	if common.MemoryCacheEnabled { | ||||||
| 		logger.SysLog("memory cache enabled") | 		common.SysLog("memory cache enabled") | ||||||
| 		logger.SysError(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency)) | 		common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) | ||||||
| 		model.InitChannelCache() | 		model.InitChannelCache() | ||||||
| 	} | 	} | ||||||
| 	if config.MemoryCacheEnabled { | 	if common.MemoryCacheEnabled { | ||||||
| 		go model.SyncOptions(config.SyncFrequency) | 		go model.SyncOptions(common.SyncFrequency) | ||||||
| 		go model.SyncChannelCache(config.SyncFrequency) | 		go model.SyncChannelCache(common.SyncFrequency) | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { | 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { | ||||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) | 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) | 			common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 		go controller.AutomaticallyUpdateChannels(frequency) | 		go controller.AutomaticallyUpdateChannels(frequency) | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { | 	if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { | ||||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) | 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) | 			common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 		go controller.AutomaticallyTestChannels(frequency) | 		go controller.AutomaticallyTestChannels(frequency) | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { | 	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { | ||||||
| 		config.BatchUpdateEnabled = true | 		common.BatchUpdateEnabled = true | ||||||
| 		logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") | 		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") | ||||||
| 		model.InitBatchUpdater() | 		model.InitBatchUpdater() | ||||||
| 	} | 	} | ||||||
| 	openai.InitTokenEncoders() | 	common.InitTokenEncoders() | ||||||
|  |  | ||||||
| 	// Initialize HTTP server | 	// Initialize HTTP server | ||||||
| 	server := gin.New() | 	server := gin.New() | ||||||
| @@ -93,16 +93,16 @@ func main() { | |||||||
| 	server.Use(middleware.RequestId()) | 	server.Use(middleware.RequestId()) | ||||||
| 	middleware.SetUpLogger(server) | 	middleware.SetUpLogger(server) | ||||||
| 	// Initialize session store | 	// Initialize session store | ||||||
| 	store := cookie.NewStore([]byte(config.SessionSecret)) | 	store := cookie.NewStore([]byte(common.SessionSecret)) | ||||||
| 	server.Use(sessions.Sessions("session", store)) | 	server.Use(sessions.Sessions("session", store)) | ||||||
|  |  | ||||||
| 	router.SetRouter(server, buildFS) | 	router.SetRouter(server, buildFS, indexPage) | ||||||
| 	var port = os.Getenv("PORT") | 	var port = os.Getenv("PORT") | ||||||
| 	if port == "" { | 	if port == "" { | ||||||
| 		port = strconv.Itoa(*common.Port) | 		port = strconv.Itoa(*common.Port) | ||||||
| 	} | 	} | ||||||
| 	err = server.Run(":" + port) | 	err = server.Run(":" + port) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.FatalLog("failed to start HTTP server: " + err.Error()) | 		common.FatalLog("failed to start HTTP server: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -3,9 +3,9 @@ package middleware | |||||||
| import ( | import ( | ||||||
| 	"github.com/gin-contrib/sessions" | 	"github.com/gin-contrib/sessions" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/model" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -108,7 +108,7 @@ func TokenAuth() func(c *gin.Context) { | |||||||
| 		c.Set("token_name", token.Name) | 		c.Set("token_name", token.Name) | ||||||
| 		if len(parts) > 1 { | 		if len(parts) > 1 { | ||||||
| 			if model.IsAdmin(token.UserId) { | 			if model.IsAdmin(token.UserId) { | ||||||
| 				c.Set("specific_channel_id", parts[1]) | 				c.Set("channelId", parts[1]) | ||||||
| 			} else { | 			} else { | ||||||
| 				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") | 				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") | ||||||
| 				return | 				return | ||||||
|   | |||||||
| @@ -1,112 +1,16 @@ | |||||||
| package middleware | package middleware | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"one-api/model" | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"github.com/songquanpeng/one-api/model" |  | ||||||
| 	"net/http" |  | ||||||
| 	"strconv" |  | ||||||
| 	"strings" |  | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type ModelRequest struct { |  | ||||||
| 	Model string `json:"model"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func Distribute() func(c *gin.Context) { | func Distribute() func(c *gin.Context) { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		userId := c.GetInt("id") | 		userId := c.GetInt("id") | ||||||
| 		userGroup, _ := model.CacheGetUserGroup(userId) | 		userGroup, _ := model.CacheGetUserGroup(userId) | ||||||
| 		c.Set("group", userGroup) | 		c.Set("group", userGroup) | ||||||
| 		var requestModel string |  | ||||||
| 		var channel *model.Channel |  | ||||||
| 		channelId, ok := c.Get("specific_channel_id") |  | ||||||
| 		if ok { |  | ||||||
| 			id, err := strconv.Atoi(channelId.(string)) |  | ||||||
| 			if err != nil { |  | ||||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 			channel, err = model.GetChannelById(id, true) |  | ||||||
| 			if err != nil { |  | ||||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 			if channel.Status != common.ChannelStatusEnabled { |  | ||||||
| 				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 		} else { |  | ||||||
| 			// Select a channel for the user |  | ||||||
| 			var modelRequest ModelRequest |  | ||||||
| 			err := common.UnmarshalBodyReusable(c, &modelRequest) |  | ||||||
| 			if err != nil { |  | ||||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的请求") |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { |  | ||||||
| 				if modelRequest.Model == "" { |  | ||||||
| 					modelRequest.Model = "text-moderation-stable" |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			if strings.HasSuffix(c.Request.URL.Path, "embeddings") { |  | ||||||
| 				if modelRequest.Model == "" { |  | ||||||
| 					modelRequest.Model = c.Param("model") |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { |  | ||||||
| 				if modelRequest.Model == "" { |  | ||||||
| 					modelRequest.Model = "dall-e-2" |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { |  | ||||||
| 				if modelRequest.Model == "" { |  | ||||||
| 					modelRequest.Model = "whisper-1" |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			requestModel = modelRequest.Model |  | ||||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false) |  | ||||||
| 			if err != nil { |  | ||||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) |  | ||||||
| 				if channel != nil { |  | ||||||
| 					logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) |  | ||||||
| 					message = "数据库一致性已被破坏,请联系管理员" |  | ||||||
| 				} |  | ||||||
| 				abortWithMessage(c, http.StatusServiceUnavailable, message) |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		SetupContextForSelectedChannel(c, channel, requestModel) |  | ||||||
| 		c.Next() | 		c.Next() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { |  | ||||||
| 	c.Set("channel", channel.Type) |  | ||||||
| 	c.Set("channel_id", channel.Id) |  | ||||||
| 	c.Set("channel_name", channel.Name) |  | ||||||
| 	c.Set("model_mapping", channel.GetModelMapping()) |  | ||||||
| 	c.Set("original_model", modelName) // for retry |  | ||||||
| 	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) |  | ||||||
| 	c.Set("base_url", channel.GetBaseURL()) |  | ||||||
| 	// this is for backward compatibility |  | ||||||
| 	switch channel.Type { |  | ||||||
| 	case common.ChannelTypeAzure: |  | ||||||
| 		c.Set(common.ConfigKeyAPIVersion, channel.Other) |  | ||||||
| 	case common.ChannelTypeXunfei: |  | ||||||
| 		c.Set(common.ConfigKeyAPIVersion, channel.Other) |  | ||||||
| 	case common.ChannelTypeGemini: |  | ||||||
| 		c.Set(common.ConfigKeyAPIVersion, channel.Other) |  | ||||||
| 	case common.ChannelTypeAIProxyLibrary: |  | ||||||
| 		c.Set(common.ConfigKeyLibraryID, channel.Other) |  | ||||||
| 	case common.ChannelTypeAli: |  | ||||||
| 		c.Set(common.ConfigKeyPlugin, channel.Other) |  | ||||||
| 	} |  | ||||||
| 	cfg, _ := channel.LoadConfig() |  | ||||||
| 	for k, v := range cfg { |  | ||||||
| 		c.Set(common.ConfigKeyPrefix+k, v) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -3,14 +3,14 @@ package middleware | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" | 	"one-api/common" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func SetUpLogger(server *gin.Engine) { | func SetUpLogger(server *gin.Engine) { | ||||||
| 	server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { | 	server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { | ||||||
| 		var requestID string | 		var requestID string | ||||||
| 		if param.Keys != nil { | 		if param.Keys != nil { | ||||||
| 			requestID = param.Keys[logger.RequestIdKey].(string) | 			requestID = param.Keys[common.RequestIdKey].(string) | ||||||
| 		} | 		} | ||||||
| 		return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", | 		return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", | ||||||
| 			param.TimeStamp.Format("2006/01/02 - 15:04:05"), | 			param.TimeStamp.Format("2006/01/02 - 15:04:05"), | ||||||
|   | |||||||
| @@ -4,9 +4,8 @@ import ( | |||||||
| 	"context" | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -27,7 +26,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st | |||||||
| 	} | 	} | ||||||
| 	if listLength < int64(maxRequestNum) { | 	if listLength < int64(maxRequestNum) { | ||||||
| 		rdb.LPush(ctx, key, time.Now().Format(timeFormat)) | 		rdb.LPush(ctx, key, time.Now().Format(timeFormat)) | ||||||
| 		rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) | 		rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) | ||||||
| 	} else { | 	} else { | ||||||
| 		oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() | 		oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() | ||||||
| 		oldTime, err := time.Parse(timeFormat, oldTimeStr) | 		oldTime, err := time.Parse(timeFormat, oldTimeStr) | ||||||
| @@ -48,14 +47,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st | |||||||
| 		// time.Since will return negative number! | 		// time.Since will return negative number! | ||||||
| 		// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows | 		// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows | ||||||
| 		if int64(nowTime.Sub(oldTime).Seconds()) < duration { | 		if int64(nowTime.Sub(oldTime).Seconds()) < duration { | ||||||
| 			rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) | 			rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) | ||||||
| 			c.Status(http.StatusTooManyRequests) | 			c.Status(http.StatusTooManyRequests) | ||||||
| 			c.Abort() | 			c.Abort() | ||||||
| 			return | 			return | ||||||
| 		} else { | 		} else { | ||||||
| 			rdb.LPush(ctx, key, time.Now().Format(timeFormat)) | 			rdb.LPush(ctx, key, time.Now().Format(timeFormat)) | ||||||
| 			rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) | 			rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) | ||||||
| 			rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) | 			rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -76,7 +75,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi | |||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		// It's safe to call multi times. | 		// It's safe to call multi times. | ||||||
| 		inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration) | 		inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) | ||||||
| 		return func(c *gin.Context) { | 		return func(c *gin.Context) { | ||||||
| 			memoryRateLimiter(c, maxRequestNum, duration, mark) | 			memoryRateLimiter(c, maxRequestNum, duration, mark) | ||||||
| 		} | 		} | ||||||
| @@ -84,21 +83,21 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi | |||||||
| } | } | ||||||
|  |  | ||||||
| func GlobalWebRateLimit() func(c *gin.Context) { | func GlobalWebRateLimit() func(c *gin.Context) { | ||||||
| 	return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW") | 	return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW") | ||||||
| } | } | ||||||
|  |  | ||||||
| func GlobalAPIRateLimit() func(c *gin.Context) { | func GlobalAPIRateLimit() func(c *gin.Context) { | ||||||
| 	return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA") | 	return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA") | ||||||
| } | } | ||||||
|  |  | ||||||
| func CriticalRateLimit() func(c *gin.Context) { | func CriticalRateLimit() func(c *gin.Context) { | ||||||
| 	return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT") | 	return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT") | ||||||
| } | } | ||||||
|  |  | ||||||
| func DownloadRateLimit() func(c *gin.Context) { | func DownloadRateLimit() func(c *gin.Context) { | ||||||
| 	return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW") | 	return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW") | ||||||
| } | } | ||||||
|  |  | ||||||
| func UploadRateLimit() func(c *gin.Context) { | func UploadRateLimit() func(c *gin.Context) { | ||||||
| 	return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP") | 	return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") | ||||||
| } | } | ||||||
|   | |||||||
| @@ -3,8 +3,8 @@ package middleware | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
| 	"runtime/debug" | 	"runtime/debug" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc { | |||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		defer func() { | 		defer func() { | ||||||
| 			if err := recover(); err != nil { | 			if err := recover(); err != nil { | ||||||
| 				logger.SysError(fmt.Sprintf("panic detected: %v", err)) | 				common.SysError(fmt.Sprintf("panic detected: %v", err)) | ||||||
| 				logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) | 				common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) | ||||||
| 				c.JSON(http.StatusInternalServerError, gin.H{ | 				c.JSON(http.StatusInternalServerError, gin.H{ | ||||||
| 					"error": gin.H{ | 					"error": gin.H{ | ||||||
| 						"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), | 						"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), | ||||||
|   | |||||||
| @@ -3,17 +3,16 @@ package middleware | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" | 	"one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func RequestId() func(c *gin.Context) { | func RequestId() func(c *gin.Context) { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		id := helper.GetTimeString() + helper.GetRandomNumberString(8) | 		id := common.GetTimeString() + common.GetRandomString(8) | ||||||
| 		c.Set(logger.RequestIdKey, id) | 		c.Set(common.RequestIdKey, id) | ||||||
| 		ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) | 		ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) | ||||||
| 		c.Request = c.Request.WithContext(ctx) | 		c.Request = c.Request.WithContext(ctx) | ||||||
| 		c.Header(logger.RequestIdKey, id) | 		c.Header(common.RequestIdKey, id) | ||||||
| 		c.Next() | 		c.Next() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -4,10 +4,9 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"github.com/gin-contrib/sessions" | 	"github.com/gin-contrib/sessions" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | 	"one-api/common" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type turnstileCheckResponse struct { | type turnstileCheckResponse struct { | ||||||
| @@ -16,7 +15,7 @@ type turnstileCheckResponse struct { | |||||||
|  |  | ||||||
| func TurnstileCheck() gin.HandlerFunc { | func TurnstileCheck() gin.HandlerFunc { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		if config.TurnstileCheckEnabled { | 		if common.TurnstileCheckEnabled { | ||||||
| 			session := sessions.Default(c) | 			session := sessions.Default(c) | ||||||
| 			turnstileChecked := session.Get("turnstile") | 			turnstileChecked := session.Get("turnstile") | ||||||
| 			if turnstileChecked != nil { | 			if turnstileChecked != nil { | ||||||
| @@ -33,12 +32,12 @@ func TurnstileCheck() gin.HandlerFunc { | |||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ | 			rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ | ||||||
| 				"secret":   {config.TurnstileSecretKey}, | 				"secret":   {common.TurnstileSecretKey}, | ||||||
| 				"response": {response}, | 				"response": {response}, | ||||||
| 				"remoteip": {c.ClientIP()}, | 				"remoteip": {c.ClientIP()}, | ||||||
| 			}) | 			}) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				logger.SysError(err.Error()) | 				common.SysError(err.Error()) | ||||||
| 				c.JSON(http.StatusOK, gin.H{ | 				c.JSON(http.StatusOK, gin.H{ | ||||||
| 					"success": false, | 					"success": false, | ||||||
| 					"message": err.Error(), | 					"message": err.Error(), | ||||||
| @@ -50,7 +49,7 @@ func TurnstileCheck() gin.HandlerFunc { | |||||||
| 			var res turnstileCheckResponse | 			var res turnstileCheckResponse | ||||||
| 			err = json.NewDecoder(rawRes.Body).Decode(&res) | 			err = json.NewDecoder(rawRes.Body).Decode(&res) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				logger.SysError(err.Error()) | 				common.SysError(err.Error()) | ||||||
| 				c.JSON(http.StatusOK, gin.H{ | 				c.JSON(http.StatusOK, gin.H{ | ||||||
| 					"success": false, | 					"success": false, | ||||||
| 					"message": err.Error(), | 					"message": err.Error(), | ||||||
|   | |||||||
| @@ -2,17 +2,16 @@ package middleware | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" | 	"one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func abortWithMessage(c *gin.Context, statusCode int, message string) { | func abortWithMessage(c *gin.Context, statusCode int, message string) { | ||||||
| 	c.JSON(statusCode, gin.H{ | 	c.JSON(statusCode, gin.H{ | ||||||
| 		"error": gin.H{ | 		"error": gin.H{ | ||||||
| 			"message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)), | 			"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), | ||||||
| 			"type":    "one_api_error", | 			"type":    "one_api_error", | ||||||
| 		}, | 		}, | ||||||
| 	}) | 	}) | ||||||
| 	c.Abort() | 	c.Abort() | ||||||
| 	logger.Error(c.Request.Context(), message) | 	common.LogError(c.Request.Context(), message) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"one-api/common" | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -39,6 +39,22 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | |||||||
| 	return &channel, err | 	return &channel, err | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GetGroupModels(group string) ([]string, error) { | ||||||
|  | 	var models []string | ||||||
|  | 	groupCol := "`group`" | ||||||
|  | 	trueVal := "1" | ||||||
|  | 	if common.UsingPostgreSQL { | ||||||
|  | 		groupCol = `"group"` | ||||||
|  | 		trueVal = "true" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err := DB.Model(&Ability{}).Where(groupCol+" = ? and enabled = ? ", group, trueVal).Distinct("model").Pluck("model", &models).Error | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return models, nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func (channel *Channel) AddAbilities() error { | func (channel *Channel) AddAbilities() error { | ||||||
| 	models_ := strings.Split(channel.Models, ",") | 	models_ := strings.Split(channel.Models, ",") | ||||||
| 	groups_ := strings.Split(channel.Group, ",") | 	groups_ := strings.Split(channel.Group, ",") | ||||||
|   | |||||||
| @@ -4,10 +4,8 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
|  | 	"one-api/common" | ||||||
| 	"sort" | 	"sort" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| @@ -16,10 +14,10 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| var ( | var ( | ||||||
| 	TokenCacheSeconds         = config.SyncFrequency | 	TokenCacheSeconds         = common.SyncFrequency | ||||||
| 	UserId2GroupCacheSeconds  = config.SyncFrequency | 	UserId2GroupCacheSeconds  = common.SyncFrequency | ||||||
| 	UserId2QuotaCacheSeconds  = config.SyncFrequency | 	UserId2QuotaCacheSeconds  = common.SyncFrequency | ||||||
| 	UserId2StatusCacheSeconds = config.SyncFrequency | 	UserId2StatusCacheSeconds = common.SyncFrequency | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func CacheGetTokenByKey(key string) (*Token, error) { | func CacheGetTokenByKey(key string) (*Token, error) { | ||||||
| @@ -44,7 +42,7 @@ func CacheGetTokenByKey(key string) (*Token, error) { | |||||||
| 		} | 		} | ||||||
| 		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) | 		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.SysError("Redis set token error: " + err.Error()) | 			common.SysError("Redis set token error: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 		return &token, nil | 		return &token, nil | ||||||
| 	} | 	} | ||||||
| @@ -64,7 +62,7 @@ func CacheGetUserGroup(id int) (group string, err error) { | |||||||
| 		} | 		} | ||||||
| 		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) | 		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.SysError("Redis set user group error: " + err.Error()) | 			common.SysError("Redis set user group error: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	return group, err | 	return group, err | ||||||
| @@ -82,7 +80,7 @@ func CacheGetUserQuota(id int) (quota int, err error) { | |||||||
| 		} | 		} | ||||||
| 		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | 		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.SysError("Redis set user quota error: " + err.Error()) | 			common.SysError("Redis set user quota error: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 		return quota, err | 		return quota, err | ||||||
| 	} | 	} | ||||||
| @@ -94,7 +92,7 @@ func CacheUpdateUserQuota(id int) error { | |||||||
| 	if !common.RedisEnabled { | 	if !common.RedisEnabled { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	quota, err := CacheGetUserQuota(id) | 	quota, err := GetUserQuota(id) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @@ -129,7 +127,7 @@ func CacheIsUserEnabled(userId int) (bool, error) { | |||||||
| 	} | 	} | ||||||
| 	err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) | 	err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("Redis set user enabled error: " + err.Error()) | 		common.SysError("Redis set user enabled error: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	return userEnabled, err | 	return userEnabled, err | ||||||
| } | } | ||||||
| @@ -180,19 +178,19 @@ func InitChannelCache() { | |||||||
| 	channelSyncLock.Lock() | 	channelSyncLock.Lock() | ||||||
| 	group2model2channels = newGroup2model2channels | 	group2model2channels = newGroup2model2channels | ||||||
| 	channelSyncLock.Unlock() | 	channelSyncLock.Unlock() | ||||||
| 	logger.SysLog("channels synced from database") | 	common.SysLog("channels synced from database") | ||||||
| } | } | ||||||
|  |  | ||||||
| func SyncChannelCache(frequency int) { | func SyncChannelCache(frequency int) { | ||||||
| 	for { | 	for { | ||||||
| 		time.Sleep(time.Duration(frequency) * time.Second) | 		time.Sleep(time.Duration(frequency) * time.Second) | ||||||
| 		logger.SysLog("syncing channels from database") | 		common.SysLog("syncing channels from database") | ||||||
| 		InitChannelCache() | 		InitChannelCache() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { | func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||||
| 	if !config.MemoryCacheEnabled { | 	if !common.MemoryCacheEnabled { | ||||||
| 		return GetRandomSatisfiedChannel(group, model) | 		return GetRandomSatisfiedChannel(group, model) | ||||||
| 	} | 	} | ||||||
| 	channelSyncLock.RLock() | 	channelSyncLock.RLock() | ||||||
| @@ -213,10 +211,24 @@ func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPrior | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	idx := rand.Intn(endIdx) | 	idx := rand.Intn(endIdx) | ||||||
| 	if ignoreFirstPriority { |  | ||||||
| 		if endIdx < len(channels) { // which means there are more than one priority |  | ||||||
| 			idx = common.RandRange(endIdx, len(channels)) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return channels[idx], nil | 	return channels[idx], nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func CacheGetGroupModels(group string) ([]string, error) { | ||||||
|  | 	if !common.MemoryCacheEnabled { | ||||||
|  | 		return GetGroupModels(group) | ||||||
|  | 	} | ||||||
|  | 	channelSyncLock.RLock() | ||||||
|  | 	defer channelSyncLock.RUnlock() | ||||||
|  |  | ||||||
|  | 	groupModels := group2model2channels[group] | ||||||
|  | 	if groupModels == nil { | ||||||
|  | 		return nil, errors.New("group not found") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	models := make([]string, 0) | ||||||
|  | 	for model := range groupModels { | ||||||
|  | 		models = append(models, model) | ||||||
|  | 	} | ||||||
|  | 	return models, nil | ||||||
|  | } | ||||||
|   | |||||||
| @@ -1,12 +1,8 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"one-api/common" | ||||||
| 	"fmt" |  | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -21,7 +17,7 @@ type Channel struct { | |||||||
| 	TestTime           int64   `json:"test_time" gorm:"bigint"` | 	TestTime           int64   `json:"test_time" gorm:"bigint"` | ||||||
| 	ResponseTime       int     `json:"response_time"` // in milliseconds | 	ResponseTime       int     `json:"response_time"` // in milliseconds | ||||||
| 	BaseURL            *string `json:"base_url" gorm:"column:base_url;default:''"` | 	BaseURL            *string `json:"base_url" gorm:"column:base_url;default:''"` | ||||||
| 	Other              string  `json:"other"`   // DEPRECATED: please save config to field Config | 	Other              string  `json:"other"` | ||||||
| 	Balance            float64 `json:"balance"` // in USD | 	Balance            float64 `json:"balance"` // in USD | ||||||
| 	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"` | 	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"` | ||||||
| 	Models             string  `json:"models"` | 	Models             string  `json:"models"` | ||||||
| @@ -29,7 +25,7 @@ type Channel struct { | |||||||
| 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"` | 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"` | ||||||
| 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | ||||||
| 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"` | 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"` | ||||||
| 	Config             string  `json:"config"` | 	Proxy              string  `json:"proxy" gorm:"type:varchar(255);default:''"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | ||||||
| @@ -48,7 +44,7 @@ func SearchChannels(keyword string) (channels []*Channel, err error) { | |||||||
| 	if common.UsingPostgreSQL { | 	if common.UsingPostgreSQL { | ||||||
| 		keyCol = `"key"` | 		keyCol = `"key"` | ||||||
| 	} | 	} | ||||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", helper.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error | 	err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error | ||||||
| 	return channels, err | 	return channels, err | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -92,17 +88,11 @@ func (channel *Channel) GetBaseURL() string { | |||||||
| 	return *channel.BaseURL | 	return *channel.BaseURL | ||||||
| } | } | ||||||
|  |  | ||||||
| func (channel *Channel) GetModelMapping() map[string]string { | func (channel *Channel) GetModelMapping() string { | ||||||
| 	if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" { | 	if channel.ModelMapping == nil { | ||||||
| 		return nil | 		return "" | ||||||
| 	} | 	} | ||||||
| 	modelMapping := make(map[string]string) | 	return *channel.ModelMapping | ||||||
| 	err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error())) |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| 	return modelMapping |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (channel *Channel) Insert() error { | func (channel *Channel) Insert() error { | ||||||
| @@ -128,21 +118,21 @@ func (channel *Channel) Update() error { | |||||||
|  |  | ||||||
| func (channel *Channel) UpdateResponseTime(responseTime int64) { | func (channel *Channel) UpdateResponseTime(responseTime int64) { | ||||||
| 	err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ | 	err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ | ||||||
| 		TestTime:     helper.GetTimestamp(), | 		TestTime:     common.GetTimestamp(), | ||||||
| 		ResponseTime: int(responseTime), | 		ResponseTime: int(responseTime), | ||||||
| 	}).Error | 	}).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("failed to update response time: " + err.Error()) | 		common.SysError("failed to update response time: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (channel *Channel) UpdateBalance(balance float64) { | func (channel *Channel) UpdateBalance(balance float64) { | ||||||
| 	err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ | 	err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ | ||||||
| 		BalanceUpdatedTime: helper.GetTimestamp(), | 		BalanceUpdatedTime: common.GetTimestamp(), | ||||||
| 		Balance:            balance, | 		Balance:            balance, | ||||||
| 	}).Error | 	}).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("failed to update balance: " + err.Error()) | 		common.SysError("failed to update balance: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -156,31 +146,19 @@ func (channel *Channel) Delete() error { | |||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func (channel *Channel) LoadConfig() (map[string]string, error) { |  | ||||||
| 	if channel.Config == "" { |  | ||||||
| 		return nil, nil |  | ||||||
| 	} |  | ||||||
| 	cfg := make(map[string]string) |  | ||||||
| 	err := json.Unmarshal([]byte(channel.Config), &cfg) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	return cfg, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func UpdateChannelStatusById(id int, status int) { | func UpdateChannelStatusById(id int, status int) { | ||||||
| 	err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) | 	err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("failed to update ability status: " + err.Error()) | 		common.SysError("failed to update ability status: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error | 	err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("failed to update channel status: " + err.Error()) | 		common.SysError("failed to update channel status: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateChannelUsedQuota(id int, quota int) { | func UpdateChannelUsedQuota(id int, quota int) { | ||||||
| 	if config.BatchUpdateEnabled { | 	if common.BatchUpdateEnabled { | ||||||
| 		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) | 		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -190,7 +168,7 @@ func UpdateChannelUsedQuota(id int, quota int) { | |||||||
| func updateChannelUsedQuota(id int, quota int) { | func updateChannelUsedQuota(id int, quota int) { | ||||||
| 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("failed to update channel used quota: " + err.Error()) | 		common.SysError("failed to update channel used quota: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										67
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										67
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -3,18 +3,15 @@ package model | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
|  |  | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Log struct { | type Log struct { | ||||||
| 	Id               int    `json:"id"` | 	Id               int    `json:"id;index:idx_created_at_id,priority:1"` | ||||||
| 	UserId           int    `json:"user_id" gorm:"index"` | 	UserId           int    `json:"user_id" gorm:"index"` | ||||||
| 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index:idx_created_at_type"` | 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"` | ||||||
| 	Type             int    `json:"type" gorm:"index:idx_created_at_type"` | 	Type             int    `json:"type" gorm:"index:idx_created_at_type"` | ||||||
| 	Content          string `json:"content"` | 	Content          string `json:"content"` | ||||||
| 	Username         string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` | 	Username         string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` | ||||||
| @@ -26,6 +23,15 @@ type Log struct { | |||||||
| 	ChannelId        int    `json:"channel" gorm:"index"` | 	ChannelId        int    `json:"channel" gorm:"index"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type LogStatistic struct { | ||||||
|  | 	Day              string `gorm:"column:day"` | ||||||
|  | 	ModelName        string `gorm:"column:model_name"` | ||||||
|  | 	RequestCount     int    `gorm:"column:request_count"` | ||||||
|  | 	Quota            int    `gorm:"column:quota"` | ||||||
|  | 	PromptTokens     int    `gorm:"column:prompt_tokens"` | ||||||
|  | 	CompletionTokens int    `gorm:"column:completion_tokens"` | ||||||
|  | } | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	LogTypeUnknown = iota | 	LogTypeUnknown = iota | ||||||
| 	LogTypeTopup | 	LogTypeTopup | ||||||
| @@ -35,31 +41,31 @@ const ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| func RecordLog(userId int, logType int, content string) { | func RecordLog(userId int, logType int, content string) { | ||||||
| 	if logType == LogTypeConsume && !config.LogConsumeEnabled { | 	if logType == LogTypeConsume && !common.LogConsumeEnabled { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	log := &Log{ | 	log := &Log{ | ||||||
| 		UserId:    userId, | 		UserId:    userId, | ||||||
| 		Username:  GetUsernameById(userId), | 		Username:  GetUsernameById(userId), | ||||||
| 		CreatedAt: helper.GetTimestamp(), | 		CreatedAt: common.GetTimestamp(), | ||||||
| 		Type:      logType, | 		Type:      logType, | ||||||
| 		Content:   content, | 		Content:   content, | ||||||
| 	} | 	} | ||||||
| 	err := DB.Create(log).Error | 	err := DB.Create(log).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("failed to record log: " + err.Error()) | 		common.SysError("failed to record log: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | ||||||
| 	logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | 	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||||
| 	if !config.LogConsumeEnabled { | 	if !common.LogConsumeEnabled { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	log := &Log{ | 	log := &Log{ | ||||||
| 		UserId:           userId, | 		UserId:           userId, | ||||||
| 		Username:         GetUsernameById(userId), | 		Username:         GetUsernameById(userId), | ||||||
| 		CreatedAt:        helper.GetTimestamp(), | 		CreatedAt:        common.GetTimestamp(), | ||||||
| 		Type:             LogTypeConsume, | 		Type:             LogTypeConsume, | ||||||
| 		Content:          content, | 		Content:          content, | ||||||
| 		PromptTokens:     promptTokens, | 		PromptTokens:     promptTokens, | ||||||
| @@ -71,7 +77,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke | |||||||
| 	} | 	} | ||||||
| 	err := DB.Create(log).Error | 	err := DB.Create(log).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Error(ctx, "failed to record log: "+err.Error()) | 		common.LogError(ctx, "failed to record log: "+err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -128,17 +134,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int | |||||||
| } | } | ||||||
|  |  | ||||||
| func SearchAllLogs(keyword string) (logs []*Log, err error) { | func SearchAllLogs(keyword string) (logs []*Log, err error) { | ||||||
| 	err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error | 	err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error | ||||||
| 	return logs, err | 	return logs, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | ||||||
| 	err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error | 	err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error | ||||||
| 	return logs, err | 	return logs, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { | func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { | ||||||
| 	tx := DB.Table("logs").Select("ifnull(sum(quota),0)") | 	tx := DB.Table("logs").Select(assembleSumSelectStr("quota")) | ||||||
| 	if username != "" { | 	if username != "" { | ||||||
| 		tx = tx.Where("username = ?", username) | 		tx = tx.Where("username = ?", username) | ||||||
| 	} | 	} | ||||||
| @@ -162,7 +168,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa | |||||||
| } | } | ||||||
|  |  | ||||||
| func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | ||||||
| 	tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") | 	tx := DB.Table("logs").Select(assembleSumSelectStr("prompt_tokens") + " + " + assembleSumSelectStr("completion_tokens")) | ||||||
| 	if username != "" { | 	if username != "" { | ||||||
| 		tx = tx.Where("username = ?", username) | 		tx = tx.Where("username = ?", username) | ||||||
| 	} | 	} | ||||||
| @@ -187,16 +193,7 @@ func DeleteOldLog(targetTimestamp int64) (int64, error) { | |||||||
| 	return result.RowsAffected, result.Error | 	return result.RowsAffected, result.Error | ||||||
| } | } | ||||||
|  |  | ||||||
| type LogStatistic struct { | func SearchLogsByDayAndModel(user_id, start, end int) (LogStatistics []*LogStatistic, err error) { | ||||||
| 	Day              string `gorm:"column:day"` |  | ||||||
| 	ModelName        string `gorm:"column:model_name"` |  | ||||||
| 	RequestCount     int    `gorm:"column:request_count"` |  | ||||||
| 	Quota            int    `gorm:"column:quota"` |  | ||||||
| 	PromptTokens     int    `gorm:"column:prompt_tokens"` |  | ||||||
| 	CompletionTokens int    `gorm:"column:completion_tokens"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatistic, err error) { |  | ||||||
| 	groupSelect := "DATE_FORMAT(FROM_UNIXTIME(created_at), '%Y-%m-%d') as day" | 	groupSelect := "DATE_FORMAT(FROM_UNIXTIME(created_at), '%Y-%m-%d') as day" | ||||||
|  |  | ||||||
| 	if common.UsingPostgreSQL { | 	if common.UsingPostgreSQL { | ||||||
| @@ -219,7 +216,21 @@ func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatis | |||||||
| 		AND created_at BETWEEN ? AND ? | 		AND created_at BETWEEN ? AND ? | ||||||
| 		GROUP BY day, model_name | 		GROUP BY day, model_name | ||||||
| 		ORDER BY day, model_name | 		ORDER BY day, model_name | ||||||
| 	`, userId, start, end).Scan(&LogStatistics).Error | 	`, user_id, start, end).Scan(&LogStatistics).Error | ||||||
|  |  | ||||||
|  | 	fmt.Println(user_id, start, end) | ||||||
|  |  | ||||||
| 	return LogStatistics, err | 	return LogStatistics, err | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func assembleSumSelectStr(selectStr string) string { | ||||||
|  | 	sumSelectStr := "%s(sum(%s),0)" | ||||||
|  | 	nullfunc := "ifnull" | ||||||
|  | 	if common.UsingPostgreSQL { | ||||||
|  | 		nullfunc = "coalesce" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	sumSelectStr = fmt.Sprintf(sumSelectStr, nullfunc, selectStr) | ||||||
|  |  | ||||||
|  | 	return sumSelectStr | ||||||
|  | } | ||||||
|   | |||||||
| @@ -2,14 +2,11 @@ package model | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"gorm.io/driver/mysql" | 	"gorm.io/driver/mysql" | ||||||
| 	"gorm.io/driver/postgres" | 	"gorm.io/driver/postgres" | ||||||
| 	"gorm.io/driver/sqlite" | 	"gorm.io/driver/sqlite" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
|  | 	"one-api/common" | ||||||
| 	"os" | 	"os" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -19,9 +16,9 @@ var DB *gorm.DB | |||||||
|  |  | ||||||
| func createRootAccountIfNeed() error { | func createRootAccountIfNeed() error { | ||||||
| 	var user User | 	var user User | ||||||
| 	//if user.Status != util.UserStatusEnabled { | 	//if user.Status != common.UserStatusEnabled { | ||||||
| 	if err := DB.First(&user).Error; err != nil { | 	if err := DB.First(&user).Error; err != nil { | ||||||
| 		logger.SysLog("no user exists, create a root user for you: username is root, password is 123456") | 		common.SysLog("no user exists, create a root user for you: username is root, password is 123456") | ||||||
| 		hashedPassword, err := common.Password2Hash("123456") | 		hashedPassword, err := common.Password2Hash("123456") | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| @@ -32,7 +29,7 @@ func createRootAccountIfNeed() error { | |||||||
| 			Role:        common.RoleRootUser, | 			Role:        common.RoleRootUser, | ||||||
| 			Status:      common.UserStatusEnabled, | 			Status:      common.UserStatusEnabled, | ||||||
| 			DisplayName: "Root User", | 			DisplayName: "Root User", | ||||||
| 			AccessToken: helper.GetUUID(), | 			AccessToken: common.GetUUID(), | ||||||
| 			Quota:       100000000, | 			Quota:       100000000, | ||||||
| 		} | 		} | ||||||
| 		DB.Create(&rootUser) | 		DB.Create(&rootUser) | ||||||
| @@ -45,7 +42,7 @@ func chooseDB() (*gorm.DB, error) { | |||||||
| 		dsn := os.Getenv("SQL_DSN") | 		dsn := os.Getenv("SQL_DSN") | ||||||
| 		if strings.HasPrefix(dsn, "postgres://") { | 		if strings.HasPrefix(dsn, "postgres://") { | ||||||
| 			// Use PostgreSQL | 			// Use PostgreSQL | ||||||
| 			logger.SysLog("using PostgreSQL as database") | 			common.SysLog("using PostgreSQL as database") | ||||||
| 			common.UsingPostgreSQL = true | 			common.UsingPostgreSQL = true | ||||||
| 			return gorm.Open(postgres.New(postgres.Config{ | 			return gorm.Open(postgres.New(postgres.Config{ | ||||||
| 				DSN:                  dsn, | 				DSN:                  dsn, | ||||||
| @@ -55,13 +52,13 @@ func chooseDB() (*gorm.DB, error) { | |||||||
| 			}) | 			}) | ||||||
| 		} | 		} | ||||||
| 		// Use MySQL | 		// Use MySQL | ||||||
| 		logger.SysLog("using MySQL as database") | 		common.SysLog("using MySQL as database") | ||||||
| 		return gorm.Open(mysql.Open(dsn), &gorm.Config{ | 		return gorm.Open(mysql.Open(dsn), &gorm.Config{ | ||||||
| 			PrepareStmt: true, // precompile SQL | 			PrepareStmt: true, // precompile SQL | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| 	// Use SQLite | 	// Use SQLite | ||||||
| 	logger.SysLog("SQL_DSN not set, using SQLite as database") | 	common.SysLog("SQL_DSN not set, using SQLite as database") | ||||||
| 	common.UsingSQLite = true | 	common.UsingSQLite = true | ||||||
| 	config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) | 	config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) | ||||||
| 	return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ | 	return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ | ||||||
| @@ -72,7 +69,7 @@ func chooseDB() (*gorm.DB, error) { | |||||||
| func InitDB() (err error) { | func InitDB() (err error) { | ||||||
| 	db, err := chooseDB() | 	db, err := chooseDB() | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		if config.DebugEnabled { | 		if common.DebugEnabled { | ||||||
| 			db = db.Debug() | 			db = db.Debug() | ||||||
| 		} | 		} | ||||||
| 		DB = db | 		DB = db | ||||||
| @@ -80,14 +77,14 @@ func InitDB() (err error) { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 		sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) | 		sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) | ||||||
| 		sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) | 		sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) | ||||||
| 		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) | 		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) | ||||||
|  |  | ||||||
| 		if !config.IsMasterNode { | 		if !common.IsMasterNode { | ||||||
| 			return nil | 			return nil | ||||||
| 		} | 		} | ||||||
| 		logger.SysLog("database migration started") | 		common.SysLog("database migration started") | ||||||
| 		err = db.AutoMigrate(&Channel{}) | 		err = db.AutoMigrate(&Channel{}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| @@ -116,11 +113,11 @@ func InitDB() (err error) { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 		logger.SysLog("database migrated") | 		common.SysLog("database migrated") | ||||||
| 		err = createRootAccountIfNeed() | 		err = createRootAccountIfNeed() | ||||||
| 		return err | 		return err | ||||||
| 	} else { | 	} else { | ||||||
| 		logger.FatalLog(err) | 		common.FatalLog(err) | ||||||
| 	} | 	} | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										217
									
								
								model/option.go
									
									
									
									
									
								
							
							
						
						
									
										217
									
								
								model/option.go
									
									
									
									
									
								
							| @@ -1,9 +1,7 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -22,57 +20,59 @@ func AllOption() ([]*Option, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func InitOptionMap() { | func InitOptionMap() { | ||||||
| 	config.OptionMapRWMutex.Lock() | 	common.OptionMapRWMutex.Lock() | ||||||
| 	config.OptionMap = make(map[string]string) | 	common.OptionMap = make(map[string]string) | ||||||
| 	config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled) | 	common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) | ||||||
| 	config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) | 	common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) | ||||||
| 	config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) | 	common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) | ||||||
| 	config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) | 	common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) | ||||||
| 	config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) | 	common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) | ||||||
| 	config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) | 	common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) | ||||||
| 	config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) | 	common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) | ||||||
| 	config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled) | 	common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) | ||||||
| 	config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled) | 	common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) | ||||||
| 	config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled) | 	common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) | ||||||
| 	config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled) | 	common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) | ||||||
| 	config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled) | 	common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) | ||||||
| 	config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled) | 	common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) | ||||||
| 	config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64) | 	common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) | ||||||
| 	config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled) | 	common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) | ||||||
| 	config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",") | 	common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) | ||||||
| 	config.OptionMap["SMTPServer"] = "" | 	common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) | ||||||
| 	config.OptionMap["SMTPFrom"] = "" | 	common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) | ||||||
| 	config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort) | 	common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) | ||||||
| 	config.OptionMap["SMTPAccount"] = "" | 	common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") | ||||||
| 	config.OptionMap["SMTPToken"] = "" | 	common.OptionMap["SMTPServer"] = "" | ||||||
| 	config.OptionMap["Notice"] = "" | 	common.OptionMap["SMTPFrom"] = "" | ||||||
| 	config.OptionMap["About"] = "" | 	common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) | ||||||
| 	config.OptionMap["HomePageContent"] = "" | 	common.OptionMap["SMTPAccount"] = "" | ||||||
| 	config.OptionMap["Footer"] = config.Footer | 	common.OptionMap["SMTPToken"] = "" | ||||||
| 	config.OptionMap["SystemName"] = config.SystemName | 	common.OptionMap["Notice"] = "" | ||||||
| 	config.OptionMap["Logo"] = config.Logo | 	common.OptionMap["About"] = "" | ||||||
| 	config.OptionMap["ServerAddress"] = "" | 	common.OptionMap["HomePageContent"] = "" | ||||||
| 	config.OptionMap["GitHubClientId"] = "" | 	common.OptionMap["Footer"] = common.Footer | ||||||
| 	config.OptionMap["GitHubClientSecret"] = "" | 	common.OptionMap["SystemName"] = common.SystemName | ||||||
| 	config.OptionMap["WeChatServerAddress"] = "" | 	common.OptionMap["Logo"] = common.Logo | ||||||
| 	config.OptionMap["WeChatServerToken"] = "" | 	common.OptionMap["ServerAddress"] = "" | ||||||
| 	config.OptionMap["WeChatAccountQRCodeImageURL"] = "" | 	common.OptionMap["GitHubClientId"] = "" | ||||||
| 	config.OptionMap["TurnstileSiteKey"] = "" | 	common.OptionMap["GitHubClientSecret"] = "" | ||||||
| 	config.OptionMap["TurnstileSecretKey"] = "" | 	common.OptionMap["WeChatServerAddress"] = "" | ||||||
| 	config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser) | 	common.OptionMap["WeChatServerToken"] = "" | ||||||
| 	config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter) | 	common.OptionMap["WeChatAccountQRCodeImageURL"] = "" | ||||||
| 	config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee) | 	common.OptionMap["TurnstileSiteKey"] = "" | ||||||
| 	config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold) | 	common.OptionMap["TurnstileSecretKey"] = "" | ||||||
| 	config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota) | 	common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) | ||||||
| 	config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() | 	common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) | ||||||
| 	config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() | 	common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) | ||||||
| 	config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() | 	common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) | ||||||
| 	config.OptionMap["TopUpLink"] = config.TopUpLink | 	common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) | ||||||
| 	config.OptionMap["ChatLink"] = config.ChatLink | 	common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() | ||||||
| 	config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) | 	common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() | ||||||
| 	config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes) | 	common.OptionMap["TopUpLink"] = common.TopUpLink | ||||||
| 	config.OptionMap["Theme"] = config.Theme | 	common.OptionMap["ChatLink"] = common.ChatLink | ||||||
| 	config.OptionMapRWMutex.Unlock() | 	common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) | ||||||
|  | 	common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) | ||||||
|  | 	common.OptionMapRWMutex.Unlock() | ||||||
| 	loadOptionsFromDatabase() | 	loadOptionsFromDatabase() | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -81,7 +81,7 @@ func loadOptionsFromDatabase() { | |||||||
| 	for _, option := range options { | 	for _, option := range options { | ||||||
| 		err := updateOptionMap(option.Key, option.Value) | 		err := updateOptionMap(option.Key, option.Value) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.SysError("failed to update option map: " + err.Error()) | 			common.SysError("failed to update option map: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -89,7 +89,7 @@ func loadOptionsFromDatabase() { | |||||||
| func SyncOptions(frequency int) { | func SyncOptions(frequency int) { | ||||||
| 	for { | 	for { | ||||||
| 		time.Sleep(time.Duration(frequency) * time.Second) | 		time.Sleep(time.Duration(frequency) * time.Second) | ||||||
| 		logger.SysLog("syncing options from database") | 		common.SysLog("syncing options from database") | ||||||
| 		loadOptionsFromDatabase() | 		loadOptionsFromDatabase() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -111,106 +111,115 @@ func UpdateOption(key string, value string) error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func updateOptionMap(key string, value string) (err error) { | func updateOptionMap(key string, value string) (err error) { | ||||||
| 	config.OptionMapRWMutex.Lock() | 	common.OptionMapRWMutex.Lock() | ||||||
| 	defer config.OptionMapRWMutex.Unlock() | 	defer common.OptionMapRWMutex.Unlock() | ||||||
| 	config.OptionMap[key] = value | 	common.OptionMap[key] = value | ||||||
|  | 	if strings.HasSuffix(key, "Permission") { | ||||||
|  | 		intValue, _ := strconv.Atoi(value) | ||||||
|  | 		switch key { | ||||||
|  | 		case "FileUploadPermission": | ||||||
|  | 			common.FileUploadPermission = intValue | ||||||
|  | 		case "FileDownloadPermission": | ||||||
|  | 			common.FileDownloadPermission = intValue | ||||||
|  | 		case "ImageUploadPermission": | ||||||
|  | 			common.ImageUploadPermission = intValue | ||||||
|  | 		case "ImageDownloadPermission": | ||||||
|  | 			common.ImageDownloadPermission = intValue | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 	if strings.HasSuffix(key, "Enabled") { | 	if strings.HasSuffix(key, "Enabled") { | ||||||
| 		boolValue := value == "true" | 		boolValue := value == "true" | ||||||
| 		switch key { | 		switch key { | ||||||
| 		case "PasswordRegisterEnabled": | 		case "PasswordRegisterEnabled": | ||||||
| 			config.PasswordRegisterEnabled = boolValue | 			common.PasswordRegisterEnabled = boolValue | ||||||
| 		case "PasswordLoginEnabled": | 		case "PasswordLoginEnabled": | ||||||
| 			config.PasswordLoginEnabled = boolValue | 			common.PasswordLoginEnabled = boolValue | ||||||
| 		case "EmailVerificationEnabled": | 		case "EmailVerificationEnabled": | ||||||
| 			config.EmailVerificationEnabled = boolValue | 			common.EmailVerificationEnabled = boolValue | ||||||
| 		case "GitHubOAuthEnabled": | 		case "GitHubOAuthEnabled": | ||||||
| 			config.GitHubOAuthEnabled = boolValue | 			common.GitHubOAuthEnabled = boolValue | ||||||
| 		case "WeChatAuthEnabled": | 		case "WeChatAuthEnabled": | ||||||
| 			config.WeChatAuthEnabled = boolValue | 			common.WeChatAuthEnabled = boolValue | ||||||
| 		case "TurnstileCheckEnabled": | 		case "TurnstileCheckEnabled": | ||||||
| 			config.TurnstileCheckEnabled = boolValue | 			common.TurnstileCheckEnabled = boolValue | ||||||
| 		case "RegisterEnabled": | 		case "RegisterEnabled": | ||||||
| 			config.RegisterEnabled = boolValue | 			common.RegisterEnabled = boolValue | ||||||
| 		case "EmailDomainRestrictionEnabled": | 		case "EmailDomainRestrictionEnabled": | ||||||
| 			config.EmailDomainRestrictionEnabled = boolValue | 			common.EmailDomainRestrictionEnabled = boolValue | ||||||
| 		case "AutomaticDisableChannelEnabled": | 		case "AutomaticDisableChannelEnabled": | ||||||
| 			config.AutomaticDisableChannelEnabled = boolValue | 			common.AutomaticDisableChannelEnabled = boolValue | ||||||
| 		case "AutomaticEnableChannelEnabled": | 		case "AutomaticEnableChannelEnabled": | ||||||
| 			config.AutomaticEnableChannelEnabled = boolValue | 			common.AutomaticEnableChannelEnabled = boolValue | ||||||
| 		case "ApproximateTokenEnabled": | 		case "ApproximateTokenEnabled": | ||||||
| 			config.ApproximateTokenEnabled = boolValue | 			common.ApproximateTokenEnabled = boolValue | ||||||
| 		case "LogConsumeEnabled": | 		case "LogConsumeEnabled": | ||||||
| 			config.LogConsumeEnabled = boolValue | 			common.LogConsumeEnabled = boolValue | ||||||
| 		case "DisplayInCurrencyEnabled": | 		case "DisplayInCurrencyEnabled": | ||||||
| 			config.DisplayInCurrencyEnabled = boolValue | 			common.DisplayInCurrencyEnabled = boolValue | ||||||
| 		case "DisplayTokenStatEnabled": | 		case "DisplayTokenStatEnabled": | ||||||
| 			config.DisplayTokenStatEnabled = boolValue | 			common.DisplayTokenStatEnabled = boolValue | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	switch key { | 	switch key { | ||||||
| 	case "EmailDomainWhitelist": | 	case "EmailDomainWhitelist": | ||||||
| 		config.EmailDomainWhitelist = strings.Split(value, ",") | 		common.EmailDomainWhitelist = strings.Split(value, ",") | ||||||
| 	case "SMTPServer": | 	case "SMTPServer": | ||||||
| 		config.SMTPServer = value | 		common.SMTPServer = value | ||||||
| 	case "SMTPPort": | 	case "SMTPPort": | ||||||
| 		intValue, _ := strconv.Atoi(value) | 		intValue, _ := strconv.Atoi(value) | ||||||
| 		config.SMTPPort = intValue | 		common.SMTPPort = intValue | ||||||
| 	case "SMTPAccount": | 	case "SMTPAccount": | ||||||
| 		config.SMTPAccount = value | 		common.SMTPAccount = value | ||||||
| 	case "SMTPFrom": | 	case "SMTPFrom": | ||||||
| 		config.SMTPFrom = value | 		common.SMTPFrom = value | ||||||
| 	case "SMTPToken": | 	case "SMTPToken": | ||||||
| 		config.SMTPToken = value | 		common.SMTPToken = value | ||||||
| 	case "ServerAddress": | 	case "ServerAddress": | ||||||
| 		config.ServerAddress = value | 		common.ServerAddress = value | ||||||
| 	case "GitHubClientId": | 	case "GitHubClientId": | ||||||
| 		config.GitHubClientId = value | 		common.GitHubClientId = value | ||||||
| 	case "GitHubClientSecret": | 	case "GitHubClientSecret": | ||||||
| 		config.GitHubClientSecret = value | 		common.GitHubClientSecret = value | ||||||
| 	case "Footer": | 	case "Footer": | ||||||
| 		config.Footer = value | 		common.Footer = value | ||||||
| 	case "SystemName": | 	case "SystemName": | ||||||
| 		config.SystemName = value | 		common.SystemName = value | ||||||
| 	case "Logo": | 	case "Logo": | ||||||
| 		config.Logo = value | 		common.Logo = value | ||||||
| 	case "WeChatServerAddress": | 	case "WeChatServerAddress": | ||||||
| 		config.WeChatServerAddress = value | 		common.WeChatServerAddress = value | ||||||
| 	case "WeChatServerToken": | 	case "WeChatServerToken": | ||||||
| 		config.WeChatServerToken = value | 		common.WeChatServerToken = value | ||||||
| 	case "WeChatAccountQRCodeImageURL": | 	case "WeChatAccountQRCodeImageURL": | ||||||
| 		config.WeChatAccountQRCodeImageURL = value | 		common.WeChatAccountQRCodeImageURL = value | ||||||
| 	case "TurnstileSiteKey": | 	case "TurnstileSiteKey": | ||||||
| 		config.TurnstileSiteKey = value | 		common.TurnstileSiteKey = value | ||||||
| 	case "TurnstileSecretKey": | 	case "TurnstileSecretKey": | ||||||
| 		config.TurnstileSecretKey = value | 		common.TurnstileSecretKey = value | ||||||
| 	case "QuotaForNewUser": | 	case "QuotaForNewUser": | ||||||
| 		config.QuotaForNewUser, _ = strconv.Atoi(value) | 		common.QuotaForNewUser, _ = strconv.Atoi(value) | ||||||
| 	case "QuotaForInviter": | 	case "QuotaForInviter": | ||||||
| 		config.QuotaForInviter, _ = strconv.Atoi(value) | 		common.QuotaForInviter, _ = strconv.Atoi(value) | ||||||
| 	case "QuotaForInvitee": | 	case "QuotaForInvitee": | ||||||
| 		config.QuotaForInvitee, _ = strconv.Atoi(value) | 		common.QuotaForInvitee, _ = strconv.Atoi(value) | ||||||
| 	case "QuotaRemindThreshold": | 	case "QuotaRemindThreshold": | ||||||
| 		config.QuotaRemindThreshold, _ = strconv.Atoi(value) | 		common.QuotaRemindThreshold, _ = strconv.Atoi(value) | ||||||
| 	case "PreConsumedQuota": | 	case "PreConsumedQuota": | ||||||
| 		config.PreConsumedQuota, _ = strconv.Atoi(value) | 		common.PreConsumedQuota, _ = strconv.Atoi(value) | ||||||
| 	case "RetryTimes": | 	case "RetryTimes": | ||||||
| 		config.RetryTimes, _ = strconv.Atoi(value) | 		common.RetryTimes, _ = strconv.Atoi(value) | ||||||
| 	case "ModelRatio": | 	case "ModelRatio": | ||||||
| 		err = common.UpdateModelRatioByJSONString(value) | 		err = common.UpdateModelRatioByJSONString(value) | ||||||
| 	case "GroupRatio": | 	case "GroupRatio": | ||||||
| 		err = common.UpdateGroupRatioByJSONString(value) | 		err = common.UpdateGroupRatioByJSONString(value) | ||||||
| 	case "CompletionRatio": |  | ||||||
| 		err = common.UpdateCompletionRatioByJSONString(value) |  | ||||||
| 	case "TopUpLink": | 	case "TopUpLink": | ||||||
| 		config.TopUpLink = value | 		common.TopUpLink = value | ||||||
| 	case "ChatLink": | 	case "ChatLink": | ||||||
| 		config.ChatLink = value | 		common.ChatLink = value | ||||||
| 	case "ChannelDisableThreshold": | 	case "ChannelDisableThreshold": | ||||||
| 		config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) | 		common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) | ||||||
| 	case "QuotaPerUnit": | 	case "QuotaPerUnit": | ||||||
| 		config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) | 		common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) | ||||||
| 	case "Theme": |  | ||||||
| 		config.Theme = value |  | ||||||
| 	} | 	} | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|   | |||||||
| @@ -3,8 +3,8 @@ package model | |||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" |  | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -28,7 +28,7 @@ func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func SearchRedemptions(keyword string) (redemptions []*Redemption, err error) { | func SearchRedemptions(keyword string) (redemptions []*Redemption, err error) { | ||||||
| 	err = DB.Where("id = ? or name LIKE ?", keyword, keyword+"%").Find(&redemptions).Error | 	err = DB.Where("id = ? or name LIKE ?", common.String2Int(keyword), keyword+"%").Find(&redemptions).Error | ||||||
| 	return redemptions, err | 	return redemptions, err | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -68,7 +68,7 @@ func Redeem(key string, userId int) (quota int, err error) { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 		redemption.RedeemedTime = helper.GetTimestamp() | 		redemption.RedeemedTime = common.GetTimestamp() | ||||||
| 		redemption.Status = common.RedemptionCodeStatusUsed | 		redemption.Status = common.RedemptionCodeStatusUsed | ||||||
| 		err = tx.Save(redemption).Error | 		err = tx.Save(redemption).Error | ||||||
| 		return err | 		return err | ||||||
|   | |||||||
| @@ -3,11 +3,8 @@ package model | |||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
|  | 	"one-api/common" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Token struct { | type Token struct { | ||||||
| @@ -41,43 +38,39 @@ func ValidateUserToken(key string) (token *Token, err error) { | |||||||
| 		return nil, errors.New("未提供令牌") | 		return nil, errors.New("未提供令牌") | ||||||
| 	} | 	} | ||||||
| 	token, err = CacheGetTokenByKey(key) | 	token, err = CacheGetTokenByKey(key) | ||||||
| 	if err != nil { | 	if err == nil { | ||||||
| 		logger.SysError("CacheGetTokenByKey failed: " + err.Error()) | 		if token.Status == common.TokenStatusExhausted { | ||||||
| 		if errors.Is(err, gorm.ErrRecordNotFound) { | 			return nil, errors.New("该令牌额度已用尽") | ||||||
| 			return nil, errors.New("无效的令牌") | 		} else if token.Status == common.TokenStatusExpired { | ||||||
|  | 			return nil, errors.New("该令牌已过期") | ||||||
| 		} | 		} | ||||||
| 		return nil, errors.New("令牌验证失败") | 		if token.Status != common.TokenStatusEnabled { | ||||||
| 	} | 			return nil, errors.New("该令牌状态不可用") | ||||||
| 	if token.Status == common.TokenStatusExhausted { | 		} | ||||||
| 		return nil, errors.New("该令牌额度已用尽") | 		if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { | ||||||
| 	} else if token.Status == common.TokenStatusExpired { | 			if !common.RedisEnabled { | ||||||
| 		return nil, errors.New("该令牌已过期") | 				token.Status = common.TokenStatusExpired | ||||||
| 	} | 				err := token.SelectUpdate() | ||||||
| 	if token.Status != common.TokenStatusEnabled { | 				if err != nil { | ||||||
| 		return nil, errors.New("该令牌状态不可用") | 					common.SysError("failed to update token status" + err.Error()) | ||||||
| 	} | 				} | ||||||
| 	if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() { |  | ||||||
| 		if !common.RedisEnabled { |  | ||||||
| 			token.Status = common.TokenStatusExpired |  | ||||||
| 			err := token.SelectUpdate() |  | ||||||
| 			if err != nil { |  | ||||||
| 				logger.SysError("failed to update token status" + err.Error()) |  | ||||||
| 			} | 			} | ||||||
|  | 			return nil, errors.New("该令牌已过期") | ||||||
| 		} | 		} | ||||||
| 		return nil, errors.New("该令牌已过期") | 		if !token.UnlimitedQuota && token.RemainQuota <= 0 { | ||||||
| 	} | 			if !common.RedisEnabled { | ||||||
| 	if !token.UnlimitedQuota && token.RemainQuota <= 0 { | 				// in this case, we can make sure the token is exhausted | ||||||
| 		if !common.RedisEnabled { | 				token.Status = common.TokenStatusExhausted | ||||||
| 			// in this case, we can make sure the token is exhausted | 				err := token.SelectUpdate() | ||||||
| 			token.Status = common.TokenStatusExhausted | 				if err != nil { | ||||||
| 			err := token.SelectUpdate() | 					common.SysError("failed to update token status" + err.Error()) | ||||||
| 			if err != nil { | 				} | ||||||
| 				logger.SysError("failed to update token status" + err.Error()) |  | ||||||
| 			} | 			} | ||||||
|  | 			return nil, errors.New("该令牌额度已用尽") | ||||||
| 		} | 		} | ||||||
| 		return nil, errors.New("该令牌额度已用尽") | 		return token, nil | ||||||
| 	} | 	} | ||||||
| 	return token, nil | 	return nil, errors.New("无效的令牌") | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetTokenByIds(id int, userId int) (*Token, error) { | func GetTokenByIds(id int, userId int) (*Token, error) { | ||||||
| @@ -141,7 +134,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
| 	if config.BatchUpdateEnabled { | 	if common.BatchUpdateEnabled { | ||||||
| 		addNewRecord(BatchUpdateTypeTokenQuota, id, quota) | 		addNewRecord(BatchUpdateTypeTokenQuota, id, quota) | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| @@ -153,7 +146,7 @@ func increaseTokenQuota(id int, quota int) (err error) { | |||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"remain_quota":  gorm.Expr("remain_quota + ?", quota), | 			"remain_quota":  gorm.Expr("remain_quota + ?", quota), | ||||||
| 			"used_quota":    gorm.Expr("used_quota - ?", quota), | 			"used_quota":    gorm.Expr("used_quota - ?", quota), | ||||||
| 			"accessed_time": helper.GetTimestamp(), | 			"accessed_time": common.GetTimestamp(), | ||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	return err | 	return err | ||||||
| @@ -163,7 +156,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
| 	if config.BatchUpdateEnabled { | 	if common.BatchUpdateEnabled { | ||||||
| 		addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) | 		addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| @@ -175,7 +168,7 @@ func decreaseTokenQuota(id int, quota int) (err error) { | |||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"remain_quota":  gorm.Expr("remain_quota - ?", quota), | 			"remain_quota":  gorm.Expr("remain_quota - ?", quota), | ||||||
| 			"used_quota":    gorm.Expr("used_quota + ?", quota), | 			"used_quota":    gorm.Expr("used_quota + ?", quota), | ||||||
| 			"accessed_time": helper.GetTimestamp(), | 			"accessed_time": common.GetTimestamp(), | ||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	return err | 	return err | ||||||
| @@ -199,24 +192,24 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { | |||||||
| 	if userQuota < quota { | 	if userQuota < quota { | ||||||
| 		return errors.New("用户额度不足") | 		return errors.New("用户额度不足") | ||||||
| 	} | 	} | ||||||
| 	quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold | 	quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold | ||||||
| 	noMoreQuota := userQuota-quota <= 0 | 	noMoreQuota := userQuota-quota <= 0 | ||||||
| 	if quotaTooLow || noMoreQuota { | 	if quotaTooLow || noMoreQuota { | ||||||
| 		go func() { | 		go func() { | ||||||
| 			email, err := GetUserEmail(token.UserId) | 			email, err := GetUserEmail(token.UserId) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				logger.SysError("failed to fetch user email: " + err.Error()) | 				common.SysError("failed to fetch user email: " + err.Error()) | ||||||
| 			} | 			} | ||||||
| 			prompt := "您的额度即将用尽" | 			prompt := "您的额度即将用尽" | ||||||
| 			if noMoreQuota { | 			if noMoreQuota { | ||||||
| 				prompt = "您的额度已用尽" | 				prompt = "您的额度已用尽" | ||||||
| 			} | 			} | ||||||
| 			if email != "" { | 			if email != "" { | ||||||
| 				topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) | 				topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) | ||||||
| 				err = common.SendEmail(prompt, email, | 				err = common.SendEmail(prompt, email, | ||||||
| 					fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) | 					fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					logger.SysError("failed to send email" + err.Error()) | 					common.SysError("failed to send email" + err.Error()) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		}() | 		}() | ||||||
|   | |||||||
| @@ -3,12 +3,10 @@ package model | |||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/songquanpeng/one-api/common" | 	"one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/config" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/helper" |  | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"gorm.io/gorm" |  | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
|  | 	"gorm.io/gorm" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // User if you add sensitive fields, don't forget to clean them in setupLogin function. | // User if you add sensitive fields, don't forget to clean them in setupLogin function. | ||||||
| @@ -18,7 +16,7 @@ type User struct { | |||||||
| 	Username         string `json:"username" gorm:"unique;index" validate:"max=12"` | 	Username         string `json:"username" gorm:"unique;index" validate:"max=12"` | ||||||
| 	Password         string `json:"password" gorm:"not null;" validate:"min=8,max=20"` | 	Password         string `json:"password" gorm:"not null;" validate:"min=8,max=20"` | ||||||
| 	DisplayName      string `json:"display_name" gorm:"index" validate:"max=20"` | 	DisplayName      string `json:"display_name" gorm:"index" validate:"max=20"` | ||||||
| 	Role             int    `json:"role" gorm:"type:int;default:1"`   // admin, util | 	Role             int    `json:"role" gorm:"type:int;default:1"`   // admin, common | ||||||
| 	Status           int    `json:"status" gorm:"type:int;default:1"` // enabled, disabled | 	Status           int    `json:"status" gorm:"type:int;default:1"` // enabled, disabled | ||||||
| 	Email            string `json:"email" gorm:"index" validate:"max=50"` | 	Email            string `json:"email" gorm:"index" validate:"max=50"` | ||||||
| 	GitHubId         string `json:"github_id" gorm:"column:github_id;index"` | 	GitHubId         string `json:"github_id" gorm:"column:github_id;index"` | ||||||
| @@ -45,11 +43,8 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func SearchUsers(keyword string) (users []*User, err error) { | func SearchUsers(keyword string) (users []*User, err error) { | ||||||
| 	if !common.UsingPostgreSQL { | 	err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", common.String2Int(keyword), keyword+"%", keyword+"%", keyword+"%").Find(&users).Error | ||||||
| 		err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error |  | ||||||
| 	} else { |  | ||||||
| 		err = DB.Omit("password").Where("username LIKE ? or email LIKE ? or display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").Find(&users).Error |  | ||||||
| 	} |  | ||||||
| 	return users, err | 	return users, err | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -92,24 +87,24 @@ func (user *User) Insert(inviterId int) error { | |||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	user.Quota = config.QuotaForNewUser | 	user.Quota = common.QuotaForNewUser | ||||||
| 	user.AccessToken = helper.GetUUID() | 	user.AccessToken = common.GetUUID() | ||||||
| 	user.AffCode = helper.GetRandomString(4) | 	user.AffCode = common.GetRandomString(4) | ||||||
| 	result := DB.Create(user) | 	result := DB.Create(user) | ||||||
| 	if result.Error != nil { | 	if result.Error != nil { | ||||||
| 		return result.Error | 		return result.Error | ||||||
| 	} | 	} | ||||||
| 	if config.QuotaForNewUser > 0 { | 	if common.QuotaForNewUser > 0 { | ||||||
| 		RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) | 		RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) | ||||||
| 	} | 	} | ||||||
| 	if inviterId != 0 { | 	if inviterId != 0 { | ||||||
| 		if config.QuotaForInvitee > 0 { | 		if common.QuotaForInvitee > 0 { | ||||||
| 			_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee) | 			_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee) | ||||||
| 			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) | 			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) | ||||||
| 		} | 		} | ||||||
| 		if config.QuotaForInviter > 0 { | 		if common.QuotaForInviter > 0 { | ||||||
| 			_ = IncreaseUserQuota(inviterId, config.QuotaForInviter) | 			_ = IncreaseUserQuota(inviterId, common.QuotaForInviter) | ||||||
| 			RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) | 			RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| @@ -144,15 +139,7 @@ func (user *User) ValidateAndFill() (err error) { | |||||||
| 	if user.Username == "" || password == "" { | 	if user.Username == "" || password == "" { | ||||||
| 		return errors.New("用户名或密码为空") | 		return errors.New("用户名或密码为空") | ||||||
| 	} | 	} | ||||||
| 	err = DB.Where("username = ?", user.Username).First(user).Error | 	DB.Where(User{Username: user.Username}).First(user) | ||||||
| 	if err != nil { |  | ||||||
| 		// we must make sure check username firstly |  | ||||||
| 		// consider this case: a malicious user set his username as other's email |  | ||||||
| 		err := DB.Where("email = ?", user.Username).First(user).Error |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errors.New("用户名或密码错误,或用户已被封禁") |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	okay := common.ValidatePasswordAndHash(password, user.Password) | 	okay := common.ValidatePasswordAndHash(password, user.Password) | ||||||
| 	if !okay || user.Status != common.UserStatusEnabled { | 	if !okay || user.Status != common.UserStatusEnabled { | ||||||
| 		return errors.New("用户名或密码错误,或用户已被封禁") | 		return errors.New("用户名或密码错误,或用户已被封禁") | ||||||
| @@ -235,7 +222,7 @@ func IsAdmin(userId int) bool { | |||||||
| 	var user User | 	var user User | ||||||
| 	err := DB.Where("id = ?", userId).Select("role").Find(&user).Error | 	err := DB.Where("id = ?", userId).Select("role").Find(&user).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("no such user " + err.Error()) | 		common.SysError("no such user " + err.Error()) | ||||||
| 		return false | 		return false | ||||||
| 	} | 	} | ||||||
| 	return user.Role >= common.RoleAdminUser | 	return user.Role >= common.RoleAdminUser | ||||||
| @@ -294,7 +281,7 @@ func IncreaseUserQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
| 	if config.BatchUpdateEnabled { | 	if common.BatchUpdateEnabled { | ||||||
| 		addNewRecord(BatchUpdateTypeUserQuota, id, quota) | 		addNewRecord(BatchUpdateTypeUserQuota, id, quota) | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| @@ -310,7 +297,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
| 	if config.BatchUpdateEnabled { | 	if common.BatchUpdateEnabled { | ||||||
| 		addNewRecord(BatchUpdateTypeUserQuota, id, -quota) | 		addNewRecord(BatchUpdateTypeUserQuota, id, -quota) | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| @@ -328,7 +315,7 @@ func GetRootUserEmail() (email string) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | ||||||
| 	if config.BatchUpdateEnabled { | 	if common.BatchUpdateEnabled { | ||||||
| 		addNewRecord(BatchUpdateTypeUsedQuota, id, quota) | 		addNewRecord(BatchUpdateTypeUsedQuota, id, quota) | ||||||
| 		addNewRecord(BatchUpdateTypeRequestCount, id, 1) | 		addNewRecord(BatchUpdateTypeRequestCount, id, 1) | ||||||
| 		return | 		return | ||||||
| @@ -344,7 +331,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | |||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("failed to update user used quota and request count: " + err.Error()) | 		common.SysError("failed to update user used quota and request count: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -355,14 +342,14 @@ func updateUserUsedQuota(id int, quota int) { | |||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("failed to update user used quota: " + err.Error()) | 		common.SysError("failed to update user used quota: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func updateUserRequestCount(id int, count int) { | func updateUserRequestCount(id int, count int) { | ||||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error | 	err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.SysError("failed to update user request count: " + err.Error()) | 		common.SysError("failed to update user request count: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,8 +1,7 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/songquanpeng/one-api/common/config" | 	"one-api/common" | ||||||
| 	"github.com/songquanpeng/one-api/common/logger" |  | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| @@ -29,7 +28,7 @@ func init() { | |||||||
| func InitBatchUpdater() { | func InitBatchUpdater() { | ||||||
| 	go func() { | 	go func() { | ||||||
| 		for { | 		for { | ||||||
| 			time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second) | 			time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) | ||||||
| 			batchUpdate() | 			batchUpdate() | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| @@ -46,7 +45,7 @@ func addNewRecord(type_ int, id int, value int) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func batchUpdate() { | func batchUpdate() { | ||||||
| 	logger.SysLog("batch update started") | 	common.SysLog("batch update started") | ||||||
| 	for i := 0; i < BatchUpdateTypeCount; i++ { | 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||||
| 		batchUpdateLocks[i].Lock() | 		batchUpdateLocks[i].Lock() | ||||||
| 		store := batchUpdateStores[i] | 		store := batchUpdateStores[i] | ||||||
| @@ -58,12 +57,12 @@ func batchUpdate() { | |||||||
| 			case BatchUpdateTypeUserQuota: | 			case BatchUpdateTypeUserQuota: | ||||||
| 				err := increaseUserQuota(key, value) | 				err := increaseUserQuota(key, value) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					logger.SysError("failed to batch update user quota: " + err.Error()) | 					common.SysError("failed to batch update user quota: " + err.Error()) | ||||||
| 				} | 				} | ||||||
| 			case BatchUpdateTypeTokenQuota: | 			case BatchUpdateTypeTokenQuota: | ||||||
| 				err := increaseTokenQuota(key, value) | 				err := increaseTokenQuota(key, value) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					logger.SysError("failed to batch update token quota: " + err.Error()) | 					common.SysError("failed to batch update token quota: " + err.Error()) | ||||||
| 				} | 				} | ||||||
| 			case BatchUpdateTypeUsedQuota: | 			case BatchUpdateTypeUsedQuota: | ||||||
| 				updateUserUsedQuota(key, value) | 				updateUserUsedQuota(key, value) | ||||||
| @@ -74,5 +73,5 @@ func batchUpdate() { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	logger.SysLog("batch update finished") | 	common.SysLog("batch update finished") | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										30
									
								
								providers/aigc2d/balance.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								providers/aigc2d/balance.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | |||||||
|  | package aigc2d | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	"one-api/providers/base" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func (p *Aigc2dProvider) Balance(channel *model.Channel) (float64, error) { | ||||||
|  | 	fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "") | ||||||
|  | 	headers := p.GetRequestHeaders() | ||||||
|  |  | ||||||
|  | 	client := common.NewClient() | ||||||
|  | 	req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 发送请求 | ||||||
|  | 	var response base.BalanceResponse | ||||||
|  | 	_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		return 0, errors.New(errWithCode.OpenAIError.Message) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	channel.UpdateBalance(response.TotalAvailable) | ||||||
|  |  | ||||||
|  | 	return response.TotalAvailable, nil | ||||||
|  | } | ||||||
							
								
								
									
										20
									
								
								providers/aigc2d/base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								providers/aigc2d/base.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | |||||||
|  | package aigc2d | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"one-api/providers/base" | ||||||
|  | 	"one-api/providers/openai" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Aigc2dProviderFactory struct{} | ||||||
|  |  | ||||||
|  | func (f Aigc2dProviderFactory) Create(c *gin.Context) base.ProviderInterface { | ||||||
|  | 	return &Aigc2dProvider{ | ||||||
|  | 		OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aigc2d.com"), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Aigc2dProvider struct { | ||||||
|  | 	*openai.OpenAIProvider | ||||||
|  | } | ||||||
							
								
								
									
										35
									
								
								providers/aiproxy/balance.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								providers/aiproxy/balance.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,35 @@ | |||||||
|  | package aiproxy | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func (p *AIProxyProvider) Balance(channel *model.Channel) (float64, error) { | ||||||
|  | 	fullRequestURL := "https://aiproxy.io/api/report/getUserOverview" | ||||||
|  | 	headers := make(map[string]string) | ||||||
|  | 	headers["Api-Key"] = channel.Key | ||||||
|  |  | ||||||
|  | 	client := common.NewClient() | ||||||
|  | 	req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 发送请求 | ||||||
|  | 	var response AIProxyUserOverviewResponse | ||||||
|  | 	_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		return 0, errors.New(errWithCode.OpenAIError.Message) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !response.Success { | ||||||
|  | 		return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	channel.UpdateBalance(response.Data.TotalPoints) | ||||||
|  |  | ||||||
|  | 	return response.Data.TotalPoints, nil | ||||||
|  | } | ||||||
							
								
								
									
										20
									
								
								providers/aiproxy/base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								providers/aiproxy/base.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | |||||||
|  | package aiproxy | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"one-api/providers/base" | ||||||
|  | 	"one-api/providers/openai" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type AIProxyProviderFactory struct{} | ||||||
|  |  | ||||||
|  | func (f AIProxyProviderFactory) Create(c *gin.Context) base.ProviderInterface { | ||||||
|  | 	return &AIProxyProvider{ | ||||||
|  | 		OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aiproxy.io"), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AIProxyProvider struct { | ||||||
|  | 	*openai.OpenAIProvider | ||||||
|  | } | ||||||
							
								
								
									
										10
									
								
								providers/aiproxy/type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								providers/aiproxy/type.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | |||||||
|  | package aiproxy | ||||||
|  |  | ||||||
|  | type AIProxyUserOverviewResponse struct { | ||||||
|  | 	Success   bool   `json:"success"` | ||||||
|  | 	Message   string `json:"message"` | ||||||
|  | 	ErrorCode int    `json:"error_code"` | ||||||
|  | 	Data      struct { | ||||||
|  | 		TotalPoints float64 `json:"totalPoints"` | ||||||
|  | 	} `json:"data"` | ||||||
|  | } | ||||||
							
								
								
									
										52
									
								
								providers/ali/base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								providers/ali/base.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | |||||||
|  | package ali | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"strings" | ||||||
|  |  | ||||||
|  | 	"one-api/providers/base" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // 定义供应商工厂 | ||||||
|  | type AliProviderFactory struct{} | ||||||
|  |  | ||||||
|  | // 创建 AliProvider | ||||||
|  | // https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation | ||||||
|  | func (f AliProviderFactory) Create(c *gin.Context) base.ProviderInterface { | ||||||
|  | 	return &AliProvider{ | ||||||
|  | 		BaseProvider: base.BaseProvider{ | ||||||
|  | 			BaseURL:         "https://dashscope.aliyuncs.com", | ||||||
|  | 			ChatCompletions: "/api/v1/services/aigc/text-generation/generation", | ||||||
|  | 			Embeddings:      "/api/v1/services/embeddings/text-embedding/text-embedding", | ||||||
|  | 			Context:         c, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliProvider struct { | ||||||
|  | 	base.BaseProvider | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *AliProvider) GetFullRequestURL(requestURL string, modelName string) string { | ||||||
|  | 	baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") | ||||||
|  |  | ||||||
|  | 	if modelName == "qwen-vl-plus" { | ||||||
|  | 		requestURL = "/api/v1/services/aigc/multimodal-generation/generation" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return fmt.Sprintf("%s%s", baseURL, requestURL) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // 获取请求头 | ||||||
|  | func (p *AliProvider) GetRequestHeaders() (headers map[string]string) { | ||||||
|  | 	headers = make(map[string]string) | ||||||
|  | 	p.CommonRequestHeaders(headers) | ||||||
|  | 	headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key) | ||||||
|  | 	if p.Channel.Other != "" { | ||||||
|  | 		headers["X-DashScope-Plugin"] = p.Channel.Other | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return headers | ||||||
|  | } | ||||||
							
								
								
									
										258
									
								
								providers/ali/chat.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										258
									
								
								providers/ali/chat.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,258 @@ | |||||||
|  | package ali | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/types" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // 阿里云响应处理 | ||||||
|  | func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||||
|  | 	if aliResponse.Code != "" { | ||||||
|  | 		errWithCode = &types.OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: types.OpenAIError{ | ||||||
|  | 				Message: aliResponse.Message, | ||||||
|  | 				Type:    aliResponse.Code, | ||||||
|  | 				Param:   aliResponse.RequestId, | ||||||
|  | 				Code:    aliResponse.Code, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	OpenAIResponse = types.ChatCompletionResponse{ | ||||||
|  | 		ID:      aliResponse.RequestId, | ||||||
|  | 		Object:  "chat.completion", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   aliResponse.Model, | ||||||
|  | 		Choices: aliResponse.Output.ToChatCompletionChoices(), | ||||||
|  | 		Usage: &types.Usage{ | ||||||
|  | 			PromptTokens:     aliResponse.Usage.InputTokens, | ||||||
|  | 			CompletionTokens: aliResponse.Usage.OutputTokens, | ||||||
|  | 			TotalTokens:      aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | const AliEnableSearchModelSuffix = "-internet" | ||||||
|  |  | ||||||
|  | // 获取聊天请求体 | ||||||
|  | func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest { | ||||||
|  | 	messages := make([]AliMessage, 0, len(request.Messages)) | ||||||
|  | 	for i := 0; i < len(request.Messages); i++ { | ||||||
|  | 		message := request.Messages[i] | ||||||
|  | 		if request.Model != "qwen-vl-plus" { | ||||||
|  | 			messages = append(messages, AliMessage{ | ||||||
|  | 				Content: message.StringContent(), | ||||||
|  | 				Role:    strings.ToLower(message.Role), | ||||||
|  | 			}) | ||||||
|  | 		} else { | ||||||
|  | 			openaiContent := message.ParseContent() | ||||||
|  | 			var parts []AliMessagePart | ||||||
|  | 			for _, part := range openaiContent { | ||||||
|  | 				if part.Type == types.ContentTypeText { | ||||||
|  | 					parts = append(parts, AliMessagePart{ | ||||||
|  | 						Text: part.Text, | ||||||
|  | 					}) | ||||||
|  | 				} else if part.Type == types.ContentTypeImageURL { | ||||||
|  | 					parts = append(parts, AliMessagePart{ | ||||||
|  | 						Image: part.ImageURL.URL, | ||||||
|  | 					}) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			messages = append(messages, AliMessage{ | ||||||
|  | 				Content: parts, | ||||||
|  | 				Role:    strings.ToLower(message.Role), | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	enableSearch := false | ||||||
|  | 	aliModel := request.Model | ||||||
|  | 	if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) { | ||||||
|  | 		enableSearch = true | ||||||
|  | 		aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &AliChatRequest{ | ||||||
|  | 		Model: aliModel, | ||||||
|  | 		Input: AliInput{ | ||||||
|  | 			Messages: messages, | ||||||
|  | 		}, | ||||||
|  | 		Parameters: AliParameters{ | ||||||
|  | 			ResultFormat:      "message", | ||||||
|  | 			EnableSearch:      enableSearch, | ||||||
|  | 			IncrementalOutput: request.Stream, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // 聊天 | ||||||
|  | func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||||
|  |  | ||||||
|  | 	requestBody := p.getChatRequestBody(request) | ||||||
|  |  | ||||||
|  | 	fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) | ||||||
|  | 	headers := p.GetRequestHeaders() | ||||||
|  | 	if request.Stream { | ||||||
|  | 		headers["Accept"] = "text/event-stream" | ||||||
|  | 		headers["X-DashScope-SSE"] = "enable" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	client := common.NewClient() | ||||||
|  | 	req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if request.Stream { | ||||||
|  | 		usage, errWithCode = p.sendStreamRequest(req, request.Model) | ||||||
|  | 		if errWithCode != nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if usage == nil { | ||||||
|  | 			usage = &types.Usage{ | ||||||
|  | 				PromptTokens:     0, | ||||||
|  | 				CompletionTokens: 0, | ||||||
|  | 				TotalTokens:      0, | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 	} else { | ||||||
|  | 		aliResponse := &AliChatResponse{ | ||||||
|  | 			Model: request.Model, | ||||||
|  | 		} | ||||||
|  | 		errWithCode = p.SendRequest(req, aliResponse, false) | ||||||
|  | 		if errWithCode != nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		usage = &types.Usage{ | ||||||
|  | 			PromptTokens:     aliResponse.Usage.InputTokens, | ||||||
|  | 			CompletionTokens: aliResponse.Usage.OutputTokens, | ||||||
|  | 			TotalTokens:      aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // 阿里云响应转OpenAI响应 | ||||||
|  | func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse { | ||||||
|  | 	// chatChoice := aliResponse.Output.ToChatCompletionChoices() | ||||||
|  | 	// jsonBody, _ := json.MarshalIndent(chatChoice, "", "  ") | ||||||
|  | 	// fmt.Println("requestBody:", string(jsonBody)) | ||||||
|  | 	var choice types.ChatCompletionStreamChoice | ||||||
|  | 	choice.Index = aliResponse.Output.Choices[0].Index | ||||||
|  | 	choice.Delta.Content = aliResponse.Output.Choices[0].Message.StringContent() | ||||||
|  | 	// fmt.Println("choice.Delta.Content:", chatChoice[0].Message) | ||||||
|  | 	if aliResponse.Output.Choices[0].FinishReason != "null" { | ||||||
|  | 		finishReason := aliResponse.Output.Choices[0].FinishReason | ||||||
|  | 		choice.FinishReason = &finishReason | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	response := types.ChatCompletionStreamResponse{ | ||||||
|  | 		ID:      aliResponse.RequestId, | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   aliResponse.Model, | ||||||
|  | 		Choices: []types.ChatCompletionStreamChoice{choice}, | ||||||
|  | 	} | ||||||
|  | 	return &response | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // 发送流请求 | ||||||
|  | func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||||
|  | 	defer req.Body.Close() | ||||||
|  |  | ||||||
|  | 	usage = &types.Usage{} | ||||||
|  | 	// 发送请求 | ||||||
|  | 	client := common.GetHttpClient(p.Channel.Proxy) | ||||||
|  | 	resp, err := client.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | 	common.PutHttpClient(client) | ||||||
|  |  | ||||||
|  | 	if common.IsFailureStatusCode(resp) { | ||||||
|  | 		return nil, common.HandleErrorResp(resp) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	defer resp.Body.Close() | ||||||
|  |  | ||||||
|  | 	scanner := bufio.NewScanner(resp.Body) | ||||||
|  | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
|  | 		if atEOF && len(data) == 0 { | ||||||
|  | 			return 0, nil, nil | ||||||
|  | 		} | ||||||
|  | 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||||
|  | 			return i + 1, data[0:i], nil | ||||||
|  | 		} | ||||||
|  | 		if atEOF { | ||||||
|  | 			return len(data), data, nil | ||||||
|  | 		} | ||||||
|  | 		return 0, nil, nil | ||||||
|  | 	}) | ||||||
|  | 	dataChan := make(chan string) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		for scanner.Scan() { | ||||||
|  | 			data := scanner.Text() | ||||||
|  | 			if len(data) < 5 { // ignore blank line or wrong format | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			if data[:5] != "data:" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			data = data[5:] | ||||||
|  | 			dataChan <- data | ||||||
|  | 		} | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  | 	common.SetEventStreamHeaders(p.Context) | ||||||
|  | 	lastResponseText := "" | ||||||
|  | 	index := 0 | ||||||
|  | 	p.Context.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case data := <-dataChan: | ||||||
|  | 			var aliResponse AliChatResponse | ||||||
|  | 			err := json.Unmarshal([]byte(data), &aliResponse) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			if aliResponse.Usage.OutputTokens != 0 { | ||||||
|  | 				usage.PromptTokens = aliResponse.Usage.InputTokens | ||||||
|  | 				usage.CompletionTokens = aliResponse.Usage.OutputTokens | ||||||
|  | 				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | ||||||
|  | 			} | ||||||
|  | 			aliResponse.Model = model | ||||||
|  | 			aliResponse.Output.Choices[0].Index = index | ||||||
|  | 			index++ | ||||||
|  | 			response := p.streamResponseAli2OpenAI(&aliResponse) | ||||||
|  | 			response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) | ||||||
|  | 			lastResponseText = aliResponse.Output.Choices[0].Message.StringContent() | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	return | ||||||
|  | } | ||||||
							
								
								
									
										73
									
								
								providers/ali/embeddings.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								providers/ali/embeddings.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,73 @@ | |||||||
|  | package ali | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/types" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // 嵌入请求处理 | ||||||
|  | func (aliResponse *AliEmbeddingResponse) ResponseHandler(resp *http.Response) (any, *types.OpenAIErrorWithStatusCode) { | ||||||
|  | 	if aliResponse.Code != "" { | ||||||
|  | 		return nil, &types.OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: types.OpenAIError{ | ||||||
|  | 				Message: aliResponse.Message, | ||||||
|  | 				Type:    aliResponse.Code, | ||||||
|  | 				Param:   aliResponse.RequestId, | ||||||
|  | 				Code:    aliResponse.Code, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	openAIEmbeddingResponse := &types.EmbeddingResponse{ | ||||||
|  | 		Object: "list", | ||||||
|  | 		Data:   make([]types.Embedding, 0, len(aliResponse.Output.Embeddings)), | ||||||
|  | 		Model:  "text-embedding-v1", | ||||||
|  | 		Usage:  &types.Usage{TotalTokens: aliResponse.Usage.TotalTokens}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, item := range aliResponse.Output.Embeddings { | ||||||
|  | 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{ | ||||||
|  | 			Object:    `embedding`, | ||||||
|  | 			Index:     item.TextIndex, | ||||||
|  | 			Embedding: item.Embedding, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return openAIEmbeddingResponse, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // 获取嵌入请求体 | ||||||
|  | func (p *AliProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest { | ||||||
|  | 	return &AliEmbeddingRequest{ | ||||||
|  | 		Model: "text-embedding-v1", | ||||||
|  | 		Input: struct { | ||||||
|  | 			Texts []string `json:"texts"` | ||||||
|  | 		}{ | ||||||
|  | 			Texts: request.ParseInput(), | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *AliProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||||
|  |  | ||||||
|  | 	requestBody := p.getEmbeddingsRequestBody(request) | ||||||
|  | 	fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model) | ||||||
|  | 	headers := p.GetRequestHeaders() | ||||||
|  |  | ||||||
|  | 	client := common.NewClient() | ||||||
|  | 	req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	aliEmbeddingResponse := &AliEmbeddingResponse{} | ||||||
|  | 	errWithCode = p.SendRequest(req, aliEmbeddingResponse, false) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens} | ||||||
|  |  | ||||||
|  | 	return usage, nil | ||||||
|  | } | ||||||
							
								
								
									
										98
									
								
								providers/ali/type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								providers/ali/type.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,98 @@ | |||||||
|  | package ali | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"one-api/types" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type AliError struct { | ||||||
|  | 	Code      string `json:"code"` | ||||||
|  | 	Message   string `json:"message"` | ||||||
|  | 	RequestId string `json:"request_id"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliUsage struct { | ||||||
|  | 	InputTokens  int `json:"input_tokens"` | ||||||
|  | 	OutputTokens int `json:"output_tokens"` | ||||||
|  | 	TotalTokens  int `json:"total_tokens"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliMessage struct { | ||||||
|  | 	Content any    `json:"content"` | ||||||
|  | 	Role    string `json:"role"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliMessagePart struct { | ||||||
|  | 	Text  string `json:"text,omitempty"` | ||||||
|  | 	Image string `json:"image,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliInput struct { | ||||||
|  | 	// Prompt  string       `json:"prompt"` | ||||||
|  | 	Messages []AliMessage `json:"messages"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliParameters struct { | ||||||
|  | 	TopP              float64 `json:"top_p,omitempty"` | ||||||
|  | 	TopK              int     `json:"top_k,omitempty"` | ||||||
|  | 	Seed              uint64  `json:"seed,omitempty"` | ||||||
|  | 	EnableSearch      bool    `json:"enable_search,omitempty"` | ||||||
|  | 	IncrementalOutput bool    `json:"incremental_output,omitempty"` | ||||||
|  | 	ResultFormat      string  `json:"result_format,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliChatRequest struct { | ||||||
|  | 	Model      string        `json:"model"` | ||||||
|  | 	Input      AliInput      `json:"input"` | ||||||
|  | 	Parameters AliParameters `json:"parameters,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliChoice struct { | ||||||
|  | 	FinishReason string                      `json:"finish_reason"` | ||||||
|  | 	Message      types.ChatCompletionMessage `json:"message"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliOutput struct { | ||||||
|  | 	Choices []types.ChatCompletionChoice `json:"choices"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (o *AliOutput) ToChatCompletionChoices() []types.ChatCompletionChoice { | ||||||
|  | 	for i := range o.Choices { | ||||||
|  | 		_, ok := o.Choices[i].Message.Content.(string) | ||||||
|  | 		if ok { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		o.Choices[i].Message.Content = o.Choices[i].Message.ParseContent() | ||||||
|  | 	} | ||||||
|  | 	return o.Choices | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliChatResponse struct { | ||||||
|  | 	Output AliOutput `json:"output"` | ||||||
|  | 	Usage  AliUsage  `json:"usage"` | ||||||
|  | 	Model  string    `json:"model,omitempty"` | ||||||
|  | 	AliError | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliEmbeddingRequest struct { | ||||||
|  | 	Model string `json:"model"` | ||||||
|  | 	Input struct { | ||||||
|  | 		Texts []string `json:"texts"` | ||||||
|  | 	} `json:"input"` | ||||||
|  | 	Parameters *struct { | ||||||
|  | 		TextType string `json:"text_type,omitempty"` | ||||||
|  | 	} `json:"parameters,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliEmbedding struct { | ||||||
|  | 	Embedding []float64 `json:"embedding"` | ||||||
|  | 	TextIndex int       `json:"text_index"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliEmbeddingResponse struct { | ||||||
|  | 	Output struct { | ||||||
|  | 		Embeddings []AliEmbedding `json:"embeddings"` | ||||||
|  | 	} `json:"output"` | ||||||
|  | 	Usage AliUsage `json:"usage"` | ||||||
|  | 	AliError | ||||||
|  | } | ||||||
							
								
								
									
										30
									
								
								providers/api2d/balance.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								providers/api2d/balance.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | |||||||
|  | package api2d | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	"one-api/providers/base" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) { | ||||||
|  | 	fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "") | ||||||
|  | 	headers := p.GetRequestHeaders() | ||||||
|  |  | ||||||
|  | 	client := common.NewClient() | ||||||
|  | 	req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 发送请求 | ||||||
|  | 	var response base.BalanceResponse | ||||||
|  | 	_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		return 0, errors.New(errWithCode.OpenAIError.Message) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	channel.UpdateBalance(response.TotalAvailable) | ||||||
|  |  | ||||||
|  | 	return response.TotalAvailable, nil | ||||||
|  | } | ||||||
							
								
								
									
										21
									
								
								providers/api2d/base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								providers/api2d/base.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | |||||||
|  | package api2d | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"one-api/providers/base" | ||||||
|  | 	"one-api/providers/openai" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Api2dProviderFactory struct{} | ||||||
|  |  | ||||||
|  | // 创建 Api2dProvider | ||||||
|  | func (f Api2dProviderFactory) Create(c *gin.Context) base.ProviderInterface { | ||||||
|  | 	return &Api2dProvider{ | ||||||
|  | 		OpenAIProvider: openai.CreateOpenAIProvider(c, "https://oa.api2d.net"), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Api2dProvider struct { | ||||||
|  | 	*openai.OpenAIProvider | ||||||
|  | } | ||||||
							
								
								
									
										30
									
								
								providers/api2gpt/balance.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								providers/api2gpt/balance.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | |||||||
|  | package api2gpt | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/model" | ||||||
|  | 	"one-api/providers/base" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func (p *Api2gptProvider) Balance(channel *model.Channel) (float64, error) { | ||||||
|  | 	fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "") | ||||||
|  | 	headers := p.GetRequestHeaders() | ||||||
|  |  | ||||||
|  | 	client := common.NewClient() | ||||||
|  | 	req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 发送请求 | ||||||
|  | 	var response base.BalanceResponse | ||||||
|  | 	_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		return 0, errors.New(errWithCode.OpenAIError.Message) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	channel.UpdateBalance(response.TotalAvailable) | ||||||
|  |  | ||||||
|  | 	return response.TotalRemaining, nil | ||||||
|  | } | ||||||
							
								
								
									
										20
									
								
								providers/api2gpt/base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								providers/api2gpt/base.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | |||||||
|  | package api2gpt | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"one-api/providers/base" | ||||||
|  | 	"one-api/providers/openai" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Api2gptProviderFactory struct{} | ||||||
|  |  | ||||||
|  | func (f Api2gptProviderFactory) Create(c *gin.Context) base.ProviderInterface { | ||||||
|  | 	return &Api2gptProvider{ | ||||||
|  | 		OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.api2gpt.com"), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Api2gptProvider struct { | ||||||
|  | 	*openai.OpenAIProvider | ||||||
|  | } | ||||||
							
								
								
									
										36
									
								
								providers/azure/base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								providers/azure/base.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | |||||||
|  | package azure | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"one-api/providers/base" | ||||||
|  | 	"one-api/providers/openai" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type AzureProviderFactory struct{} | ||||||
|  |  | ||||||
|  | // 创建 AzureProvider | ||||||
|  | func (f AzureProviderFactory) Create(c *gin.Context) base.ProviderInterface { | ||||||
|  | 	return &AzureProvider{ | ||||||
|  | 		OpenAIProvider: openai.OpenAIProvider{ | ||||||
|  | 			BaseProvider: base.BaseProvider{ | ||||||
|  | 				BaseURL:             "", | ||||||
|  | 				Completions:         "/completions", | ||||||
|  | 				ChatCompletions:     "/chat/completions", | ||||||
|  | 				Embeddings:          "/embeddings", | ||||||
|  | 				AudioTranscriptions: "/audio/transcriptions", | ||||||
|  | 				AudioTranslations:   "/audio/translations", | ||||||
|  | 				ImagesGenerations:   "/images/generations", | ||||||
|  | 				// ImagesEdit:          "/images/edit", | ||||||
|  | 				// ImagesVariations:    "/images/variations", | ||||||
|  | 				Context: c, | ||||||
|  | 				// AudioSpeech:         "/audio/speech", | ||||||
|  | 			}, | ||||||
|  | 			IsAzure: true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AzureProvider struct { | ||||||
|  | 	openai.OpenAIProvider | ||||||
|  | } | ||||||
							
								
								
									
										103
									
								
								providers/azure/image_generations.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								providers/azure/image_generations.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,103 @@ | |||||||
|  | package azure | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"one-api/providers/openai" | ||||||
|  | 	"one-api/types" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||||
|  | 	if c.Status == "canceled" || c.Status == "failed" { | ||||||
|  | 		errWithCode = &types.OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: types.OpenAIError{ | ||||||
|  | 				Message: c.Error.Message, | ||||||
|  | 				Type:    "one_api_error", | ||||||
|  | 				Code:    c.Error.Code, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		} | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	operation_location := resp.Header.Get("operation-location") | ||||||
|  | 	if operation_location == "" { | ||||||
|  | 		return nil, common.ErrorWrapper(errors.New("image url is empty"), "get_images_url_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	client := common.NewClient() | ||||||
|  | 	req, err := client.NewRequest("GET", operation_location, common.WithHeader(c.Header)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, common.ErrorWrapper(err, "get_images_request_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	getImageAzureResponse := ImageAzureResponse{} | ||||||
|  | 	for i := 0; i < 3; i++ { | ||||||
|  | 		// 休眠 2 秒 | ||||||
|  | 		time.Sleep(2 * time.Second) | ||||||
|  | 		_, errWithCode = common.SendRequest(req, &getImageAzureResponse, false, c.Proxy) | ||||||
|  | 		fmt.Println("getImageAzureResponse", getImageAzureResponse) | ||||||
|  | 		if errWithCode != nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if getImageAzureResponse.Status == "canceled" || getImageAzureResponse.Status == "failed" { | ||||||
|  | 			return nil, &types.OpenAIErrorWithStatusCode{ | ||||||
|  | 				OpenAIError: types.OpenAIError{ | ||||||
|  | 					Message: c.Error.Message, | ||||||
|  | 					Type:    "get_images_request_failed", | ||||||
|  | 					Code:    c.Error.Code, | ||||||
|  | 				}, | ||||||
|  | 				StatusCode: resp.StatusCode, | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		if getImageAzureResponse.Status == "succeeded" { | ||||||
|  | 			return getImageAzureResponse.Result, nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil, common.ErrorWrapper(errors.New("get image Timeout"), "get_images_url_failed", http.StatusInternalServerError) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *AzureProvider) ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { | ||||||
|  |  | ||||||
|  | 	requestBody, err := p.GetRequestBody(&request, isModelMapped) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fullRequestURL := p.GetFullRequestURL(p.ImagesGenerations, request.Model) | ||||||
|  | 	headers := p.GetRequestHeaders() | ||||||
|  |  | ||||||
|  | 	client := common.NewClient() | ||||||
|  | 	req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if request.Model == "dall-e-2" { | ||||||
|  | 		imageAzureResponse := &ImageAzureResponse{ | ||||||
|  | 			Header: headers, | ||||||
|  | 			Proxy:  p.Channel.Proxy, | ||||||
|  | 		} | ||||||
|  | 		errWithCode = p.SendRequest(req, imageAzureResponse, false) | ||||||
|  | 	} else { | ||||||
|  | 		openAIProviderImageResponseResponse := &openai.OpenAIProviderImageResponseResponse{} | ||||||
|  | 		errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	usage = &types.Usage{ | ||||||
|  | 		PromptTokens:     promptTokens, | ||||||
|  | 		CompletionTokens: 0, | ||||||
|  | 		TotalTokens:      promptTokens, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return | ||||||
|  | } | ||||||
							
								
								
									
										22
									
								
								providers/azure/type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								providers/azure/type.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | |||||||
|  | package azure | ||||||
|  |  | ||||||
|  | import "one-api/types" | ||||||
|  |  | ||||||
|  | type ImageAzureResponse struct { | ||||||
|  | 	ID      string              `json:"id,omitempty"` | ||||||
|  | 	Created int64               `json:"created,omitempty"` | ||||||
|  | 	Expires int64               `json:"expires,omitempty"` | ||||||
|  | 	Result  types.ImageResponse `json:"result,omitempty"` | ||||||
|  | 	Status  string              `json:"status,omitempty"` | ||||||
|  | 	Error   ImageAzureError     `json:"error,omitempty"` | ||||||
|  | 	Header  map[string]string   `json:"header,omitempty"` | ||||||
|  | 	Proxy   string              `json:"proxy,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ImageAzureError struct { | ||||||
|  | 	Code       string   `json:"code,omitempty"` | ||||||
|  | 	Target     string   `json:"target,omitempty"` | ||||||
|  | 	Message    string   `json:"message,omitempty"` | ||||||
|  | 	Details    []string `json:"details,omitempty"` | ||||||
|  | 	InnerError any      `json:"innererror,omitempty"` | ||||||
|  | } | ||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user