mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-25 19:03:43 +08:00 
			
		
		
		
	Compare commits
	
		
			277 Commits
		
	
	
		
			v0.5.11
			...
			v0.6.8-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | ec6ad24810 | ||
|  | c4fe57c165 | ||
|  | 274fcf3d76 | ||
|  | 0fc07ea558 | ||
|  | 1ce1e529ee | ||
|  | d936817de9 | ||
|  | fecaece71b | ||
|  | c135d74f13 | ||
|  | d0369b114f | ||
|  | b21b3b5b46 | ||
|  | ae1cd29f94 | ||
|  | f25aaf7752 | ||
|  | b70a07e814 | ||
|  | 34cb147a74 | ||
|  | 8cc1ee6360 | ||
|  | 5a58426859 | ||
|  | 254b9777c0 | ||
|  | 114c44c6e7 | ||
|  | a3c7e15aed | ||
|  | 3777517f64 | ||
|  | 9fc5f427dc | ||
|  | 864a467886 | ||
|  | ed78b5340b | ||
|  | fee69e7c20 | ||
|  | 9d23a44dbf | ||
|  | 6e4cfb20d5 | ||
|  | ff196b75a7 | ||
|  | 279caf82dc | ||
|  | b1520b308b | ||
|  | ed717211aa | ||
|  | 6ccf3f3cfc | ||
|  | f74577141c | ||
|  | 6aafb7a99e | ||
|  | c1971870fa | ||
|  | f83894c83f | ||
|  | e9981fff36 | ||
|  | 98669d5d48 | ||
|  | 9321427c6e | ||
|  | ceea4c6d4a | ||
|  | b53e00a9b3 | ||
|  | 332c8db0b3 | ||
|  | 3be28da57b | ||
|  | fa74ba0eaa | ||
|  | a9211d66f6 | ||
|  | 07b2fd58d6 | ||
|  | 0acee9a065 | ||
|  | f965469e8a | ||
|  | 03ea60532a | ||
|  | 2457d00afb | ||
|  | 91b80ae879 | ||
|  | 2720e1a358 | ||
|  | 71f4403fd5 | ||
|  | 1f76c80553 | ||
|  | 7e027d2bd0 | ||
|  | 30f373b623 | ||
|  | 1c2654320e | ||
|  | 6cffb116b7 | ||
|  | a84c7b38b7 | ||
|  | 1bd14af47b | ||
|  | 6170b91d1c | ||
|  | 04b49aa0ec | ||
|  | ef88497f25 | ||
|  | 007906216d | ||
|  | e64e7707a0 | ||
|  | ea210b6ed7 | ||
|  | 9026ec7510 | ||
|  | c317872097 | ||
|  | da0842272c | ||
|  | 0a650b85b4 | ||
|  | 24f026d18e | ||
|  | cb33e8aad5 | ||
|  | 779b747e9e | ||
|  | 3d149fedf4 | ||
|  | 83517f687c | ||
|  | e30ebda0fe | ||
|  | d87c55f542 | ||
|  | e5b3e37c46 | ||
|  | 8de489cf06 | ||
|  | d14e4aa01b | ||
|  | 541182102e | ||
|  | b2679cca65 | ||
|  | 8572fac7a2 | ||
|  | a2a00dfbc3 | ||
|  | 129282f4a9 | ||
|  | a873cbd392 | ||
|  | 35ba1da984 | ||
|  | 2369025842 | ||
|  | f452bd481e | ||
|  | ddee58df36 | ||
|  | 520a62e704 | ||
|  | fc9a784950 | ||
|  | 1a0b039bcf | ||
|  | 7bf61f9165 | ||
|  | a10232f43a | ||
|  | af543ab8ec | ||
|  | e086da05b1 | ||
|  | 3af4649b52 | ||
|  | 52c32c0b4a | ||
|  | 3fe2863ff7 | ||
|  | acf8cb6248 | ||
|  | 572fc9ffb8 | ||
|  | 569c04acb0 | ||
|  | 961b4108e6 | ||
|  | 0b8ccb94eb | ||
|  | f586ae0ad8 | ||
|  | 24ed170e7b | ||
|  | f70506eac1 | ||
|  | 8f4d78e24d | ||
|  | cd2707692f | ||
|  | 2ab7d25a80 | ||
|  | f9d914873f | ||
|  | 880e12c855 | ||
|  | 0cb224e62e | ||
|  | a44fb5d482 | ||
|  | eec41849ec | ||
|  | d4347e7a35 | ||
|  | b50b43eb65 | ||
|  | 348adc2b02 | ||
|  | dcf24b98dc | ||
|  | af679e04f4 | ||
|  | 93cbca6a9f | ||
|  | 840ef80d94 | ||
|  | 9a2662af0d | ||
|  | 77f9e75654 | ||
|  | 5b41f57423 | ||
|  | 0bb7db0b44 | ||
|  | 4d61b9937b | ||
|  | 68605800af | ||
|  | c49778c254 | ||
|  | f02c7138ea | ||
|  | ca3228855a | ||
|  | f8cc63f00b | ||
|  | 0a37aa4cbd | ||
|  | 054b00b725 | ||
|  | 76569bb0b6 | ||
|  | 1994256bac | ||
|  | 1f80b0a39f | ||
|  | f73f2e51df | ||
|  | 6f036bd0c9 | ||
|  | fb90747c23 | ||
|  | ed70881a58 | ||
|  | 8b9fa3d6e4 | ||
|  | 8b9813d63b | ||
|  | dc7aaf2de5 | ||
|  | 065da8ef8c | ||
|  | e3cfb1fa52 | ||
|  | f89ae5ad58 | ||
|  | 06a3fc5421 | ||
|  | a9c464ec5a | ||
|  | 3f3c13c98c | ||
|  | 2ba28c72cb | ||
|  | 5e81e19bc8 | ||
|  | 96d7a99312 | ||
|  | 24be9de098 | ||
|  | 5b349efff9 | ||
|  | f76c46d648 | ||
|  | cdfdeea3b4 | ||
|  | 56ddbb842a | ||
|  | 99f81a267c | ||
|  | c243cd5535 | ||
|  | e96b173abe | ||
|  | 4ae311e964 | ||
|  | b14cb748d8 | ||
|  | ade19ba4a2 | ||
|  | 4d86d021c4 | ||
|  | 7a44adb5a7 | ||
|  | 9821bc7281 | ||
|  | 08831881f1 | ||
|  | 0eb2272bb7 | ||
|  | 704ec1a827 | ||
|  | 1d7470d6ad | ||
|  | 1185303346 | ||
|  | c212fcf8d7 | ||
|  | c285e000cc | ||
|  | d25ed4c009 | ||
|  | 7400885fbb | ||
|  | 11af81eb39 | ||
|  | 205aba694f | ||
|  | 8dac3afebc | ||
|  | a07791bf93 | ||
|  | 4bb662c0e4 | ||
|  | 4998d58319 | ||
|  | 190203cf8f | ||
|  | 6325c8e0b4 | ||
|  | b204f6d82b | ||
|  | 752639560f | ||
|  | 996f4d99dd | ||
|  | ebfee3b46c | ||
|  | 3e2e805d61 | ||
|  | 3edf7247c4 | ||
|  | 0926b6206b | ||
|  | 7cd57f3125 | ||
|  | 66efabd5ae | ||
|  | 8ede66a896 | ||
|  | b169173860 | ||
|  | f33555ae78 | ||
|  | c28ec10795 | ||
|  | e3767cbb07 | ||
|  | be9eb59fbb | ||
|  | 89e111ac69 | ||
|  | 2dcef85285 | ||
|  | 79d0cd378a | ||
|  | e99150bdb9 | ||
|  | a72e5fcc9e | ||
|  | 0710f8cd66 | ||
|  | 49cad7d4a5 | ||
|  | a90161cf00 | ||
|  | a45fc7d736 | ||
|  | 45940dcb12 | ||
|  | 969042b001 | ||
|  | 7e7369dbc4 | ||
|  | e54e647170 | ||
|  | 358920c858 | ||
|  | 1ea598c773 | ||
|  | 796be42487 | ||
|  | 5b50eb94e5 | ||
|  | 71c61365eb | ||
|  | b09f979b80 | ||
|  | 12440874b0 | ||
|  | 6ebc99460e | ||
|  | 27ad8bfb98 | ||
|  | 8388aa537f | ||
|  | 2346bf70af | ||
|  | f05b403ca5 | ||
|  | b33616df44 | ||
|  | cf16f44970 | ||
|  | bf2e26a48f | ||
|  | 4fb22ad4ce | ||
|  | 95cfb8e8c9 | ||
|  | c6ace985c2 | ||
|  | 10a926b8f3 | ||
|  | 2df877a352 | ||
|  | 9d8967f7d3 | ||
|  | b35f3523d3 | ||
|  | 82e916b5ff | ||
|  | de18d6fe16 | ||
|  | 1d0b7fb5ae | ||
|  | f9490bb72e | ||
|  | 76467285e8 | ||
|  | df1fd9aa81 | ||
|  | 614c2e0442 | ||
|  | eac6a0b9aa | ||
|  | b747cdbc6f | ||
|  | 6b27d6659a | ||
|  | dc5b781191 | ||
|  | c880b4a9a3 | ||
|  | 565ea58e68 | ||
|  | f141a37a9e | ||
|  | 5b78886ad3 | ||
|  | 87c7c4f0e6 | ||
|  | 4c4a873890 | ||
|  | 0664bdfda1 | ||
|  | 32387d9c20 | ||
|  | bd888f2eb7 | ||
|  | cece77e533 | ||
|  | 2a5468e23c | ||
|  | d0e415893b | ||
|  | 6cf5ce9a7a | ||
|  | f598b9df87 | ||
|  | 532c50d212 | ||
|  | 2acc2f5017 | ||
|  | 604ac56305 | ||
|  | 9383b638a6 | ||
|  | 28d512a675 | ||
|  | de9a58ca0b | ||
|  | 1aa374ccfb | ||
|  | d548a01c59 | ||
|  | 2cd1a78203 | ||
|  | b9d3cb0c45 | ||
|  | ea407f0054 | ||
|  | 26e2e646cb | ||
|  | 4f214c48c6 | ||
|  | 2d760d4a01 | ||
|  | e2ed0399f0 | ||
|  | eed9f5fdf0 | ||
|  | f2c51a494c | ||
|  | 8a4d6f3327 | 
							
								
								
									
										3
									
								
								.env.example
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								.env.example
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | PORT=3000 | ||||||
|  | DEBUG=false | ||||||
|  | HTTPS_PROXY=http://localhost:7890 | ||||||
							
								
								
									
										47
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | |||||||
|  | name: CI | ||||||
|  |  | ||||||
|  | # This setup assumes that you run the unit tests with code coverage in the same | ||||||
|  | # workflow that will also print the coverage report as comment to the pull request.  | ||||||
|  | # Therefore, you need to trigger this workflow when a pull request is (re)opened or | ||||||
|  | # when new code is pushed to the branch of the pull request. In addition, you also | ||||||
|  | # need to trigger this workflow when new code is pushed to the main branch because  | ||||||
|  | # we need to upload the code coverage results as artifact for the main branch as | ||||||
|  | # well since it will be the baseline code coverage. | ||||||
|  | #  | ||||||
|  | # We do not want to trigger the workflow for pushes to *any* branch because this | ||||||
|  | # would trigger our jobs twice on pull requests (once from "push" event and once | ||||||
|  | # from "pull_request->synchronize") | ||||||
|  | on: | ||||||
|  |   pull_request: | ||||||
|  |     types: [opened, reopened, synchronize] | ||||||
|  |   push: | ||||||
|  |     branches: | ||||||
|  |       - 'main' | ||||||
|  |  | ||||||
|  | jobs: | ||||||
|  |   unit_tests: | ||||||
|  |     name: "Unit tests" | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |     steps: | ||||||
|  |       - name: Checkout repository | ||||||
|  |         uses: actions/checkout@v4 | ||||||
|  |  | ||||||
|  |       - name: Setup Go | ||||||
|  |         uses: actions/setup-go@v4 | ||||||
|  |         with: | ||||||
|  |           go-version: ^1.22 | ||||||
|  |  | ||||||
|  |       # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a  | ||||||
|  |       # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt") | ||||||
|  |       # in the next step as well as the next job. | ||||||
|  |       - name: Test | ||||||
|  |         run: go test -cover -coverprofile=coverage.txt ./... | ||||||
|  |       - uses: codecov/codecov-action@v4 | ||||||
|  |         with: | ||||||
|  |           token: ${{ secrets.CODECOV_TOKEN }} | ||||||
|  |  | ||||||
|  |   commit_lint: | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |     steps: | ||||||
|  |       - uses: actions/checkout@v3 | ||||||
|  |       - uses: wagoid/commitlint-github-action@v6 | ||||||
							
								
								
									
										9
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							| @@ -3,7 +3,7 @@ name: Publish Docker image (amd64, English) | |||||||
| on: | on: | ||||||
|   push: |   push: | ||||||
|     tags: |     tags: | ||||||
|       - '*' |       - 'v*.*.*' | ||||||
|   workflow_dispatch: |   workflow_dispatch: | ||||||
|     inputs: |     inputs: | ||||||
|       name: |       name: | ||||||
| @@ -20,6 +20,13 @@ jobs: | |||||||
|       - name: Check out the repo |       - name: Check out the repo | ||||||
|         uses: actions/checkout@v3 |         uses: actions/checkout@v3 | ||||||
|  |  | ||||||
|  |       - name: Check repository URL | ||||||
|  |         run: | | ||||||
|  |           REPO_URL=$(git config --get remote.origin.url) | ||||||
|  |           if [[ $REPO_URL == *"pro" ]]; then | ||||||
|  |             exit 1 | ||||||
|  |           fi       | ||||||
|  |  | ||||||
|       - name: Save version info |       - name: Save version info | ||||||
|         run: | |         run: | | ||||||
|           git describe --tags > VERSION  |           git describe --tags > VERSION  | ||||||
|   | |||||||
							
								
								
									
										9
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -3,7 +3,7 @@ name: Publish Docker image (amd64) | |||||||
| on: | on: | ||||||
|   push: |   push: | ||||||
|     tags: |     tags: | ||||||
|       - '*' |       - 'v*.*.*' | ||||||
|   workflow_dispatch: |   workflow_dispatch: | ||||||
|     inputs: |     inputs: | ||||||
|       name: |       name: | ||||||
| @@ -20,6 +20,13 @@ jobs: | |||||||
|       - name: Check out the repo |       - name: Check out the repo | ||||||
|         uses: actions/checkout@v3 |         uses: actions/checkout@v3 | ||||||
|  |  | ||||||
|  |       - name: Check repository URL | ||||||
|  |         run: | | ||||||
|  |           REPO_URL=$(git config --get remote.origin.url) | ||||||
|  |           if [[ $REPO_URL == *"pro" ]]; then | ||||||
|  |             exit 1 | ||||||
|  |           fi         | ||||||
|  |  | ||||||
|       - name: Save version info |       - name: Save version info | ||||||
|         run: | |         run: | | ||||||
|           git describe --tags > VERSION  |           git describe --tags > VERSION  | ||||||
|   | |||||||
							
								
								
									
										9
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -3,7 +3,7 @@ name: Publish Docker image (arm64) | |||||||
| on: | on: | ||||||
|   push: |   push: | ||||||
|     tags: |     tags: | ||||||
|       - '*' |       - 'v*.*.*' | ||||||
|       - '!*-alpha*' |       - '!*-alpha*' | ||||||
|   workflow_dispatch: |   workflow_dispatch: | ||||||
|     inputs: |     inputs: | ||||||
| @@ -21,6 +21,13 @@ jobs: | |||||||
|       - name: Check out the repo |       - name: Check out the repo | ||||||
|         uses: actions/checkout@v3 |         uses: actions/checkout@v3 | ||||||
|  |  | ||||||
|  |       - name: Check repository URL | ||||||
|  |         run: | | ||||||
|  |           REPO_URL=$(git config --get remote.origin.url) | ||||||
|  |           if [[ $REPO_URL == *"pro" ]]; then | ||||||
|  |             exit 1 | ||||||
|  |           fi | ||||||
|  |  | ||||||
|       - name: Save version info |       - name: Save version info | ||||||
|         run: | |         run: | | ||||||
|           git describe --tags > VERSION  |           git describe --tags > VERSION  | ||||||
|   | |||||||
							
								
								
									
										12
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -5,7 +5,7 @@ permissions: | |||||||
| on: | on: | ||||||
|   push: |   push: | ||||||
|     tags: |     tags: | ||||||
|       - '*' |       - 'v*.*.*' | ||||||
|       - '!*-alpha*' |       - '!*-alpha*' | ||||||
|   workflow_dispatch: |   workflow_dispatch: | ||||||
|     inputs: |     inputs: | ||||||
| @@ -20,10 +20,16 @@ jobs: | |||||||
|         uses: actions/checkout@v3 |         uses: actions/checkout@v3 | ||||||
|         with: |         with: | ||||||
|           fetch-depth: 0 |           fetch-depth: 0 | ||||||
|  |       - name: Check repository URL | ||||||
|  |         run: | | ||||||
|  |           REPO_URL=$(git config --get remote.origin.url) | ||||||
|  |           if [[ $REPO_URL == *"pro" ]]; then | ||||||
|  |             exit 1 | ||||||
|  |           fi | ||||||
|       - uses: actions/setup-node@v3 |       - uses: actions/setup-node@v3 | ||||||
|         with: |         with: | ||||||
|           node-version: 16 |           node-version: 16 | ||||||
|       - name: Build Frontend (theme default) |       - name: Build Frontend | ||||||
|         env: |         env: | ||||||
|           CI: "" |           CI: "" | ||||||
|         run: | |         run: | | ||||||
| @@ -38,7 +44,7 @@ jobs: | |||||||
|       - name: Build Backend (amd64) |       - name: Build Backend (amd64) | ||||||
|         run: | |         run: | | ||||||
|           go mod download |           go mod download | ||||||
|           go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api |           go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api | ||||||
|  |  | ||||||
|       - name: Build Backend (arm64) |       - name: Build Backend (arm64) | ||||||
|         run: | |         run: | | ||||||
|   | |||||||
							
								
								
									
										12
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -5,7 +5,7 @@ permissions: | |||||||
| on: | on: | ||||||
|   push: |   push: | ||||||
|     tags: |     tags: | ||||||
|       - '*' |       - 'v*.*.*' | ||||||
|       - '!*-alpha*' |       - '!*-alpha*' | ||||||
|   workflow_dispatch: |   workflow_dispatch: | ||||||
|     inputs: |     inputs: | ||||||
| @@ -20,10 +20,16 @@ jobs: | |||||||
|         uses: actions/checkout@v3 |         uses: actions/checkout@v3 | ||||||
|         with: |         with: | ||||||
|           fetch-depth: 0 |           fetch-depth: 0 | ||||||
|  |       - name: Check repository URL | ||||||
|  |         run: | | ||||||
|  |           REPO_URL=$(git config --get remote.origin.url) | ||||||
|  |           if [[ $REPO_URL == *"pro" ]]; then | ||||||
|  |             exit 1 | ||||||
|  |           fi | ||||||
|       - uses: actions/setup-node@v3 |       - uses: actions/setup-node@v3 | ||||||
|         with: |         with: | ||||||
|           node-version: 16 |           node-version: 16 | ||||||
|       - name: Build Frontend (theme default) |       - name: Build Frontend | ||||||
|         env: |         env: | ||||||
|           CI: "" |           CI: "" | ||||||
|         run: | |         run: | | ||||||
| @@ -38,7 +44,7 @@ jobs: | |||||||
|       - name: Build Backend |       - name: Build Backend | ||||||
|         run: | |         run: | | ||||||
|           go mod download |           go mod download | ||||||
|           go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos |           go build -ldflags "-X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api-macos | ||||||
|       - name: Release |       - name: Release | ||||||
|         uses: softprops/action-gh-release@v1 |         uses: softprops/action-gh-release@v1 | ||||||
|         if: startsWith(github.ref, 'refs/tags/') |         if: startsWith(github.ref, 'refs/tags/') | ||||||
|   | |||||||
							
								
								
									
										12
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -5,7 +5,7 @@ permissions: | |||||||
| on: | on: | ||||||
|   push: |   push: | ||||||
|     tags: |     tags: | ||||||
|       - '*' |       - 'v*.*.*' | ||||||
|       - '!*-alpha*' |       - '!*-alpha*' | ||||||
|   workflow_dispatch: |   workflow_dispatch: | ||||||
|     inputs: |     inputs: | ||||||
| @@ -23,10 +23,16 @@ jobs: | |||||||
|         uses: actions/checkout@v3 |         uses: actions/checkout@v3 | ||||||
|         with: |         with: | ||||||
|           fetch-depth: 0 |           fetch-depth: 0 | ||||||
|  |       - name: Check repository URL | ||||||
|  |         run: | | ||||||
|  |           REPO_URL=$(git config --get remote.origin.url) | ||||||
|  |           if [[ $REPO_URL == *"pro" ]]; then | ||||||
|  |             exit 1 | ||||||
|  |           fi | ||||||
|       - uses: actions/setup-node@v3 |       - uses: actions/setup-node@v3 | ||||||
|         with: |         with: | ||||||
|           node-version: 16 |           node-version: 16 | ||||||
|       - name: Build Frontend (theme default) |       - name: Build Frontend | ||||||
|         env: |         env: | ||||||
|           CI: "" |           CI: "" | ||||||
|         run: | |         run: | | ||||||
| @@ -41,7 +47,7 @@ jobs: | |||||||
|       - name: Build Backend |       - name: Build Backend | ||||||
|         run: | |         run: | | ||||||
|           go mod download |           go mod download | ||||||
|           go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe |           go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api.exe | ||||||
|       - name: Release |       - name: Release | ||||||
|         uses: softprops/action-gh-release@v1 |         uses: softprops/action-gh-release@v1 | ||||||
|         if: startsWith(github.ref, 'refs/tags/') |         if: startsWith(github.ref, 'refs/tags/') | ||||||
|   | |||||||
							
								
								
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -6,4 +6,7 @@ upload | |||||||
| build | build | ||||||
| *.db-journal | *.db-journal | ||||||
| logs | logs | ||||||
| data | data | ||||||
|  | /web/node_modules | ||||||
|  | cmd.md | ||||||
|  | .env | ||||||
| @@ -12,6 +12,10 @@ WORKDIR /web/berry | |||||||
| RUN npm install | RUN npm install | ||||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||||
|  |  | ||||||
|  | WORKDIR /web/air | ||||||
|  | RUN npm install | ||||||
|  | RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||||
|  |  | ||||||
| FROM golang AS builder2 | FROM golang AS builder2 | ||||||
|  |  | ||||||
| ENV GO111MODULE=on \ | ENV GO111MODULE=on \ | ||||||
| @@ -23,7 +27,7 @@ ADD go.mod go.sum ./ | |||||||
| RUN go mod download | RUN go mod download | ||||||
| COPY . . | COPY . . | ||||||
| COPY --from=builder /web/build ./web/build | COPY --from=builder /web/build ./web/build | ||||||
| RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api | RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api | ||||||
|  |  | ||||||
| FROM alpine | FROM alpine | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										28
									
								
								README.en.md
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								README.en.md
									
									
									
									
									
								
							| @@ -101,7 +101,7 @@ Nginx reference configuration: | |||||||
| ``` | ``` | ||||||
| server{ | server{ | ||||||
|    server_name openai.justsong.cn;  # Modify your domain name accordingly |    server_name openai.justsong.cn;  # Modify your domain name accordingly | ||||||
|     |  | ||||||
|    location / { |    location / { | ||||||
|           client_max_body_size  64m; |           client_max_body_size  64m; | ||||||
|           proxy_http_version 1.1; |           proxy_http_version 1.1; | ||||||
| @@ -132,14 +132,14 @@ The initial account username is `root` and password is `123456`. | |||||||
| 1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source: | 1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source: | ||||||
|    ```shell |    ```shell | ||||||
|    git clone https://github.com/songquanpeng/one-api.git |    git clone https://github.com/songquanpeng/one-api.git | ||||||
|     |  | ||||||
|    # Build the frontend |    # Build the frontend | ||||||
|    cd one-api/web |    cd one-api/web/default | ||||||
|    npm install |    npm install | ||||||
|    npm run build |    npm run build | ||||||
|     |  | ||||||
|    # Build the backend |    # Build the backend | ||||||
|    cd .. |    cd ../.. | ||||||
|    go mod download |    go mod download | ||||||
|    go build -ldflags "-s -w" -o one-api |    go build -ldflags "-s -w" -o one-api | ||||||
|    ``` |    ``` | ||||||
| @@ -241,17 +241,19 @@ If the channel ID is not provided, load balancing will be used to distribute the | |||||||
|     + Example: `SESSION_SECRET=random_string` |     + Example: `SESSION_SECRET=random_string` | ||||||
| 3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0. | 3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0. | ||||||
|     + Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` |     + Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||||
| 4. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. | 4. `LOG_SQL_DSN`: When set, a separate database will be used for the `logs` table; please use MySQL or PostgreSQL. | ||||||
|  |     + Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs` | ||||||
|  | 5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. | ||||||
|     + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` |     + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||||
| 5. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. | 6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. | ||||||
|     + Example: `SYNC_FREQUENCY=60` |     + Example: `SYNC_FREQUENCY=60` | ||||||
| 6. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. | 7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. | ||||||
|     + Example: `NODE_TYPE=slave` |     + Example: `NODE_TYPE=slave` | ||||||
| 7. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. | 8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. | ||||||
|     + Example: `CHANNEL_UPDATE_FREQUENCY=1440` |     + Example: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||||
| 8. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. | 9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. | ||||||
|     + Example: `CHANNEL_TEST_FREQUENCY=1440` |     + Example: `CHANNEL_TEST_FREQUENCY=1440` | ||||||
| 9. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. | 10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. | ||||||
|     + Example: `POLLING_INTERVAL=5` |     + Example: `POLLING_INTERVAL=5` | ||||||
|  |  | ||||||
| ### Command Line Parameters | ### Command Line Parameters | ||||||
| @@ -285,7 +287,9 @@ If the channel ID is not provided, load balancing will be used to distribute the | |||||||
|     + Double-check that your interface address and API Key are correct. |     + Double-check that your interface address and API Key are correct. | ||||||
|  |  | ||||||
| ## Related Projects | ## Related Projects | ||||||
| [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM | * [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM | ||||||
|  | * [VChart](https://github.com/VisActor/VChart):  More than just a cross-platform charting library, but also an expressive data storyteller. | ||||||
|  | * [VMind](https://github.com/VisActor/VMind):  Not just automatic, but also fantastic. Open-source solution for intelligent visualization. | ||||||
|  |  | ||||||
| ## Note | ## Note | ||||||
| This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. | This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. | ||||||
|   | |||||||
							
								
								
									
										17
									
								
								README.ja.md
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								README.ja.md
									
									
									
									
									
								
							| @@ -135,12 +135,12 @@ sudo service nginx restart | |||||||
|    git clone https://github.com/songquanpeng/one-api.git |    git clone https://github.com/songquanpeng/one-api.git | ||||||
|  |  | ||||||
|    # フロントエンドのビルド |    # フロントエンドのビルド | ||||||
|    cd one-api/web |    cd one-api/web/default | ||||||
|    npm install |    npm install | ||||||
|    npm run build |    npm run build | ||||||
|  |  | ||||||
|    # バックエンドのビルド |    # バックエンドのビルド | ||||||
|    cd .. |    cd ../.. | ||||||
|    go mod download |    go mod download | ||||||
|    go build -ldflags "-s -w" -o one-api |    go build -ldflags "-s -w" -o one-api | ||||||
|    ``` |    ``` | ||||||
| @@ -242,17 +242,18 @@ graph LR | |||||||
|     + 例: `SESSION_SECRET=random_string` |     + 例: `SESSION_SECRET=random_string` | ||||||
| 3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。 | 3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。 | ||||||
|     + 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` |     + 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` | ||||||
| 4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。 | 4. `LOG_SQL_DSN`: を設定すると、`logs`テーブルには独立したデータベースが使用されます。MySQLまたはPostgreSQLを使用してください。 | ||||||
|  | 5. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。 | ||||||
|     + 例: `FRONTEND_BASE_URL=https://openai.justsong.cn` |     + 例: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||||
| 5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 | 6. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 | ||||||
|     + 例: `SYNC_FREQUENCY=60` |     + 例: `SYNC_FREQUENCY=60` | ||||||
| 6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。 | 7. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。 | ||||||
|     + 例: `NODE_TYPE=slave` |     + 例: `NODE_TYPE=slave` | ||||||
| 7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 | 8. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 | ||||||
|     + 例: `CHANNEL_UPDATE_FREQUENCY=1440` |     + 例: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||||
| 8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 | 9. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 | ||||||
|     + 例: `CHANNEL_TEST_FREQUENCY=1440` |     + 例: `CHANNEL_TEST_FREQUENCY=1440` | ||||||
| 9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 | 10. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 | ||||||
|     + 例: `POLLING_INTERVAL=5` |     + 例: `POLLING_INTERVAL=5` | ||||||
|  |  | ||||||
| ### コマンドラインパラメータ | ### コマンドラインパラメータ | ||||||
|   | |||||||
							
								
								
									
										89
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										89
									
								
								README.md
									
									
									
									
									
								
							| @@ -53,7 +53,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
|  |  | ||||||
| > [!NOTE] | > [!NOTE] | ||||||
| > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 | > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 | ||||||
| >  | > | ||||||
| > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 | > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 | ||||||
|  |  | ||||||
| > [!WARNING] | > [!WARNING] | ||||||
| @@ -65,21 +65,36 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
| ## 功能 | ## 功能 | ||||||
| 1. 支持多种大模型: | 1. 支持多种大模型: | ||||||
|    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) |    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||||
|    + [x] [Anthropic Claude 系列模型](https://anthropic.com) |    + [x] [Anthropic Claude 系列模型](https://anthropic.com) (支持 AWS Claude) | ||||||
|    + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) |    + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) | ||||||
|  |    + [x] [Mistral 系列模型](https://mistral.ai/) | ||||||
|  |    + [x] [字节跳动豆包大模型](https://console.volcengine.com/ark/region:ark+cn-beijing/model) | ||||||
|    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) |    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||||
|    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) |    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) | ||||||
|    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) |    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||||
|    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) |    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) | ||||||
|    + [x] [360 智脑](https://ai.360.cn) |    + [x] [360 智脑](https://ai.360.cn) | ||||||
|    + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) |    + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) | ||||||
|  |    + [x] [Moonshot AI](https://platform.moonshot.cn/) | ||||||
|  |    + [x] [百川大模型](https://platform.baichuan-ai.com) | ||||||
|  |    + [x] [MINIMAX](https://api.minimax.chat/) | ||||||
|  |    + [x] [Groq](https://wow.groq.com/) | ||||||
|  |    + [x] [Ollama](https://github.com/ollama/ollama) | ||||||
|  |    + [x] [零一万物](https://platform.lingyiwanwu.com/) | ||||||
|  |    + [x] [阶跃星辰](https://platform.stepfun.com/) | ||||||
|  |    + [x] [Coze](https://www.coze.com/) | ||||||
|  |    + [x] [Cohere](https://cohere.com/) | ||||||
|  |    + [x] [DeepSeek](https://www.deepseek.com/) | ||||||
|  |    + [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) | ||||||
|  |    + [x] [DeepL](https://www.deepl.com/) | ||||||
|  |    + [x] [together.ai](https://www.together.ai/) | ||||||
| 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | ||||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||||
| 5. 支持**多机部署**,[详见此处](#多机部署)。 | 5. 支持**多机部署**,[详见此处](#多机部署)。 | ||||||
| 6. 支持**令牌管理**,设置令牌的过期时间和额度。 | 6. 支持**令牌管理**,设置令牌的过期时间、额度、允许的 IP 范围以及允许的模型访问。 | ||||||
| 7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 | 7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 | ||||||
| 8. 支持**通道管理**,批量创建通道。 | 8. 支持**渠道管理**,批量创建渠道。 | ||||||
| 9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 | 9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 | ||||||
| 10. 支持渠道**设置模型列表**。 | 10. 支持渠道**设置模型列表**。 | ||||||
| 11. 支持**查看额度明细**。 | 11. 支持**查看额度明细**。 | ||||||
| @@ -93,13 +108,15 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
| 19. 支持丰富的**自定义**设置, | 19. 支持丰富的**自定义**设置, | ||||||
|     1. 支持自定义系统名称,logo 以及页脚。 |     1. 支持自定义系统名称,logo 以及页脚。 | ||||||
|     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 |     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 | ||||||
| 20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。 | 20. 支持通过系统访问令牌调用管理 API,进而**在无需二开的情况下扩展和自定义** One API 的功能,详情请参考此处 [API 文档](./docs/API.md)。。 | ||||||
| 21. 支持 Cloudflare Turnstile 用户校验。 | 21. 支持 Cloudflare Turnstile 用户校验。 | ||||||
| 22. 支持用户管理,支持**多种用户登录注册方式**: | 22. 支持用户管理,支持**多种用户登录注册方式**: | ||||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 |     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||||
|  |     + 支持使用飞书进行授权登录。 | ||||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 |     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 |     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||||
| 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 | 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 | ||||||
|  | 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 | ||||||
|  |  | ||||||
| ## 部署 | ## 部署 | ||||||
| ### 基于 Docker 进行部署 | ### 基于 Docker 进行部署 | ||||||
| @@ -127,7 +144,7 @@ Nginx 的参考配置: | |||||||
| ``` | ``` | ||||||
| server{ | server{ | ||||||
|    server_name openai.justsong.cn;  # 请根据实际情况修改你的域名 |    server_name openai.justsong.cn;  # 请根据实际情况修改你的域名 | ||||||
|     |  | ||||||
|    location / { |    location / { | ||||||
|           client_max_body_size  64m; |           client_max_body_size  64m; | ||||||
|           proxy_http_version 1.1; |           proxy_http_version 1.1; | ||||||
| @@ -172,14 +189,14 @@ docker-compose ps | |||||||
| 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: | 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: | ||||||
|    ```shell |    ```shell | ||||||
|    git clone https://github.com/songquanpeng/one-api.git |    git clone https://github.com/songquanpeng/one-api.git | ||||||
|     |  | ||||||
|    # 构建前端 |    # 构建前端 | ||||||
|    cd one-api/web |    cd one-api/web/default | ||||||
|    npm install |    npm install | ||||||
|    npm run build |    npm run build | ||||||
|     |  | ||||||
|    # 构建后端 |    # 构建后端 | ||||||
|    cd .. |    cd ../.. | ||||||
|    go mod download |    go mod download | ||||||
|    go build -ldflags "-s -w" -o one-api |    go build -ldflags "-s -w" -o one-api | ||||||
|    ```` |    ```` | ||||||
| @@ -304,7 +321,7 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashbo | |||||||
| 例如对于 OpenAI 的官方库: | 例如对于 OpenAI 的官方库: | ||||||
| ```bash | ```bash | ||||||
| OPENAI_API_KEY="sk-xxxxxx" | OPENAI_API_KEY="sk-xxxxxx" | ||||||
| OPENAI_API_BASE="https://<HOST>:<PORT>/v1"  | OPENAI_API_BASE="https://<HOST>:<PORT>/v1" | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ```mermaid | ```mermaid | ||||||
| @@ -323,6 +340,7 @@ graph LR | |||||||
| 不加的话将会使用负载均衡的方式使用多个渠道。 | 不加的话将会使用负载均衡的方式使用多个渠道。 | ||||||
|  |  | ||||||
| ### 环境变量 | ### 环境变量 | ||||||
|  | > One API 支持从 `.env` 文件中读取环境变量,请参照 `.env.example` 文件,使用时请将其重命名为 `.env`。 | ||||||
| 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 | 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 | ||||||
|    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` |    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||||
|    + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 |    + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 | ||||||
| @@ -340,35 +358,44 @@ graph LR | |||||||
|      + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 |      + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 | ||||||
|        + 如果报错 `Error 1040: Too many connections`,请适当减小该值。 |        + 如果报错 `Error 1040: Too many connections`,请适当减小该值。 | ||||||
|      + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 |      + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 | ||||||
| 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | 4. `LOG_SQL_DSN`:设置之后将为 `logs` 表使用独立的数据库,请使用 MySQL 或 PostgreSQL。 | ||||||
|  | 5. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | ||||||
|    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` |    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||||
| 5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 | 6. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||||
|    + 例子:`MEMORY_CACHE_ENABLED=true` |    + 例子:`MEMORY_CACHE_ENABLED=true` | ||||||
| 6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 | 7. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 | ||||||
|    + 例子:`SYNC_FREQUENCY=60` |    + 例子:`SYNC_FREQUENCY=60` | ||||||
| 7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | 8. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | ||||||
|    + 例子:`NODE_TYPE=slave` |    + 例子:`NODE_TYPE=slave` | ||||||
| 8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | 9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | ||||||
|    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` |    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` | ||||||
| 9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | 10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | ||||||
|    + 例子:`CHANNEL_TEST_FREQUENCY=1440` | 11. 例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||||
| 10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | 12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||||
|     + 例子:`POLLING_INTERVAL=5` |     + 例子:`POLLING_INTERVAL=5` | ||||||
| 11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | 13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||||
|     + 例子:`BATCH_UPDATE_ENABLED=true` |     + 例子:`BATCH_UPDATE_ENABLED=true` | ||||||
|     + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 |     + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 | ||||||
| 12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | 14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||||
|     + 例子:`BATCH_UPDATE_INTERVAL=5` |     + 例子:`BATCH_UPDATE_INTERVAL=5` | ||||||
| 13. 请求频率限制: | 15. 请求频率限制: | ||||||
|     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 |     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 | ||||||
|     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 |     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 | ||||||
| 14. 编码器缓存设置: | 16. 编码器缓存设置: | ||||||
|     + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 |     + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 | ||||||
|     + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 |     + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 | ||||||
| 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | 17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | ||||||
| 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | 18. `RELAY_PROXY`:设置后使用该代理来请求 API。 | ||||||
| 17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 | 19. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。 | ||||||
| 18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 | 20. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。 | ||||||
|  | 21. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | ||||||
|  | 22. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 | ||||||
|  | 23. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 | ||||||
|  | 24. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 | ||||||
|  | 25. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 | ||||||
|  | 26. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 | ||||||
|  | 27. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 | ||||||
|  | 28. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 | ||||||
|  |  | ||||||
| ### 命令行参数 | ### 命令行参数 | ||||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||||
| @@ -407,7 +434,7 @@ https://openai.justsong.cn | |||||||
|    + 检查你的接口地址和 API Key 有没有填对。 |    + 检查你的接口地址和 API Key 有没有填对。 | ||||||
|    + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 |    + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 | ||||||
| 6. 报错:`当前分组负载已饱和,请稍后再试` | 6. 报错:`当前分组负载已饱和,请稍后再试` | ||||||
|    + 上游通道 429 了。 |    + 上游渠道 429 了。 | ||||||
| 7. 升级之后我的数据会丢失吗? | 7. 升级之后我的数据会丢失吗? | ||||||
|    + 如果使用 MySQL,不会。 |    + 如果使用 MySQL,不会。 | ||||||
|    + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 |    + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 | ||||||
| @@ -415,12 +442,14 @@ https://openai.justsong.cn | |||||||
|    + 一般情况下不需要,系统将在初始化的时候自动调整。 |    + 一般情况下不需要,系统将在初始化的时候自动调整。 | ||||||
|    + 如果需要的话,我会在更新日志中说明,并给出脚本。 |    + 如果需要的话,我会在更新日志中说明,并给出脚本。 | ||||||
| 9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`? | 9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`? | ||||||
|    + 这是检测到 ability 表里有些记录的通道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的通道。 |    + 这是检测到 ability 表里有些记录的渠道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的渠道。 | ||||||
|    + 对于每一个通道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该通道支持该模型。 |    + 对于每一个渠道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该渠道支持该模型。 | ||||||
|  |  | ||||||
| ## 相关项目 | ## 相关项目 | ||||||
| * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | ||||||
| * [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web):  一键拥有你自己的跨平台 ChatGPT 应用 | * [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web):  一键拥有你自己的跨平台 ChatGPT 应用 | ||||||
|  | * [VChart](https://github.com/VisActor/VChart):  不只是开箱即用的多端图表库,更是生动灵活的数据故事讲述者。 | ||||||
|  | * [VMind](https://github.com/VisActor/VMind):  不仅自动,还很智能。开源智能可视化解决方案。 | ||||||
|  |  | ||||||
| ## 注意 | ## 注意 | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										29
									
								
								common/blacklist/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								common/blacklist/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | |||||||
|  | package blacklist | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"sync" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var blackList sync.Map | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	blackList = sync.Map{} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func userId2Key(id int) string { | ||||||
|  | 	return fmt.Sprintf("userid_%d", id) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BanUser(id int) { | ||||||
|  | 	blackList.Store(userId2Key(id), true) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func UnbanUser(id int) { | ||||||
|  | 	blackList.Delete(userId2Key(id)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func IsUserBanned(id int) bool { | ||||||
|  | 	_, ok := blackList.Load(userId2Key(id)) | ||||||
|  | 	return ok | ||||||
|  | } | ||||||
							
								
								
									
										60
									
								
								common/client/init.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								common/client/init.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,60 @@ | |||||||
|  | package client | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/url" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var HTTPClient *http.Client | ||||||
|  | var ImpatientHTTPClient *http.Client | ||||||
|  | var UserContentRequestHTTPClient *http.Client | ||||||
|  |  | ||||||
|  | func Init() { | ||||||
|  | 	if config.UserContentRequestProxy != "" { | ||||||
|  | 		logger.SysLog(fmt.Sprintf("using %s as proxy to fetch user content", config.UserContentRequestProxy)) | ||||||
|  | 		proxyURL, err := url.Parse(config.UserContentRequestProxy) | ||||||
|  | 		if err != nil { | ||||||
|  | 			logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy)) | ||||||
|  | 		} | ||||||
|  | 		transport := &http.Transport{ | ||||||
|  | 			Proxy: http.ProxyURL(proxyURL), | ||||||
|  | 		} | ||||||
|  | 		UserContentRequestHTTPClient = &http.Client{ | ||||||
|  | 			Transport: transport, | ||||||
|  | 			Timeout:   time.Second * time.Duration(config.UserContentRequestTimeout), | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		UserContentRequestHTTPClient = &http.Client{} | ||||||
|  | 	} | ||||||
|  | 	var transport http.RoundTripper | ||||||
|  | 	if config.RelayProxy != "" { | ||||||
|  | 		logger.SysLog(fmt.Sprintf("using %s as api relay proxy", config.RelayProxy)) | ||||||
|  | 		proxyURL, err := url.Parse(config.RelayProxy) | ||||||
|  | 		if err != nil { | ||||||
|  | 			logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy)) | ||||||
|  | 		} | ||||||
|  | 		transport = &http.Transport{ | ||||||
|  | 			Proxy: http.ProxyURL(proxyURL), | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if config.RelayTimeout == 0 { | ||||||
|  | 		HTTPClient = &http.Client{ | ||||||
|  | 			Transport: transport, | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		HTTPClient = &http.Client{ | ||||||
|  | 			Timeout:   time.Duration(config.RelayTimeout) * time.Second, | ||||||
|  | 			Transport: transport, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ImpatientHTTPClient = &http.Client{ | ||||||
|  | 		Timeout:   5 * time.Second, | ||||||
|  | 		Transport: transport, | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										153
									
								
								common/config/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										153
									
								
								common/config/config.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,153 @@ | |||||||
|  | package config | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/songquanpeng/one-api/common/env" | ||||||
|  | 	"os" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/google/uuid" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var SystemName = "One API" | ||||||
|  | var ServerAddress = "http://localhost:3000" | ||||||
|  | var Footer = "" | ||||||
|  | var Logo = "" | ||||||
|  | var TopUpLink = "" | ||||||
|  | var ChatLink = "" | ||||||
|  | var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens | ||||||
|  | var DisplayInCurrencyEnabled = true | ||||||
|  | var DisplayTokenStatEnabled = true | ||||||
|  |  | ||||||
|  | // Any options with "Secret", "Token" in its key won't be return by GetOptions | ||||||
|  |  | ||||||
|  | var SessionSecret = uuid.New().String() | ||||||
|  |  | ||||||
|  | var OptionMap map[string]string | ||||||
|  | var OptionMapRWMutex sync.RWMutex | ||||||
|  |  | ||||||
|  | var ItemsPerPage = 10 | ||||||
|  | var MaxRecentItems = 100 | ||||||
|  |  | ||||||
|  | var PasswordLoginEnabled = true | ||||||
|  | var PasswordRegisterEnabled = true | ||||||
|  | var EmailVerificationEnabled = false | ||||||
|  | var GitHubOAuthEnabled = false | ||||||
|  | var WeChatAuthEnabled = false | ||||||
|  | var TurnstileCheckEnabled = false | ||||||
|  | var RegisterEnabled = true | ||||||
|  |  | ||||||
|  | var EmailDomainRestrictionEnabled = false | ||||||
|  | var EmailDomainWhitelist = []string{ | ||||||
|  | 	"gmail.com", | ||||||
|  | 	"163.com", | ||||||
|  | 	"126.com", | ||||||
|  | 	"qq.com", | ||||||
|  | 	"outlook.com", | ||||||
|  | 	"hotmail.com", | ||||||
|  | 	"icloud.com", | ||||||
|  | 	"yahoo.com", | ||||||
|  | 	"foxmail.com", | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var DebugEnabled = strings.ToLower(os.Getenv("DEBUG")) == "true" | ||||||
|  | var DebugSQLEnabled = strings.ToLower(os.Getenv("DEBUG_SQL")) == "true" | ||||||
|  | var MemoryCacheEnabled = strings.ToLower(os.Getenv("MEMORY_CACHE_ENABLED")) == "true" | ||||||
|  |  | ||||||
|  | var LogConsumeEnabled = true | ||||||
|  |  | ||||||
|  | var SMTPServer = "" | ||||||
|  | var SMTPPort = 587 | ||||||
|  | var SMTPAccount = "" | ||||||
|  | var SMTPFrom = "" | ||||||
|  | var SMTPToken = "" | ||||||
|  |  | ||||||
|  | var GitHubClientId = "" | ||||||
|  | var GitHubClientSecret = "" | ||||||
|  |  | ||||||
|  | var LarkClientId = "" | ||||||
|  | var LarkClientSecret = "" | ||||||
|  |  | ||||||
|  | var WeChatServerAddress = "" | ||||||
|  | var WeChatServerToken = "" | ||||||
|  | var WeChatAccountQRCodeImageURL = "" | ||||||
|  |  | ||||||
|  | var MessagePusherAddress = "" | ||||||
|  | var MessagePusherToken = "" | ||||||
|  |  | ||||||
|  | var TurnstileSiteKey = "" | ||||||
|  | var TurnstileSecretKey = "" | ||||||
|  |  | ||||||
|  | var QuotaForNewUser int64 = 0 | ||||||
|  | var QuotaForInviter int64 = 0 | ||||||
|  | var QuotaForInvitee int64 = 0 | ||||||
|  | var ChannelDisableThreshold = 5.0 | ||||||
|  | var AutomaticDisableChannelEnabled = false | ||||||
|  | var AutomaticEnableChannelEnabled = false | ||||||
|  | var QuotaRemindThreshold int64 = 1000 | ||||||
|  | var PreConsumedQuota int64 = 500 | ||||||
|  | var ApproximateTokenEnabled = false | ||||||
|  | var RetryTimes = 0 | ||||||
|  |  | ||||||
|  | var RootUserEmail = "" | ||||||
|  |  | ||||||
|  | var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" | ||||||
|  |  | ||||||
|  | var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | ||||||
|  | var RequestInterval = time.Duration(requestInterval) * time.Second | ||||||
|  |  | ||||||
|  | var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second | ||||||
|  |  | ||||||
|  | var BatchUpdateEnabled = false | ||||||
|  | var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5) | ||||||
|  |  | ||||||
|  | var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second | ||||||
|  |  | ||||||
|  | var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE") | ||||||
|  |  | ||||||
|  | var Theme = env.String("THEME", "default") | ||||||
|  | var ValidThemes = map[string]bool{ | ||||||
|  | 	"default": true, | ||||||
|  | 	"berry":   true, | ||||||
|  | 	"air":     true, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // All duration's unit is seconds | ||||||
|  | // Shouldn't larger then RateLimitKeyExpirationDuration | ||||||
|  | var ( | ||||||
|  | 	GlobalApiRateLimitNum            = env.Int("GLOBAL_API_RATE_LIMIT", 240) | ||||||
|  | 	GlobalApiRateLimitDuration int64 = 3 * 60 | ||||||
|  |  | ||||||
|  | 	GlobalWebRateLimitNum            = env.Int("GLOBAL_WEB_RATE_LIMIT", 120) | ||||||
|  | 	GlobalWebRateLimitDuration int64 = 3 * 60 | ||||||
|  |  | ||||||
|  | 	UploadRateLimitNum            = 10 | ||||||
|  | 	UploadRateLimitDuration int64 = 60 | ||||||
|  |  | ||||||
|  | 	DownloadRateLimitNum            = 10 | ||||||
|  | 	DownloadRateLimitDuration int64 = 60 | ||||||
|  |  | ||||||
|  | 	CriticalRateLimitNum            = 20 | ||||||
|  | 	CriticalRateLimitDuration int64 = 20 * 60 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var RateLimitKeyExpirationDuration = 20 * time.Minute | ||||||
|  |  | ||||||
|  | var EnableMetric = env.Bool("ENABLE_METRIC", false) | ||||||
|  | var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10) | ||||||
|  | var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) | ||||||
|  | var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024) | ||||||
|  | var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) | ||||||
|  |  | ||||||
|  | var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") | ||||||
|  |  | ||||||
|  | var GeminiVersion = env.String("GEMINI_VERSION", "v1") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false) | ||||||
|  |  | ||||||
|  | var RelayProxy = env.String("RELAY_PROXY", "") | ||||||
|  | var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") | ||||||
|  | var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) | ||||||
| @@ -1,227 +1,6 @@ | |||||||
| package common | package common | ||||||
|  |  | ||||||
| import ( | import "time" | ||||||
| 	"os" |  | ||||||
| 	"strconv" |  | ||||||
| 	"sync" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/google/uuid" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| var StartTime = time.Now().Unix() // unit: second | var StartTime = time.Now().Unix() // unit: second | ||||||
| var Version = "v0.0.0"            // this hard coding will be replaced automatically when building, no need to manually change | var Version = "v0.0.0"            // this hard coding will be replaced automatically when building, no need to manually change | ||||||
| var SystemName = "One API" |  | ||||||
| var ServerAddress = "http://localhost:3000" |  | ||||||
| var Footer = "" |  | ||||||
| var Logo = "" |  | ||||||
| var TopUpLink = "" |  | ||||||
| var ChatLink = "" |  | ||||||
| var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens |  | ||||||
| var DisplayInCurrencyEnabled = true |  | ||||||
| var DisplayTokenStatEnabled = true |  | ||||||
|  |  | ||||||
| // Any options with "Secret", "Token" in its key won't be return by GetOptions |  | ||||||
|  |  | ||||||
| var SessionSecret = uuid.New().String() |  | ||||||
|  |  | ||||||
| var OptionMap map[string]string |  | ||||||
| var OptionMapRWMutex sync.RWMutex |  | ||||||
|  |  | ||||||
| var ItemsPerPage = 10 |  | ||||||
| var MaxRecentItems = 100 |  | ||||||
|  |  | ||||||
| var PasswordLoginEnabled = true |  | ||||||
| var PasswordRegisterEnabled = true |  | ||||||
| var EmailVerificationEnabled = false |  | ||||||
| var GitHubOAuthEnabled = false |  | ||||||
| var WeChatAuthEnabled = false |  | ||||||
| var TurnstileCheckEnabled = false |  | ||||||
| var RegisterEnabled = true |  | ||||||
|  |  | ||||||
| var EmailDomainRestrictionEnabled = false |  | ||||||
| var EmailDomainWhitelist = []string{ |  | ||||||
| 	"gmail.com", |  | ||||||
| 	"163.com", |  | ||||||
| 	"126.com", |  | ||||||
| 	"qq.com", |  | ||||||
| 	"outlook.com", |  | ||||||
| 	"hotmail.com", |  | ||||||
| 	"icloud.com", |  | ||||||
| 	"yahoo.com", |  | ||||||
| 	"foxmail.com", |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var DebugEnabled = os.Getenv("DEBUG") == "true" |  | ||||||
| var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" |  | ||||||
|  |  | ||||||
| var LogConsumeEnabled = true |  | ||||||
|  |  | ||||||
| var SMTPServer = "" |  | ||||||
| var SMTPPort = 587 |  | ||||||
| var SMTPAccount = "" |  | ||||||
| var SMTPFrom = "" |  | ||||||
| var SMTPToken = "" |  | ||||||
|  |  | ||||||
| var GitHubClientId = "" |  | ||||||
| var GitHubClientSecret = "" |  | ||||||
|  |  | ||||||
| var WeChatServerAddress = "" |  | ||||||
| var WeChatServerToken = "" |  | ||||||
| var WeChatAccountQRCodeImageURL = "" |  | ||||||
|  |  | ||||||
| var TurnstileSiteKey = "" |  | ||||||
| var TurnstileSecretKey = "" |  | ||||||
|  |  | ||||||
| var QuotaForNewUser = 0 |  | ||||||
| var QuotaForInviter = 0 |  | ||||||
| var QuotaForInvitee = 0 |  | ||||||
| var ChannelDisableThreshold = 5.0 |  | ||||||
| var AutomaticDisableChannelEnabled = false |  | ||||||
| var AutomaticEnableChannelEnabled = false |  | ||||||
| var QuotaRemindThreshold = 1000 |  | ||||||
| var PreConsumedQuota = 500 |  | ||||||
| var ApproximateTokenEnabled = false |  | ||||||
| var RetryTimes = 0 |  | ||||||
|  |  | ||||||
| var RootUserEmail = "" |  | ||||||
|  |  | ||||||
| var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" |  | ||||||
|  |  | ||||||
| var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) |  | ||||||
| var RequestInterval = time.Duration(requestInterval) * time.Second |  | ||||||
|  |  | ||||||
| var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second |  | ||||||
|  |  | ||||||
| var BatchUpdateEnabled = false |  | ||||||
| var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) |  | ||||||
|  |  | ||||||
| var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second |  | ||||||
|  |  | ||||||
| var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") |  | ||||||
|  |  | ||||||
| var Theme = GetOrDefaultString("THEME", "default") |  | ||||||
| var ValidThemes = map[string]bool{ |  | ||||||
| 	"default": true, |  | ||||||
| 	"berry":   true, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	RequestIdKey = "X-Oneapi-Request-Id" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	RoleGuestUser  = 0 |  | ||||||
| 	RoleCommonUser = 1 |  | ||||||
| 	RoleAdminUser  = 10 |  | ||||||
| 	RoleRootUser   = 100 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| var ( |  | ||||||
| 	FileUploadPermission    = RoleGuestUser |  | ||||||
| 	FileDownloadPermission  = RoleGuestUser |  | ||||||
| 	ImageUploadPermission   = RoleGuestUser |  | ||||||
| 	ImageDownloadPermission = RoleGuestUser |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // All duration's unit is seconds |  | ||||||
| // Shouldn't larger then RateLimitKeyExpirationDuration |  | ||||||
| var ( |  | ||||||
| 	GlobalApiRateLimitNum            = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) |  | ||||||
| 	GlobalApiRateLimitDuration int64 = 3 * 60 |  | ||||||
|  |  | ||||||
| 	GlobalWebRateLimitNum            = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) |  | ||||||
| 	GlobalWebRateLimitDuration int64 = 3 * 60 |  | ||||||
|  |  | ||||||
| 	UploadRateLimitNum            = 10 |  | ||||||
| 	UploadRateLimitDuration int64 = 60 |  | ||||||
|  |  | ||||||
| 	DownloadRateLimitNum            = 10 |  | ||||||
| 	DownloadRateLimitDuration int64 = 60 |  | ||||||
|  |  | ||||||
| 	CriticalRateLimitNum            = 20 |  | ||||||
| 	CriticalRateLimitDuration int64 = 20 * 60 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| var RateLimitKeyExpirationDuration = 20 * time.Minute |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	UserStatusEnabled  = 1 // don't use 0, 0 is the default value! |  | ||||||
| 	UserStatusDisabled = 2 // also don't use 0 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	TokenStatusEnabled   = 1 // don't use 0, 0 is the default value! |  | ||||||
| 	TokenStatusDisabled  = 2 // also don't use 0 |  | ||||||
| 	TokenStatusExpired   = 3 |  | ||||||
| 	TokenStatusExhausted = 4 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	RedemptionCodeStatusEnabled  = 1 // don't use 0, 0 is the default value! |  | ||||||
| 	RedemptionCodeStatusDisabled = 2 // also don't use 0 |  | ||||||
| 	RedemptionCodeStatusUsed     = 3 // also don't use 0 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	ChannelStatusUnknown          = 0 |  | ||||||
| 	ChannelStatusEnabled          = 1 // don't use 0, 0 is the default value! |  | ||||||
| 	ChannelStatusManuallyDisabled = 2 // also don't use 0 |  | ||||||
| 	ChannelStatusAutoDisabled     = 3 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	ChannelTypeUnknown        = 0 |  | ||||||
| 	ChannelTypeOpenAI         = 1 |  | ||||||
| 	ChannelTypeAPI2D          = 2 |  | ||||||
| 	ChannelTypeAzure          = 3 |  | ||||||
| 	ChannelTypeCloseAI        = 4 |  | ||||||
| 	ChannelTypeOpenAISB       = 5 |  | ||||||
| 	ChannelTypeOpenAIMax      = 6 |  | ||||||
| 	ChannelTypeOhMyGPT        = 7 |  | ||||||
| 	ChannelTypeCustom         = 8 |  | ||||||
| 	ChannelTypeAILS           = 9 |  | ||||||
| 	ChannelTypeAIProxy        = 10 |  | ||||||
| 	ChannelTypePaLM           = 11 |  | ||||||
| 	ChannelTypeAPI2GPT        = 12 |  | ||||||
| 	ChannelTypeAIGC2D         = 13 |  | ||||||
| 	ChannelTypeAnthropic      = 14 |  | ||||||
| 	ChannelTypeBaidu          = 15 |  | ||||||
| 	ChannelTypeZhipu          = 16 |  | ||||||
| 	ChannelTypeAli            = 17 |  | ||||||
| 	ChannelTypeXunfei         = 18 |  | ||||||
| 	ChannelType360            = 19 |  | ||||||
| 	ChannelTypeOpenRouter     = 20 |  | ||||||
| 	ChannelTypeAIProxyLibrary = 21 |  | ||||||
| 	ChannelTypeFastGPT        = 22 |  | ||||||
| 	ChannelTypeTencent        = 23 |  | ||||||
| 	ChannelTypeGemini         = 24 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| var ChannelBaseURLs = []string{ |  | ||||||
| 	"",                                  // 0 |  | ||||||
| 	"https://api.openai.com",            // 1 |  | ||||||
| 	"https://oa.api2d.net",              // 2 |  | ||||||
| 	"",                                  // 3 |  | ||||||
| 	"https://api.closeai-proxy.xyz",     // 4 |  | ||||||
| 	"https://api.openai-sb.com",         // 5 |  | ||||||
| 	"https://api.openaimax.com",         // 6 |  | ||||||
| 	"https://api.ohmygpt.com",           // 7 |  | ||||||
| 	"",                                  // 8 |  | ||||||
| 	"https://api.caipacity.com",         // 9 |  | ||||||
| 	"https://api.aiproxy.io",            // 10 |  | ||||||
| 	"",                                  // 11 |  | ||||||
| 	"https://api.api2gpt.com",           // 12 |  | ||||||
| 	"https://api.aigc2d.com",            // 13 |  | ||||||
| 	"https://api.anthropic.com",         // 14 |  | ||||||
| 	"https://aip.baidubce.com",          // 15 |  | ||||||
| 	"https://open.bigmodel.cn",          // 16 |  | ||||||
| 	"https://dashscope.aliyuncs.com",    // 17 |  | ||||||
| 	"",                                  // 18 |  | ||||||
| 	"https://ai.360.cn",                 // 19 |  | ||||||
| 	"https://openrouter.ai/api",         // 20 |  | ||||||
| 	"https://api.aiproxy.io",            // 21 |  | ||||||
| 	"https://fastgpt.run/api/openapi",   // 22 |  | ||||||
| 	"https://hunyuan.cloud.tencent.com", //23 |  | ||||||
| 	"",                                  //24 |  | ||||||
| } |  | ||||||
|   | |||||||
							
								
								
									
										6
									
								
								common/conv/any.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								common/conv/any.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | |||||||
|  | package conv | ||||||
|  |  | ||||||
|  | func AsString(v any) string { | ||||||
|  | 	str, _ := v.(string) | ||||||
|  | 	return str | ||||||
|  | } | ||||||
							
								
								
									
										23
									
								
								common/ctxkey/key.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								common/ctxkey/key.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,23 @@ | |||||||
|  | package ctxkey | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	Config            = "config" | ||||||
|  | 	Id                = "id" | ||||||
|  | 	Username          = "username" | ||||||
|  | 	Role              = "role" | ||||||
|  | 	Status            = "status" | ||||||
|  | 	Channel           = "channel" | ||||||
|  | 	ChannelId         = "channel_id" | ||||||
|  | 	SpecificChannelId = "specific_channel_id" | ||||||
|  | 	RequestModel      = "request_model" | ||||||
|  | 	ConvertedRequest  = "converted_request" | ||||||
|  | 	OriginalModel     = "original_model" | ||||||
|  | 	Group             = "group" | ||||||
|  | 	ModelMapping      = "model_mapping" | ||||||
|  | 	ChannelName       = "channel_name" | ||||||
|  | 	TokenId           = "token_id" | ||||||
|  | 	TokenName         = "token_name" | ||||||
|  | 	BaseURL           = "base_url" | ||||||
|  | 	AvailableModels   = "available_models" | ||||||
|  | 	KeyRequestBody    = "key_request_body" | ||||||
|  | ) | ||||||
| @@ -1,7 +1,12 @@ | |||||||
| package common | package common | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/songquanpeng/one-api/common/env" | ||||||
|  | ) | ||||||
|  |  | ||||||
| var UsingSQLite = false | var UsingSQLite = false | ||||||
| var UsingPostgreSQL = false | var UsingPostgreSQL = false | ||||||
|  | var UsingMySQL = false | ||||||
|  |  | ||||||
| var SQLitePath = "one-api.db" | var SQLitePath = "one-api.db" | ||||||
| var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000) | var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000) | ||||||
|   | |||||||
| @@ -1,86 +0,0 @@ | |||||||
| package common |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"crypto/rand" |  | ||||||
| 	"crypto/tls" |  | ||||||
| 	"encoding/base64" |  | ||||||
| 	"fmt" |  | ||||||
| 	"net/smtp" |  | ||||||
| 	"strings" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func SendEmail(subject string, receiver string, content string) error { |  | ||||||
| 	if SMTPFrom == "" { // for compatibility |  | ||||||
| 		SMTPFrom = SMTPAccount |  | ||||||
| 	} |  | ||||||
| 	encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) |  | ||||||
|  |  | ||||||
| 	// Extract domain from SMTPFrom |  | ||||||
| 	parts := strings.Split(SMTPFrom, "@") |  | ||||||
| 	var domain string |  | ||||||
| 	if len(parts) > 1 { |  | ||||||
| 		domain = parts[1] |  | ||||||
| 	} |  | ||||||
| 	// Generate a unique Message-ID |  | ||||||
| 	buf := make([]byte, 16) |  | ||||||
| 	_, err := rand.Read(buf) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	messageId := fmt.Sprintf("<%x@%s>", buf, domain) |  | ||||||
|  |  | ||||||
| 	mail := []byte(fmt.Sprintf("To: %s\r\n"+ |  | ||||||
| 		"From: %s<%s>\r\n"+ |  | ||||||
| 		"Subject: %s\r\n"+ |  | ||||||
| 		"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 |  | ||||||
| 		"Date: %s\r\n"+ |  | ||||||
| 		"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", |  | ||||||
| 		receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) |  | ||||||
| 	auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) |  | ||||||
| 	addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) |  | ||||||
| 	to := strings.Split(receiver, ";") |  | ||||||
|  |  | ||||||
| 	if SMTPPort == 465 { |  | ||||||
| 		tlsConfig := &tls.Config{ |  | ||||||
| 			InsecureSkipVerify: true, |  | ||||||
| 			ServerName:         SMTPServer, |  | ||||||
| 		} |  | ||||||
| 		conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		client, err := smtp.NewClient(conn, SMTPServer) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		defer client.Close() |  | ||||||
| 		if err = client.Auth(auth); err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		if err = client.Mail(SMTPFrom); err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		receiverEmails := strings.Split(receiver, ";") |  | ||||||
| 		for _, receiver := range receiverEmails { |  | ||||||
| 			if err = client.Rcpt(receiver); err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		w, err := client.Data() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		_, err = w.Write(mail) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		err = w.Close() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 	} else { |  | ||||||
| 		err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) |  | ||||||
| 	} |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
| @@ -15,10 +15,7 @@ type embedFileSystem struct { | |||||||
|  |  | ||||||
| func (e embedFileSystem) Exists(prefix string, path string) bool { | func (e embedFileSystem) Exists(prefix string, path string) bool { | ||||||
| 	_, err := e.Open(path) | 	_, err := e.Open(path) | ||||||
| 	if err != nil { | 	return err == nil | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 	return true |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem { | func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem { | ||||||
|   | |||||||
							
								
								
									
										42
									
								
								common/env/helper.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								common/env/helper.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | |||||||
|  | package env | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"os" | ||||||
|  | 	"strconv" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func Bool(env string, defaultValue bool) bool { | ||||||
|  | 	if env == "" || os.Getenv(env) == "" { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	return os.Getenv(env) == "true" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Int(env string, defaultValue int) int { | ||||||
|  | 	if env == "" || os.Getenv(env) == "" { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	num, err := strconv.Atoi(os.Getenv(env)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	return num | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Float64(env string, defaultValue float64) float64 { | ||||||
|  | 	if env == "" || os.Getenv(env) == "" { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	num, err := strconv.ParseFloat(os.Getenv(env), 64) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	return num | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func String(env string, defaultValue string) string { | ||||||
|  | 	if env == "" || os.Getenv(env) == "" { | ||||||
|  | 		return defaultValue | ||||||
|  | 	} | ||||||
|  | 	return os.Getenv(env) | ||||||
|  | } | ||||||
| @@ -4,16 +4,27 @@ import ( | |||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||||
| 	"io" | 	"io" | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func UnmarshalBodyReusable(c *gin.Context, v any) error { | func GetRequestBody(c *gin.Context) ([]byte, error) { | ||||||
|  | 	requestBody, _ := c.Get(ctxkey.KeyRequestBody) | ||||||
|  | 	if requestBody != nil { | ||||||
|  | 		return requestBody.([]byte), nil | ||||||
|  | 	} | ||||||
| 	requestBody, err := io.ReadAll(c.Request.Body) | 	requestBody, err := io.ReadAll(c.Request.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	err = c.Request.Body.Close() | 	_ = c.Request.Body.Close() | ||||||
|  | 	c.Set(ctxkey.KeyRequestBody, requestBody) | ||||||
|  | 	return requestBody.([]byte), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||||
|  | 	requestBody, err := GetRequestBody(c) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @@ -31,3 +42,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { | |||||||
| 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func SetEventStreamHeaders(c *gin.Context) { | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||||
|  | 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||||
|  | 	c.Writer.Header().Set("Connection", "keep-alive") | ||||||
|  | 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||||
|  | 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||||
|  | } | ||||||
|   | |||||||
							
								
								
									
										139
									
								
								common/helper/helper.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								common/helper/helper.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,139 @@ | |||||||
|  | package helper | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/random" | ||||||
|  | 	"html/template" | ||||||
|  | 	"log" | ||||||
|  | 	"net" | ||||||
|  | 	"os/exec" | ||||||
|  | 	"runtime" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func OpenBrowser(url string) { | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	switch runtime.GOOS { | ||||||
|  | 	case "linux": | ||||||
|  | 		err = exec.Command("xdg-open", url).Start() | ||||||
|  | 	case "windows": | ||||||
|  | 		err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() | ||||||
|  | 	case "darwin": | ||||||
|  | 		err = exec.Command("open", url).Start() | ||||||
|  | 	} | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Println(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetIp() (ip string) { | ||||||
|  | 	ips, err := net.InterfaceAddrs() | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Println(err) | ||||||
|  | 		return ip | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, a := range ips { | ||||||
|  | 		if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { | ||||||
|  | 			if ipNet.IP.To4() != nil { | ||||||
|  | 				ip = ipNet.IP.String() | ||||||
|  | 				if strings.HasPrefix(ip, "10") { | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 				if strings.HasPrefix(ip, "172") { | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 				if strings.HasPrefix(ip, "192.168") { | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 				ip = "" | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var sizeKB = 1024 | ||||||
|  | var sizeMB = sizeKB * 1024 | ||||||
|  | var sizeGB = sizeMB * 1024 | ||||||
|  |  | ||||||
|  | func Bytes2Size(num int64) string { | ||||||
|  | 	numStr := "" | ||||||
|  | 	unit := "B" | ||||||
|  | 	if num/int64(sizeGB) > 1 { | ||||||
|  | 		numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB)) | ||||||
|  | 		unit = "GB" | ||||||
|  | 	} else if num/int64(sizeMB) > 1 { | ||||||
|  | 		numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB))) | ||||||
|  | 		unit = "MB" | ||||||
|  | 	} else if num/int64(sizeKB) > 1 { | ||||||
|  | 		numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB))) | ||||||
|  | 		unit = "KB" | ||||||
|  | 	} else { | ||||||
|  | 		numStr = fmt.Sprintf("%d", num) | ||||||
|  | 	} | ||||||
|  | 	return numStr + " " + unit | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Interface2String(inter interface{}) string { | ||||||
|  | 	switch inter := inter.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		return inter | ||||||
|  | 	case int: | ||||||
|  | 		return fmt.Sprintf("%d", inter) | ||||||
|  | 	case float64: | ||||||
|  | 		return fmt.Sprintf("%f", inter) | ||||||
|  | 	} | ||||||
|  | 	return "Not Implemented" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func UnescapeHTML(x string) interface{} { | ||||||
|  | 	return template.HTML(x) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func IntMax(a int, b int) int { | ||||||
|  | 	if a >= b { | ||||||
|  | 		return a | ||||||
|  | 	} else { | ||||||
|  | 		return b | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GenRequestID() string { | ||||||
|  | 	return GetTimeString() + random.GetRandomNumberString(8) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetResponseID(c *gin.Context) string { | ||||||
|  | 	logID := c.GetString(RequestIdKey) | ||||||
|  | 	return fmt.Sprintf("chatcmpl-%s", logID) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Max(a int, b int) int { | ||||||
|  | 	if a >= b { | ||||||
|  | 		return a | ||||||
|  | 	} else { | ||||||
|  | 		return b | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func AssignOrDefault(value string, defaultValue string) string { | ||||||
|  | 	if len(value) != 0 { | ||||||
|  | 		return value | ||||||
|  | 	} | ||||||
|  | 	return defaultValue | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func MessageWithRequestId(message string, id string) string { | ||||||
|  | 	return fmt.Sprintf("%s (request id: %s)", message, id) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func String2Int(str string) int { | ||||||
|  | 	num, err := strconv.Atoi(str) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0 | ||||||
|  | 	} | ||||||
|  | 	return num | ||||||
|  | } | ||||||
							
								
								
									
										5
									
								
								common/helper/key.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								common/helper/key.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | package helper | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	RequestIdKey = "X-Oneapi-Request-Id" | ||||||
|  | ) | ||||||
							
								
								
									
										15
									
								
								common/helper/time.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								common/helper/time.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | |||||||
|  | package helper | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func GetTimestamp() int64 { | ||||||
|  | 	return time.Now().Unix() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetTimeString() string { | ||||||
|  | 	now := time.Now() | ||||||
|  | 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) | ||||||
|  | } | ||||||
| @@ -3,6 +3,7 @@ package image | |||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/client" | ||||||
| 	"image" | 	"image" | ||||||
| 	_ "image/gif" | 	_ "image/gif" | ||||||
| 	_ "image/jpeg" | 	_ "image/jpeg" | ||||||
| @@ -16,10 +17,10 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| // Regex to match data URL pattern | // Regex to match data URL pattern | ||||||
| var	dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) | var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) | ||||||
|  |  | ||||||
| func IsImageUrl(url string) (bool, error) { | func IsImageUrl(url string) (bool, error) { | ||||||
| 	resp, err := http.Head(url) | 	resp, err := client.UserContentRequestHTTPClient.Head(url) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return false, err | 		return false, err | ||||||
| 	} | 	} | ||||||
| @@ -34,7 +35,7 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) { | |||||||
| 	if !isImage { | 	if !isImage { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	resp, err := http.Get(url) | 	resp, err := client.UserContentRequestHTTPClient.Get(url) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package image_test | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/client" | ||||||
| 	"image" | 	"image" | ||||||
| 	_ "image/gif" | 	_ "image/gif" | ||||||
| 	_ "image/jpeg" | 	_ "image/jpeg" | ||||||
| @@ -12,7 +13,7 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	img "one-api/common/image" | 	img "github.com/songquanpeng/one-api/common/image" | ||||||
|  |  | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| 	_ "golang.org/x/image/webp" | 	_ "golang.org/x/image/webp" | ||||||
| @@ -44,6 +45,11 @@ var ( | |||||||
| 	} | 	} | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | func TestMain(m *testing.M) { | ||||||
|  | 	client.Init() | ||||||
|  | 	m.Run() | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestDecode(t *testing.T) { | func TestDecode(t *testing.T) { | ||||||
| 	// Bytes read: varies sometimes | 	// Bytes read: varies sometimes | ||||||
| 	// jpeg: 1063892 | 	// jpeg: 1063892 | ||||||
|   | |||||||
| @@ -3,6 +3,8 @@ package common | |||||||
| import ( | import ( | ||||||
| 	"flag" | 	"flag" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"log" | 	"log" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
| @@ -22,7 +24,7 @@ func printHelp() { | |||||||
| 	fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]") | 	fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]") | ||||||
| } | } | ||||||
|  |  | ||||||
| func init() { | func Init() { | ||||||
| 	flag.Parse() | 	flag.Parse() | ||||||
|  |  | ||||||
| 	if *PrintVersion { | 	if *PrintVersion { | ||||||
| @@ -37,9 +39,9 @@ func init() { | |||||||
|  |  | ||||||
| 	if os.Getenv("SESSION_SECRET") != "" { | 	if os.Getenv("SESSION_SECRET") != "" { | ||||||
| 		if os.Getenv("SESSION_SECRET") == "random_string" { | 		if os.Getenv("SESSION_SECRET") == "random_string" { | ||||||
| 			SysError("SESSION_SECRET is set to an example value, please change it to a random string.") | 			logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.") | ||||||
| 		} else { | 		} else { | ||||||
| 			SessionSecret = os.Getenv("SESSION_SECRET") | 			config.SessionSecret = os.Getenv("SESSION_SECRET") | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("SQLITE_PATH") != "" { | 	if os.Getenv("SQLITE_PATH") != "" { | ||||||
| @@ -57,5 +59,6 @@ func init() { | |||||||
| 				log.Fatal(err) | 				log.Fatal(err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | 		logger.LogDir = *LogDir | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										100
									
								
								common/logger.go
									
									
									
									
									
								
							
							
						
						
									
										100
									
								
								common/logger.go
									
									
									
									
									
								
							| @@ -1,100 +0,0 @@ | |||||||
| package common |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"context" |  | ||||||
| 	"fmt" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"io" |  | ||||||
| 	"log" |  | ||||||
| 	"os" |  | ||||||
| 	"path/filepath" |  | ||||||
| 	"sync" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	loggerINFO  = "INFO" |  | ||||||
| 	loggerWarn  = "WARN" |  | ||||||
| 	loggerError = "ERR" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const maxLogCount = 1000000 |  | ||||||
|  |  | ||||||
| var logCount int |  | ||||||
| var setupLogLock sync.Mutex |  | ||||||
| var setupLogWorking bool |  | ||||||
|  |  | ||||||
| func SetupLogger() { |  | ||||||
| 	if *LogDir != "" { |  | ||||||
| 		ok := setupLogLock.TryLock() |  | ||||||
| 		if !ok { |  | ||||||
| 			log.Println("setup log is already working") |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		defer func() { |  | ||||||
| 			setupLogLock.Unlock() |  | ||||||
| 			setupLogWorking = false |  | ||||||
| 		}() |  | ||||||
| 		logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) |  | ||||||
| 		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Fatal("failed to open log file") |  | ||||||
| 		} |  | ||||||
| 		gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) |  | ||||||
| 		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func SysLog(s string) { |  | ||||||
| 	t := time.Now() |  | ||||||
| 	_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func SysError(s string) { |  | ||||||
| 	t := time.Now() |  | ||||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func LogInfo(ctx context.Context, msg string) { |  | ||||||
| 	logHelper(ctx, loggerINFO, msg) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func LogWarn(ctx context.Context, msg string) { |  | ||||||
| 	logHelper(ctx, loggerWarn, msg) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func LogError(ctx context.Context, msg string) { |  | ||||||
| 	logHelper(ctx, loggerError, msg) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func logHelper(ctx context.Context, level string, msg string) { |  | ||||||
| 	writer := gin.DefaultErrorWriter |  | ||||||
| 	if level == loggerINFO { |  | ||||||
| 		writer = gin.DefaultWriter |  | ||||||
| 	} |  | ||||||
| 	id := ctx.Value(RequestIdKey) |  | ||||||
| 	now := time.Now() |  | ||||||
| 	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) |  | ||||||
| 	logCount++ // we don't need accurate count, so no lock here |  | ||||||
| 	if logCount > maxLogCount && !setupLogWorking { |  | ||||||
| 		logCount = 0 |  | ||||||
| 		setupLogWorking = true |  | ||||||
| 		go func() { |  | ||||||
| 			SetupLogger() |  | ||||||
| 		}() |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func FatalLog(v ...any) { |  | ||||||
| 	t := time.Now() |  | ||||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) |  | ||||||
| 	os.Exit(1) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func LogQuota(quota int) string { |  | ||||||
| 	if DisplayInCurrencyEnabled { |  | ||||||
| 		return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) |  | ||||||
| 	} else { |  | ||||||
| 		return fmt.Sprintf("%d 点额度", quota) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
							
								
								
									
										3
									
								
								common/logger/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								common/logger/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | package logger | ||||||
|  |  | ||||||
|  | var LogDir string | ||||||
							
								
								
									
										116
									
								
								common/logger/logger.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										116
									
								
								common/logger/logger.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,116 @@ | |||||||
|  | package logger | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"log" | ||||||
|  | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	loggerDEBUG = "DEBUG" | ||||||
|  | 	loggerINFO  = "INFO" | ||||||
|  | 	loggerWarn  = "WARN" | ||||||
|  | 	loggerError = "ERR" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var setupLogOnce sync.Once | ||||||
|  |  | ||||||
|  | func SetupLogger() { | ||||||
|  | 	setupLogOnce.Do(func() { | ||||||
|  | 		if LogDir != "" { | ||||||
|  | 			var logPath string | ||||||
|  | 			if config.OnlyOneLogFile { | ||||||
|  | 				logPath = filepath.Join(LogDir, "oneapi.log") | ||||||
|  | 			} else { | ||||||
|  | 				logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) | ||||||
|  | 			} | ||||||
|  | 			fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | ||||||
|  | 			if err != nil { | ||||||
|  | 				log.Fatal("failed to open log file") | ||||||
|  | 			} | ||||||
|  | 			gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) | ||||||
|  | 			gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func SysLog(s string) { | ||||||
|  | 	t := time.Now() | ||||||
|  | 	_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func SysLogf(format string, a ...any) { | ||||||
|  | 	SysLog(fmt.Sprintf(format, a...)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func SysError(s string) { | ||||||
|  | 	t := time.Now() | ||||||
|  | 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func SysErrorf(format string, a ...any) { | ||||||
|  | 	SysError(fmt.Sprintf(format, a...)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Debug(ctx context.Context, msg string) { | ||||||
|  | 	if config.DebugEnabled { | ||||||
|  | 		logHelper(ctx, loggerDEBUG, msg) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Info(ctx context.Context, msg string) { | ||||||
|  | 	logHelper(ctx, loggerINFO, msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Warn(ctx context.Context, msg string) { | ||||||
|  | 	logHelper(ctx, loggerWarn, msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Error(ctx context.Context, msg string) { | ||||||
|  | 	logHelper(ctx, loggerError, msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Debugf(ctx context.Context, format string, a ...any) { | ||||||
|  | 	Debug(ctx, fmt.Sprintf(format, a...)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Infof(ctx context.Context, format string, a ...any) { | ||||||
|  | 	Info(ctx, fmt.Sprintf(format, a...)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Warnf(ctx context.Context, format string, a ...any) { | ||||||
|  | 	Warn(ctx, fmt.Sprintf(format, a...)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Errorf(ctx context.Context, format string, a ...any) { | ||||||
|  | 	Error(ctx, fmt.Sprintf(format, a...)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func logHelper(ctx context.Context, level string, msg string) { | ||||||
|  | 	writer := gin.DefaultErrorWriter | ||||||
|  | 	if level == loggerINFO { | ||||||
|  | 		writer = gin.DefaultWriter | ||||||
|  | 	} | ||||||
|  | 	id := ctx.Value(helper.RequestIdKey) | ||||||
|  | 	if id == nil { | ||||||
|  | 		id = helper.GenRequestID() | ||||||
|  | 	} | ||||||
|  | 	now := time.Now() | ||||||
|  | 	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) | ||||||
|  | 	SetupLogger() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func FatalLog(v ...any) { | ||||||
|  | 	t := time.Now() | ||||||
|  | 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) | ||||||
|  | 	os.Exit(1) | ||||||
|  | } | ||||||
							
								
								
									
										105
									
								
								common/message/email.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								common/message/email.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | |||||||
|  | package message | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"crypto/tls" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"net" | ||||||
|  | 	"net/smtp" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func shouldAuth() bool { | ||||||
|  | 	return config.SMTPAccount != "" || config.SMTPToken != "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func SendEmail(subject string, receiver string, content string) error { | ||||||
|  | 	if receiver == "" { | ||||||
|  | 		return fmt.Errorf("receiver is empty") | ||||||
|  | 	} | ||||||
|  | 	if config.SMTPFrom == "" { // for compatibility | ||||||
|  | 		config.SMTPFrom = config.SMTPAccount | ||||||
|  | 	} | ||||||
|  | 	encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) | ||||||
|  |  | ||||||
|  | 	// Extract domain from SMTPFrom | ||||||
|  | 	parts := strings.Split(config.SMTPFrom, "@") | ||||||
|  | 	var domain string | ||||||
|  | 	if len(parts) > 1 { | ||||||
|  | 		domain = parts[1] | ||||||
|  | 	} | ||||||
|  | 	// Generate a unique Message-ID | ||||||
|  | 	buf := make([]byte, 16) | ||||||
|  | 	_, err := rand.Read(buf) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	messageId := fmt.Sprintf("<%x@%s>", buf, domain) | ||||||
|  |  | ||||||
|  | 	mail := []byte(fmt.Sprintf("To: %s\r\n"+ | ||||||
|  | 		"From: %s<%s>\r\n"+ | ||||||
|  | 		"Subject: %s\r\n"+ | ||||||
|  | 		"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 | ||||||
|  | 		"Date: %s\r\n"+ | ||||||
|  | 		"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", | ||||||
|  | 		receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) | ||||||
|  |  | ||||||
|  | 	auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) | ||||||
|  | 	addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) | ||||||
|  | 	to := strings.Split(receiver, ";") | ||||||
|  |  | ||||||
|  | 	if config.SMTPPort == 465 || !shouldAuth() { | ||||||
|  | 		// need advanced client | ||||||
|  | 		var conn net.Conn | ||||||
|  | 		var err error | ||||||
|  | 		if config.SMTPPort == 465 { | ||||||
|  | 			tlsConfig := &tls.Config{ | ||||||
|  | 				InsecureSkipVerify: true, | ||||||
|  | 				ServerName:         config.SMTPServer, | ||||||
|  | 			} | ||||||
|  | 			conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) | ||||||
|  | 		} else { | ||||||
|  | 			conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)) | ||||||
|  | 		} | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		client, err := smtp.NewClient(conn, config.SMTPServer) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		defer client.Close() | ||||||
|  | 		if shouldAuth() { | ||||||
|  | 			if err = client.Auth(auth); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		if err = client.Mail(config.SMTPFrom); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		receiverEmails := strings.Split(receiver, ";") | ||||||
|  | 		for _, receiver := range receiverEmails { | ||||||
|  | 			if err = client.Rcpt(receiver); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		w, err := client.Data() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		_, err = w.Write(mail) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		err = w.Close() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail) | ||||||
|  | 	} | ||||||
|  | 	return err | ||||||
|  | } | ||||||
							
								
								
									
										22
									
								
								common/message/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								common/message/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | |||||||
|  | package message | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	ByAll           = "all" | ||||||
|  | 	ByEmail         = "email" | ||||||
|  | 	ByMessagePusher = "message_pusher" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func Notify(by string, title string, description string, content string) error { | ||||||
|  | 	if by == ByEmail { | ||||||
|  | 		return SendEmail(title, config.RootUserEmail, content) | ||||||
|  | 	} | ||||||
|  | 	if by == ByMessagePusher { | ||||||
|  | 		return SendMessage(title, description, content) | ||||||
|  | 	} | ||||||
|  | 	return fmt.Errorf("unknown notify method: %s", by) | ||||||
|  | } | ||||||
							
								
								
									
										53
									
								
								common/message/message-pusher.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								common/message/message-pusher.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | |||||||
|  | package message | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"net/http" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type request struct { | ||||||
|  | 	Title       string `json:"title"` | ||||||
|  | 	Description string `json:"description"` | ||||||
|  | 	Content     string `json:"content"` | ||||||
|  | 	URL         string `json:"url"` | ||||||
|  | 	Channel     string `json:"channel"` | ||||||
|  | 	Token       string `json:"token"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type response struct { | ||||||
|  | 	Success bool   `json:"success"` | ||||||
|  | 	Message string `json:"message"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func SendMessage(title string, description string, content string) error { | ||||||
|  | 	if config.MessagePusherAddress == "" { | ||||||
|  | 		return errors.New("message pusher address is not set") | ||||||
|  | 	} | ||||||
|  | 	req := request{ | ||||||
|  | 		Title:       title, | ||||||
|  | 		Description: description, | ||||||
|  | 		Content:     content, | ||||||
|  | 		Token:       config.MessagePusherToken, | ||||||
|  | 	} | ||||||
|  | 	data, err := json.Marshal(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	resp, err := http.Post(config.MessagePusherAddress, | ||||||
|  | 		"application/json", bytes.NewBuffer(data)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	var res response | ||||||
|  | 	err = json.NewDecoder(resp.Body).Decode(&res) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if !res.Success { | ||||||
|  | 		return errors.New(res.Message) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
| @@ -1,161 +0,0 @@ | |||||||
| package common |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"strings" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| var DalleSizeRatios = map[string]map[string]float64{ |  | ||||||
| 	"dall-e-2": { |  | ||||||
| 		"256x256":   1, |  | ||||||
| 		"512x512":   1.125, |  | ||||||
| 		"1024x1024": 1.25, |  | ||||||
| 	}, |  | ||||||
| 	"dall-e-3": { |  | ||||||
| 		"1024x1024": 1, |  | ||||||
| 		"1024x1792": 2, |  | ||||||
| 		"1792x1024": 2, |  | ||||||
| 	}, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var DalleGenerationImageAmounts = map[string][2]int{ |  | ||||||
| 	"dall-e-2": {1, 10}, |  | ||||||
| 	"dall-e-3": {1, 1}, // OpenAI allows n=1 currently. |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var DalleImagePromptLengthLimitations = map[string]int{ |  | ||||||
| 	"dall-e-2": 1000, |  | ||||||
| 	"dall-e-3": 4000, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ModelRatio |  | ||||||
| // https://platform.openai.com/docs/models/model-endpoint-compatibility |  | ||||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf |  | ||||||
| // https://openai.com/pricing |  | ||||||
| // TODO: when a new api is enabled, check the pricing here |  | ||||||
| // 1 === $0.002 / 1K tokens |  | ||||||
| // 1 === ¥0.014 / 1k tokens |  | ||||||
| var ModelRatio = map[string]float64{ |  | ||||||
| 	"gpt-4":                     15, |  | ||||||
| 	"gpt-4-0314":                15, |  | ||||||
| 	"gpt-4-0613":                15, |  | ||||||
| 	"gpt-4-32k":                 30, |  | ||||||
| 	"gpt-4-32k-0314":            30, |  | ||||||
| 	"gpt-4-32k-0613":            30, |  | ||||||
| 	"gpt-4-1106-preview":        5,    // $0.01 / 1K tokens |  | ||||||
| 	"gpt-4-vision-preview":      5,    // $0.01 / 1K tokens |  | ||||||
| 	"gpt-3.5-turbo":             0.75, // $0.0015 / 1K tokens |  | ||||||
| 	"gpt-3.5-turbo-0301":        0.75, |  | ||||||
| 	"gpt-3.5-turbo-0613":        0.75, |  | ||||||
| 	"gpt-3.5-turbo-16k":         1.5, // $0.003 / 1K tokens |  | ||||||
| 	"gpt-3.5-turbo-16k-0613":    1.5, |  | ||||||
| 	"gpt-3.5-turbo-instruct":    0.75, // $0.0015 / 1K tokens |  | ||||||
| 	"gpt-3.5-turbo-1106":        0.5,  // $0.001 / 1K tokens |  | ||||||
| 	"davinci-002":               1,    // $0.002 / 1K tokens |  | ||||||
| 	"babbage-002":               0.2,  // $0.0004 / 1K tokens |  | ||||||
| 	"text-ada-001":              0.2, |  | ||||||
| 	"text-babbage-001":          0.25, |  | ||||||
| 	"text-curie-001":            1, |  | ||||||
| 	"text-davinci-002":          10, |  | ||||||
| 	"text-davinci-003":          10, |  | ||||||
| 	"text-davinci-edit-001":     10, |  | ||||||
| 	"code-davinci-edit-001":     10, |  | ||||||
| 	"whisper-1":                 15,  // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens |  | ||||||
| 	"tts-1":                     7.5, // $0.015 / 1K characters |  | ||||||
| 	"tts-1-1106":                7.5, |  | ||||||
| 	"tts-1-hd":                  15, // $0.030 / 1K characters |  | ||||||
| 	"tts-1-hd-1106":             15, |  | ||||||
| 	"davinci":                   10, |  | ||||||
| 	"curie":                     10, |  | ||||||
| 	"babbage":                   10, |  | ||||||
| 	"ada":                       10, |  | ||||||
| 	"text-embedding-ada-002":    0.05, |  | ||||||
| 	"text-search-ada-doc-001":   10, |  | ||||||
| 	"text-moderation-stable":    0.1, |  | ||||||
| 	"text-moderation-latest":    0.1, |  | ||||||
| 	"dall-e-2":                  8,      // $0.016 - $0.020 / image |  | ||||||
| 	"dall-e-3":                  20,     // $0.040 - $0.120 / image |  | ||||||
| 	"claude-instant-1":          0.815,  // $1.63 / 1M tokens |  | ||||||
| 	"claude-2":                  5.51,   // $11.02 / 1M tokens |  | ||||||
| 	"claude-2.0":                5.51,   // $11.02 / 1M tokens |  | ||||||
| 	"claude-2.1":                5.51,   // $11.02 / 1M tokens |  | ||||||
| 	"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens |  | ||||||
| 	"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens |  | ||||||
| 	"ERNIE-Bot-4":               8.572,  // ¥0.12 / 1k tokens |  | ||||||
| 	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens |  | ||||||
| 	"PaLM-2":                    1, |  | ||||||
| 	"gemini-pro":                1,      // $0.00025 / 1k characters -> $0.001 / 1k tokens |  | ||||||
| 	"gemini-pro-vision":         1,      // $0.00025 / 1k characters -> $0.001 / 1k tokens |  | ||||||
| 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens |  | ||||||
| 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens |  | ||||||
| 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens |  | ||||||
| 	"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens |  | ||||||
| 	"qwen-turbo":                0.5715, // ¥0.008 / 1k tokens  // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing |  | ||||||
| 	"qwen-plus":                 1.4286, // ¥0.02 / 1k tokens |  | ||||||
| 	"qwen-max":                  1.4286, // ¥0.02 / 1k tokens |  | ||||||
| 	"qwen-max-longcontext":      1.4286, // ¥0.02 / 1k tokens |  | ||||||
| 	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens |  | ||||||
| 	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens |  | ||||||
| 	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens |  | ||||||
| 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens |  | ||||||
| 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens |  | ||||||
| 	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens |  | ||||||
| 	"hunyuan":                   7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func ModelRatio2JSONString() string { |  | ||||||
| 	jsonBytes, err := json.Marshal(ModelRatio) |  | ||||||
| 	if err != nil { |  | ||||||
| 		SysError("error marshalling model ratio: " + err.Error()) |  | ||||||
| 	} |  | ||||||
| 	return string(jsonBytes) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func UpdateModelRatioByJSONString(jsonStr string) error { |  | ||||||
| 	ModelRatio = make(map[string]float64) |  | ||||||
| 	return json.Unmarshal([]byte(jsonStr), &ModelRatio) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetModelRatio(name string) float64 { |  | ||||||
| 	if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { |  | ||||||
| 		name = strings.TrimSuffix(name, "-internet") |  | ||||||
| 	} |  | ||||||
| 	ratio, ok := ModelRatio[name] |  | ||||||
| 	if !ok { |  | ||||||
| 		SysError("model ratio not found: " + name) |  | ||||||
| 		return 30 |  | ||||||
| 	} |  | ||||||
| 	return ratio |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetCompletionRatio(name string) float64 { |  | ||||||
| 	if strings.HasPrefix(name, "gpt-3.5") { |  | ||||||
| 		if strings.HasSuffix(name, "1106") { |  | ||||||
| 			return 2 |  | ||||||
| 		} |  | ||||||
| 		if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" { |  | ||||||
| 			// TODO: clear this after 2023-12-11 |  | ||||||
| 			now := time.Now() |  | ||||||
| 			// https://platform.openai.com/docs/models/continuous-model-upgrades |  | ||||||
| 			// if after 2023-12-11, use 2 |  | ||||||
| 			if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) { |  | ||||||
| 				return 2 |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		return 1.333333 |  | ||||||
| 	} |  | ||||||
| 	if strings.HasPrefix(name, "gpt-4") { |  | ||||||
| 		if strings.HasSuffix(name, "preview") { |  | ||||||
| 			return 3 |  | ||||||
| 		} |  | ||||||
| 		return 2 |  | ||||||
| 	} |  | ||||||
| 	if strings.HasPrefix(name, "claude-instant-1") { |  | ||||||
| 		return 3.38 |  | ||||||
| 	} |  | ||||||
| 	if strings.HasPrefix(name, "claude-2") { |  | ||||||
| 		return 2.965517 |  | ||||||
| 	} |  | ||||||
| 	return 1 |  | ||||||
| } |  | ||||||
							
								
								
									
										52
									
								
								common/network/ip.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								common/network/ip.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | |||||||
|  | package network | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"net" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func splitSubnets(subnets string) []string { | ||||||
|  | 	res := strings.Split(subnets, ",") | ||||||
|  | 	for i := 0; i < len(res); i++ { | ||||||
|  | 		res[i] = strings.TrimSpace(res[i]) | ||||||
|  | 	} | ||||||
|  | 	return res | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func isValidSubnet(subnet string) error { | ||||||
|  | 	_, _, err := net.ParseCIDR(subnet) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("failed to parse subnet: %w", err) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func isIpInSubnet(ctx context.Context, ip string, subnet string) bool { | ||||||
|  | 	_, ipNet, err := net.ParseCIDR(subnet) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Errorf(ctx, "failed to parse subnet: %s", err.Error()) | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return ipNet.Contains(net.ParseIP(ip)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func IsValidSubnets(subnets string) error { | ||||||
|  | 	for _, subnet := range splitSubnets(subnets) { | ||||||
|  | 		if err := isValidSubnet(subnet); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool { | ||||||
|  | 	for _, subnet := range splitSubnets(subnets) { | ||||||
|  | 		if isIpInSubnet(ctx, ip, subnet) { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
							
								
								
									
										19
									
								
								common/network/ip_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								common/network/ip_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | |||||||
|  | package network | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	. "github.com/smartystreets/goconvey/convey" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestIsIpInSubnet(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	ip1 := "192.168.0.5" | ||||||
|  | 	ip2 := "125.216.250.89" | ||||||
|  | 	subnet := "192.168.0.0/24" | ||||||
|  | 	Convey("TestIsIpInSubnet", t, func() { | ||||||
|  | 		So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue) | ||||||
|  | 		So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
							
								
								
									
										61
									
								
								common/random/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								common/random/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | |||||||
|  | package random | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/google/uuid" | ||||||
|  | 	"math/rand" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func GetUUID() string { | ||||||
|  | 	code := uuid.New().String() | ||||||
|  | 	code = strings.Replace(code, "-", "", -1) | ||||||
|  | 	return code | ||||||
|  | } | ||||||
|  |  | ||||||
|  | const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" | ||||||
|  | const keyNumbers = "0123456789" | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	rand.Seed(time.Now().UnixNano()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GenerateKey() string { | ||||||
|  | 	rand.Seed(time.Now().UnixNano()) | ||||||
|  | 	key := make([]byte, 48) | ||||||
|  | 	for i := 0; i < 16; i++ { | ||||||
|  | 		key[i] = keyChars[rand.Intn(len(keyChars))] | ||||||
|  | 	} | ||||||
|  | 	uuid_ := GetUUID() | ||||||
|  | 	for i := 0; i < 32; i++ { | ||||||
|  | 		c := uuid_[i] | ||||||
|  | 		if i%2 == 0 && c >= 'a' && c <= 'z' { | ||||||
|  | 			c = c - 'a' + 'A' | ||||||
|  | 		} | ||||||
|  | 		key[i+16] = c | ||||||
|  | 	} | ||||||
|  | 	return string(key) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetRandomString(length int) string { | ||||||
|  | 	rand.Seed(time.Now().UnixNano()) | ||||||
|  | 	key := make([]byte, length) | ||||||
|  | 	for i := 0; i < length; i++ { | ||||||
|  | 		key[i] = keyChars[rand.Intn(len(keyChars))] | ||||||
|  | 	} | ||||||
|  | 	return string(key) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetRandomNumberString(length int) string { | ||||||
|  | 	rand.Seed(time.Now().UnixNano()) | ||||||
|  | 	key := make([]byte, length) | ||||||
|  | 	for i := 0; i < length; i++ { | ||||||
|  | 		key[i] = keyNumbers[rand.Intn(len(keyNumbers))] | ||||||
|  | 	} | ||||||
|  | 	return string(key) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // RandRange returns a random number between min and max (max is not included) | ||||||
|  | func RandRange(min, max int) int { | ||||||
|  | 	return min + rand.Intn(max-min) | ||||||
|  | } | ||||||
| @@ -3,6 +3,7 @@ package common | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"github.com/go-redis/redis/v8" | 	"github.com/go-redis/redis/v8" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"os" | 	"os" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| @@ -14,18 +15,18 @@ var RedisEnabled = true | |||||||
| func InitRedisClient() (err error) { | func InitRedisClient() (err error) { | ||||||
| 	if os.Getenv("REDIS_CONN_STRING") == "" { | 	if os.Getenv("REDIS_CONN_STRING") == "" { | ||||||
| 		RedisEnabled = false | 		RedisEnabled = false | ||||||
| 		SysLog("REDIS_CONN_STRING not set, Redis is not enabled") | 		logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled") | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("SYNC_FREQUENCY") == "" { | 	if os.Getenv("SYNC_FREQUENCY") == "" { | ||||||
| 		RedisEnabled = false | 		RedisEnabled = false | ||||||
| 		SysLog("SYNC_FREQUENCY not set, Redis is disabled") | 		logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	SysLog("Redis is enabled") | 	logger.SysLog("Redis is enabled") | ||||||
| 	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) | 	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		FatalLog("failed to parse Redis connection string: " + err.Error()) | 		logger.FatalLog("failed to parse Redis connection string: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	RDB = redis.NewClient(opt) | 	RDB = redis.NewClient(opt) | ||||||
|  |  | ||||||
| @@ -34,7 +35,7 @@ func InitRedisClient() (err error) { | |||||||
|  |  | ||||||
| 	_, err = RDB.Ping(ctx).Result() | 	_, err = RDB.Ping(ctx).Result() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		FatalLog("Redis ping test failed: " + err.Error()) | 		logger.FatalLog("Redis ping test failed: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| @@ -42,7 +43,7 @@ func InitRedisClient() (err error) { | |||||||
| func ParseRedisOption() *redis.Options { | func ParseRedisOption() *redis.Options { | ||||||
| 	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) | 	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		FatalLog("failed to parse Redis connection string: " + err.Error()) | 		logger.FatalLog("failed to parse Redis connection string: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	return opt | 	return opt | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										29
									
								
								common/render/render.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								common/render/render.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | |||||||
|  | package render | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func StringData(c *gin.Context, str string) { | ||||||
|  | 	str = strings.TrimPrefix(str, "data: ") | ||||||
|  | 	str = strings.TrimSuffix(str, "\r") | ||||||
|  | 	c.Render(-1, common.CustomEvent{Data: "data: " + str}) | ||||||
|  | 	c.Writer.Flush() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ObjectData(c *gin.Context, object interface{}) error { | ||||||
|  | 	jsonData, err := json.Marshal(object) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("error marshalling object: %w", err) | ||||||
|  | 	} | ||||||
|  | 	StringData(c, string(jsonData)) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Done(c *gin.Context) { | ||||||
|  | 	StringData(c, "[DONE]") | ||||||
|  | } | ||||||
							
								
								
									
										212
									
								
								common/utils.go
									
									
									
									
									
								
							
							
						
						
									
										212
									
								
								common/utils.go
									
									
									
									
									
								
							| @@ -2,215 +2,13 @@ package common | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/google/uuid" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| 	"html/template" |  | ||||||
| 	"log" |  | ||||||
| 	"math/rand" |  | ||||||
| 	"net" |  | ||||||
| 	"os" |  | ||||||
| 	"os/exec" |  | ||||||
| 	"runtime" |  | ||||||
| 	"strconv" |  | ||||||
| 	"strings" |  | ||||||
| 	"time" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func OpenBrowser(url string) { | func LogQuota(quota int64) string { | ||||||
| 	var err error | 	if config.DisplayInCurrencyEnabled { | ||||||
|  | 		return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) | ||||||
| 	switch runtime.GOOS { |  | ||||||
| 	case "linux": |  | ||||||
| 		err = exec.Command("xdg-open", url).Start() |  | ||||||
| 	case "windows": |  | ||||||
| 		err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() |  | ||||||
| 	case "darwin": |  | ||||||
| 		err = exec.Command("open", url).Start() |  | ||||||
| 	} |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Println(err) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetIp() (ip string) { |  | ||||||
| 	ips, err := net.InterfaceAddrs() |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Println(err) |  | ||||||
| 		return ip |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for _, a := range ips { |  | ||||||
| 		if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { |  | ||||||
| 			if ipNet.IP.To4() != nil { |  | ||||||
| 				ip = ipNet.IP.String() |  | ||||||
| 				if strings.HasPrefix(ip, "10") { |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
| 				if strings.HasPrefix(ip, "172") { |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
| 				if strings.HasPrefix(ip, "192.168") { |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
| 				ip = "" |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var sizeKB = 1024 |  | ||||||
| var sizeMB = sizeKB * 1024 |  | ||||||
| var sizeGB = sizeMB * 1024 |  | ||||||
|  |  | ||||||
| func Bytes2Size(num int64) string { |  | ||||||
| 	numStr := "" |  | ||||||
| 	unit := "B" |  | ||||||
| 	if num/int64(sizeGB) > 1 { |  | ||||||
| 		numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB)) |  | ||||||
| 		unit = "GB" |  | ||||||
| 	} else if num/int64(sizeMB) > 1 { |  | ||||||
| 		numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB))) |  | ||||||
| 		unit = "MB" |  | ||||||
| 	} else if num/int64(sizeKB) > 1 { |  | ||||||
| 		numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB))) |  | ||||||
| 		unit = "KB" |  | ||||||
| 	} else { | 	} else { | ||||||
| 		numStr = fmt.Sprintf("%d", num) | 		return fmt.Sprintf("%d 点额度", quota) | ||||||
| 	} |  | ||||||
| 	return numStr + " " + unit |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func Seconds2Time(num int) (time string) { |  | ||||||
| 	if num/31104000 > 0 { |  | ||||||
| 		time += strconv.Itoa(num/31104000) + " 年 " |  | ||||||
| 		num %= 31104000 |  | ||||||
| 	} |  | ||||||
| 	if num/2592000 > 0 { |  | ||||||
| 		time += strconv.Itoa(num/2592000) + " 个月 " |  | ||||||
| 		num %= 2592000 |  | ||||||
| 	} |  | ||||||
| 	if num/86400 > 0 { |  | ||||||
| 		time += strconv.Itoa(num/86400) + " 天 " |  | ||||||
| 		num %= 86400 |  | ||||||
| 	} |  | ||||||
| 	if num/3600 > 0 { |  | ||||||
| 		time += strconv.Itoa(num/3600) + " 小时 " |  | ||||||
| 		num %= 3600 |  | ||||||
| 	} |  | ||||||
| 	if num/60 > 0 { |  | ||||||
| 		time += strconv.Itoa(num/60) + " 分钟 " |  | ||||||
| 		num %= 60 |  | ||||||
| 	} |  | ||||||
| 	time += strconv.Itoa(num) + " 秒" |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func Interface2String(inter interface{}) string { |  | ||||||
| 	switch inter.(type) { |  | ||||||
| 	case string: |  | ||||||
| 		return inter.(string) |  | ||||||
| 	case int: |  | ||||||
| 		return fmt.Sprintf("%d", inter.(int)) |  | ||||||
| 	case float64: |  | ||||||
| 		return fmt.Sprintf("%f", inter.(float64)) |  | ||||||
| 	} |  | ||||||
| 	return "Not Implemented" |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func UnescapeHTML(x string) interface{} { |  | ||||||
| 	return template.HTML(x) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func IntMax(a int, b int) int { |  | ||||||
| 	if a >= b { |  | ||||||
| 		return a |  | ||||||
| 	} else { |  | ||||||
| 		return b |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetUUID() string { |  | ||||||
| 	code := uuid.New().String() |  | ||||||
| 	code = strings.Replace(code, "-", "", -1) |  | ||||||
| 	return code |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" |  | ||||||
|  |  | ||||||
| func init() { |  | ||||||
| 	rand.Seed(time.Now().UnixNano()) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GenerateKey() string { |  | ||||||
| 	rand.Seed(time.Now().UnixNano()) |  | ||||||
| 	key := make([]byte, 48) |  | ||||||
| 	for i := 0; i < 16; i++ { |  | ||||||
| 		key[i] = keyChars[rand.Intn(len(keyChars))] |  | ||||||
| 	} |  | ||||||
| 	uuid_ := GetUUID() |  | ||||||
| 	for i := 0; i < 32; i++ { |  | ||||||
| 		c := uuid_[i] |  | ||||||
| 		if i%2 == 0 && c >= 'a' && c <= 'z' { |  | ||||||
| 			c = c - 'a' + 'A' |  | ||||||
| 		} |  | ||||||
| 		key[i+16] = c |  | ||||||
| 	} |  | ||||||
| 	return string(key) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetRandomString(length int) string { |  | ||||||
| 	rand.Seed(time.Now().UnixNano()) |  | ||||||
| 	key := make([]byte, length) |  | ||||||
| 	for i := 0; i < length; i++ { |  | ||||||
| 		key[i] = keyChars[rand.Intn(len(keyChars))] |  | ||||||
| 	} |  | ||||||
| 	return string(key) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetTimestamp() int64 { |  | ||||||
| 	return time.Now().Unix() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetTimeString() string { |  | ||||||
| 	now := time.Now() |  | ||||||
| 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func Max(a int, b int) int { |  | ||||||
| 	if a >= b { |  | ||||||
| 		return a |  | ||||||
| 	} else { |  | ||||||
| 		return b |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetOrDefault(env string, defaultValue int) int { |  | ||||||
| 	if env == "" || os.Getenv(env) == "" { |  | ||||||
| 		return defaultValue |  | ||||||
| 	} |  | ||||||
| 	num, err := strconv.Atoi(os.Getenv(env)) |  | ||||||
| 	if err != nil { |  | ||||||
| 		SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) |  | ||||||
| 		return defaultValue |  | ||||||
| 	} |  | ||||||
| 	return num |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetOrDefaultString(env string, defaultValue string) string { |  | ||||||
| 	if env == "" || os.Getenv(env) == "" { |  | ||||||
| 		return defaultValue |  | ||||||
| 	} |  | ||||||
| 	return os.Getenv(env) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func MessageWithRequestId(message string, id string) string { |  | ||||||
| 	return fmt.Sprintf("%s (request id: %s)", message, id) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func String2Int(str string) int { |  | ||||||
| 	num, err := strconv.Atoi(str) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0 |  | ||||||
| 	} |  | ||||||
| 	return num |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| package controller | package auth | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| @@ -7,9 +7,12 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-contrib/sessions" | 	"github.com/gin-contrib/sessions" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/random" | ||||||
|  | 	"github.com/songquanpeng/one-api/controller" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/model" |  | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| @@ -30,7 +33,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { | |||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		return nil, errors.New("无效的参数") | 		return nil, errors.New("无效的参数") | ||||||
| 	} | 	} | ||||||
| 	values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code} | 	values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code} | ||||||
| 	jsonData, err := json.Marshal(values) | 	jsonData, err := json.Marshal(values) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -46,7 +49,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { | |||||||
| 	} | 	} | ||||||
| 	res, err := client.Do(req) | 	res, err := client.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysLog(err.Error()) | 		logger.SysLog(err.Error()) | ||||||
| 		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") | 		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") | ||||||
| 	} | 	} | ||||||
| 	defer res.Body.Close() | 	defer res.Body.Close() | ||||||
| @@ -62,7 +65,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { | |||||||
| 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) | 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) | ||||||
| 	res2, err := client.Do(req) | 	res2, err := client.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysLog(err.Error()) | 		logger.SysLog(err.Error()) | ||||||
| 		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") | 		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") | ||||||
| 	} | 	} | ||||||
| 	defer res2.Body.Close() | 	defer res2.Body.Close() | ||||||
| @@ -93,7 +96,7 @@ func GitHubOAuth(c *gin.Context) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if !common.GitHubOAuthEnabled { | 	if !config.GitHubOAuthEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "管理员未开启通过 GitHub 登录以及注册", | 			"message": "管理员未开启通过 GitHub 登录以及注册", | ||||||
| @@ -122,7 +125,7 @@ func GitHubOAuth(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		if common.RegisterEnabled { | 		if config.RegisterEnabled { | ||||||
| 			user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) | 			user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) | ||||||
| 			if githubUser.Name != "" { | 			if githubUser.Name != "" { | ||||||
| 				user.DisplayName = githubUser.Name | 				user.DisplayName = githubUser.Name | ||||||
| @@ -130,8 +133,8 @@ func GitHubOAuth(c *gin.Context) { | |||||||
| 				user.DisplayName = "GitHub User" | 				user.DisplayName = "GitHub User" | ||||||
| 			} | 			} | ||||||
| 			user.Email = githubUser.Email | 			user.Email = githubUser.Email | ||||||
| 			user.Role = common.RoleCommonUser | 			user.Role = model.RoleCommonUser | ||||||
| 			user.Status = common.UserStatusEnabled | 			user.Status = model.UserStatusEnabled | ||||||
| 
 | 
 | ||||||
| 			if err := user.Insert(0); err != nil { | 			if err := user.Insert(0); err != nil { | ||||||
| 				c.JSON(http.StatusOK, gin.H{ | 				c.JSON(http.StatusOK, gin.H{ | ||||||
| @@ -149,18 +152,18 @@ func GitHubOAuth(c *gin.Context) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if user.Status != common.UserStatusEnabled { | 	if user.Status != model.UserStatusEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"message": "用户已被封禁", | 			"message": "用户已被封禁", | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	setupLogin(&user, c) | 	controller.SetupLogin(&user, c) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func GitHubBind(c *gin.Context) { | func GitHubBind(c *gin.Context) { | ||||||
| 	if !common.GitHubOAuthEnabled { | 	if !config.GitHubOAuthEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "管理员未开启通过 GitHub 登录以及注册", | 			"message": "管理员未开启通过 GitHub 登录以及注册", | ||||||
| @@ -216,7 +219,7 @@ func GitHubBind(c *gin.Context) { | |||||||
| 
 | 
 | ||||||
| func GenerateOAuthCode(c *gin.Context) { | func GenerateOAuthCode(c *gin.Context) { | ||||||
| 	session := sessions.Default(c) | 	session := sessions.Default(c) | ||||||
| 	state := common.GetRandomString(12) | 	state := random.GetRandomString(12) | ||||||
| 	session.Set("oauth_state", state) | 	session.Set("oauth_state", state) | ||||||
| 	err := session.Save() | 	err := session.Save() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
							
								
								
									
										200
									
								
								controller/auth/lark.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										200
									
								
								controller/auth/lark.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,200 @@ | |||||||
|  | package auth | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-contrib/sessions" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/controller" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strconv" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type LarkOAuthResponse struct { | ||||||
|  | 	AccessToken string `json:"access_token"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type LarkUser struct { | ||||||
|  | 	Name   string `json:"name"` | ||||||
|  | 	OpenID string `json:"open_id"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getLarkUserInfoByCode(code string) (*LarkUser, error) { | ||||||
|  | 	if code == "" { | ||||||
|  | 		return nil, errors.New("无效的参数") | ||||||
|  | 	} | ||||||
|  | 	values := map[string]string{ | ||||||
|  | 		"client_id":     config.LarkClientId, | ||||||
|  | 		"client_secret": config.LarkClientSecret, | ||||||
|  | 		"code":          code, | ||||||
|  | 		"grant_type":    "authorization_code", | ||||||
|  | 		"redirect_uri":  fmt.Sprintf("%s/oauth/lark", config.ServerAddress), | ||||||
|  | 	} | ||||||
|  | 	jsonData, err := json.Marshal(values) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	req.Header.Set("Content-Type", "application/json") | ||||||
|  | 	req.Header.Set("Accept", "application/json") | ||||||
|  | 	client := http.Client{ | ||||||
|  | 		Timeout: 5 * time.Second, | ||||||
|  | 	} | ||||||
|  | 	res, err := client.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.SysLog(err.Error()) | ||||||
|  | 		return nil, errors.New("无法连接至飞书服务器,请稍后重试!") | ||||||
|  | 	} | ||||||
|  | 	defer res.Body.Close() | ||||||
|  | 	var oAuthResponse LarkOAuthResponse | ||||||
|  | 	err = json.NewDecoder(res.Body).Decode(&oAuthResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) | ||||||
|  | 	res2, err := client.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.SysLog(err.Error()) | ||||||
|  | 		return nil, errors.New("无法连接至飞书服务器,请稍后重试!") | ||||||
|  | 	} | ||||||
|  | 	var larkUser LarkUser | ||||||
|  | 	err = json.NewDecoder(res2.Body).Decode(&larkUser) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return &larkUser, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func LarkOAuth(c *gin.Context) { | ||||||
|  | 	session := sessions.Default(c) | ||||||
|  | 	state := c.Query("state") | ||||||
|  | 	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { | ||||||
|  | 		c.JSON(http.StatusForbidden, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": "state is empty or not same", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	username := session.Get("username") | ||||||
|  | 	if username != nil { | ||||||
|  | 		LarkBind(c) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	code := c.Query("code") | ||||||
|  | 	larkUser, err := getLarkUserInfoByCode(code) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	user := model.User{ | ||||||
|  | 		LarkId: larkUser.OpenID, | ||||||
|  | 	} | ||||||
|  | 	if model.IsLarkIdAlreadyTaken(user.LarkId) { | ||||||
|  | 		err := user.FillUserByLarkId() | ||||||
|  | 		if err != nil { | ||||||
|  | 			c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 				"success": false, | ||||||
|  | 				"message": err.Error(), | ||||||
|  | 			}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		if config.RegisterEnabled { | ||||||
|  | 			user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1) | ||||||
|  | 			if larkUser.Name != "" { | ||||||
|  | 				user.DisplayName = larkUser.Name | ||||||
|  | 			} else { | ||||||
|  | 				user.DisplayName = "Lark User" | ||||||
|  | 			} | ||||||
|  | 			user.Role = model.RoleCommonUser | ||||||
|  | 			user.Status = model.UserStatusEnabled | ||||||
|  |  | ||||||
|  | 			if err := user.Insert(0); err != nil { | ||||||
|  | 				c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 					"success": false, | ||||||
|  | 					"message": err.Error(), | ||||||
|  | 				}) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 				"success": false, | ||||||
|  | 				"message": "管理员关闭了新用户注册", | ||||||
|  | 			}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if user.Status != model.UserStatusEnabled { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"message": "用户已被封禁", | ||||||
|  | 			"success": false, | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	controller.SetupLogin(&user, c) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func LarkBind(c *gin.Context) { | ||||||
|  | 	code := c.Query("code") | ||||||
|  | 	larkUser, err := getLarkUserInfoByCode(code) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	user := model.User{ | ||||||
|  | 		LarkId: larkUser.OpenID, | ||||||
|  | 	} | ||||||
|  | 	if model.IsLarkIdAlreadyTaken(user.LarkId) { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": "该飞书账户已被绑定", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	session := sessions.Default(c) | ||||||
|  | 	id := session.Get("id") | ||||||
|  | 	// id := c.GetInt("id")  // critical bug! | ||||||
|  | 	user.Id = id.(int) | ||||||
|  | 	err = user.FillUserById() | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	user.LarkId = larkUser.OpenID | ||||||
|  | 	err = user.Update(false) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 		"success": true, | ||||||
|  | 		"message": "bind", | ||||||
|  | 	}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
| @@ -1,13 +1,15 @@ | |||||||
| package controller | package auth | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||||
|  | 	"github.com/songquanpeng/one-api/controller" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/model" |  | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| @@ -22,11 +24,11 @@ func getWeChatIdByCode(code string) (string, error) { | |||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		return "", errors.New("无效的参数") | 		return "", errors.New("无效的参数") | ||||||
| 	} | 	} | ||||||
| 	req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil) | 	req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
| 	req.Header.Set("Authorization", common.WeChatServerToken) | 	req.Header.Set("Authorization", config.WeChatServerToken) | ||||||
| 	client := http.Client{ | 	client := http.Client{ | ||||||
| 		Timeout: 5 * time.Second, | 		Timeout: 5 * time.Second, | ||||||
| 	} | 	} | ||||||
| @@ -50,7 +52,7 @@ func getWeChatIdByCode(code string) (string, error) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func WeChatAuth(c *gin.Context) { | func WeChatAuth(c *gin.Context) { | ||||||
| 	if !common.WeChatAuthEnabled { | 	if !config.WeChatAuthEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"message": "管理员未开启通过微信登录以及注册", | 			"message": "管理员未开启通过微信登录以及注册", | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -79,11 +81,11 @@ func WeChatAuth(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		if common.RegisterEnabled { | 		if config.RegisterEnabled { | ||||||
| 			user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) | 			user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) | ||||||
| 			user.DisplayName = "WeChat User" | 			user.DisplayName = "WeChat User" | ||||||
| 			user.Role = common.RoleCommonUser | 			user.Role = model.RoleCommonUser | ||||||
| 			user.Status = common.UserStatusEnabled | 			user.Status = model.UserStatusEnabled | ||||||
| 
 | 
 | ||||||
| 			if err := user.Insert(0); err != nil { | 			if err := user.Insert(0); err != nil { | ||||||
| 				c.JSON(http.StatusOK, gin.H{ | 				c.JSON(http.StatusOK, gin.H{ | ||||||
| @@ -101,18 +103,18 @@ func WeChatAuth(c *gin.Context) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if user.Status != common.UserStatusEnabled { | 	if user.Status != model.UserStatusEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"message": "用户已被封禁", | 			"message": "用户已被封禁", | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	setupLogin(&user, c) | 	controller.SetupLogin(&user, c) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func WeChatBind(c *gin.Context) { | func WeChatBind(c *gin.Context) { | ||||||
| 	if !common.WeChatAuthEnabled { | 	if !config.WeChatAuthEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"message": "管理员未开启通过微信登录以及注册", | 			"message": "管理员未开启通过微信登录以及注册", | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -135,7 +137,7 @@ func WeChatBind(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	id := c.GetInt("id") | 	id := c.GetInt(ctxkey.Id) | ||||||
| 	user := model.User{ | 	user := model.User{ | ||||||
| 		Id: id, | 		Id: id, | ||||||
| 	} | 	} | ||||||
| @@ -2,44 +2,48 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"one-api/common" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| 	"one-api/model" | 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
|  | 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func GetSubscription(c *gin.Context) { | func GetSubscription(c *gin.Context) { | ||||||
| 	var remainQuota int | 	var remainQuota int64 | ||||||
| 	var usedQuota int | 	var usedQuota int64 | ||||||
| 	var err error | 	var err error | ||||||
| 	var token *model.Token | 	var token *model.Token | ||||||
| 	var expiredTime int64 | 	var expiredTime int64 | ||||||
| 	if common.DisplayTokenStatEnabled { | 	if config.DisplayTokenStatEnabled { | ||||||
| 		tokenId := c.GetInt("token_id") | 		tokenId := c.GetInt(ctxkey.TokenId) | ||||||
| 		token, err = model.GetTokenById(tokenId) | 		token, err = model.GetTokenById(tokenId) | ||||||
| 		expiredTime = token.ExpiredTime | 		expiredTime = token.ExpiredTime | ||||||
| 		remainQuota = token.RemainQuota | 		remainQuota = token.RemainQuota | ||||||
| 		usedQuota = token.UsedQuota | 		usedQuota = token.UsedQuota | ||||||
| 	} else { | 	} else { | ||||||
| 		userId := c.GetInt("id") | 		userId := c.GetInt(ctxkey.Id) | ||||||
| 		remainQuota, err = model.GetUserQuota(userId) | 		remainQuota, err = model.GetUserQuota(userId) | ||||||
| 		usedQuota, err = model.GetUserUsedQuota(userId) | 		if err != nil { | ||||||
|  | 			usedQuota, err = model.GetUserUsedQuota(userId) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 	if expiredTime <= 0 { | 	if expiredTime <= 0 { | ||||||
| 		expiredTime = 0 | 		expiredTime = 0 | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		openAIError := OpenAIError{ | 		Error := relaymodel.Error{ | ||||||
| 			Message: err.Error(), | 			Message: err.Error(), | ||||||
| 			Type:    "upstream_error", | 			Type:    "upstream_error", | ||||||
| 		} | 		} | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"error": openAIError, | 			"error": Error, | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	quota := remainQuota + usedQuota | 	quota := remainQuota + usedQuota | ||||||
| 	amount := float64(quota) | 	amount := float64(quota) | ||||||
| 	if common.DisplayInCurrencyEnabled { | 	if config.DisplayInCurrencyEnabled { | ||||||
| 		amount /= common.QuotaPerUnit | 		amount /= config.QuotaPerUnit | ||||||
| 	} | 	} | ||||||
| 	if token != nil && token.UnlimitedQuota { | 	if token != nil && token.UnlimitedQuota { | ||||||
| 		amount = 100000000 | 		amount = 100000000 | ||||||
| @@ -57,30 +61,30 @@ func GetSubscription(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GetUsage(c *gin.Context) { | func GetUsage(c *gin.Context) { | ||||||
| 	var quota int | 	var quota int64 | ||||||
| 	var err error | 	var err error | ||||||
| 	var token *model.Token | 	var token *model.Token | ||||||
| 	if common.DisplayTokenStatEnabled { | 	if config.DisplayTokenStatEnabled { | ||||||
| 		tokenId := c.GetInt("token_id") | 		tokenId := c.GetInt(ctxkey.TokenId) | ||||||
| 		token, err = model.GetTokenById(tokenId) | 		token, err = model.GetTokenById(tokenId) | ||||||
| 		quota = token.UsedQuota | 		quota = token.UsedQuota | ||||||
| 	} else { | 	} else { | ||||||
| 		userId := c.GetInt("id") | 		userId := c.GetInt(ctxkey.Id) | ||||||
| 		quota, err = model.GetUserUsedQuota(userId) | 		quota, err = model.GetUserUsedQuota(userId) | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		openAIError := OpenAIError{ | 		Error := relaymodel.Error{ | ||||||
| 			Message: err.Error(), | 			Message: err.Error(), | ||||||
| 			Type:    "one_api_error", | 			Type:    "one_api_error", | ||||||
| 		} | 		} | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"error": openAIError, | 			"error": Error, | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	amount := float64(quota) | 	amount := float64(quota) | ||||||
| 	if common.DisplayInCurrencyEnabled { | 	if config.DisplayInCurrencyEnabled { | ||||||
| 		amount /= common.QuotaPerUnit | 		amount /= config.QuotaPerUnit | ||||||
| 	} | 	} | ||||||
| 	usage := OpenAIUsageResponse{ | 	usage := OpenAIUsageResponse{ | ||||||
| 		Object:     "list", | 		Object:     "list", | ||||||
|   | |||||||
| @@ -4,10 +4,14 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/client" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/monitor" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/model" |  | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -92,7 +96,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He | |||||||
| 	for k := range headers { | 	for k := range headers { | ||||||
| 		req.Header.Add(k, headers.Get(k)) | 		req.Header.Add(k, headers.Get(k)) | ||||||
| 	} | 	} | ||||||
| 	res, err := httpClient.Do(req) | 	res, err := client.HTTPClient.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @@ -200,28 +204,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func updateChannelBalance(channel *model.Channel) (float64, error) { | func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||||
| 	baseURL := common.ChannelBaseURLs[channel.Type] | 	baseURL := channeltype.ChannelBaseURLs[channel.Type] | ||||||
| 	if channel.GetBaseURL() == "" { | 	if channel.GetBaseURL() == "" { | ||||||
| 		channel.BaseURL = &baseURL | 		channel.BaseURL = &baseURL | ||||||
| 	} | 	} | ||||||
| 	switch channel.Type { | 	switch channel.Type { | ||||||
| 	case common.ChannelTypeOpenAI: | 	case channeltype.OpenAI: | ||||||
| 		if channel.GetBaseURL() != "" { | 		if channel.GetBaseURL() != "" { | ||||||
| 			baseURL = channel.GetBaseURL() | 			baseURL = channel.GetBaseURL() | ||||||
| 		} | 		} | ||||||
| 	case common.ChannelTypeAzure: | 	case channeltype.Azure: | ||||||
| 		return 0, errors.New("尚未实现") | 		return 0, errors.New("尚未实现") | ||||||
| 	case common.ChannelTypeCustom: | 	case channeltype.Custom: | ||||||
| 		baseURL = channel.GetBaseURL() | 		baseURL = channel.GetBaseURL() | ||||||
| 	case common.ChannelTypeCloseAI: | 	case channeltype.CloseAI: | ||||||
| 		return updateChannelCloseAIBalance(channel) | 		return updateChannelCloseAIBalance(channel) | ||||||
| 	case common.ChannelTypeOpenAISB: | 	case channeltype.OpenAISB: | ||||||
| 		return updateChannelOpenAISBBalance(channel) | 		return updateChannelOpenAISBBalance(channel) | ||||||
| 	case common.ChannelTypeAIProxy: | 	case channeltype.AIProxy: | ||||||
| 		return updateChannelAIProxyBalance(channel) | 		return updateChannelAIProxyBalance(channel) | ||||||
| 	case common.ChannelTypeAPI2GPT: | 	case channeltype.API2GPT: | ||||||
| 		return updateChannelAPI2GPTBalance(channel) | 		return updateChannelAPI2GPTBalance(channel) | ||||||
| 	case common.ChannelTypeAIGC2D: | 	case channeltype.AIGC2D: | ||||||
| 		return updateChannelAIGC2DBalance(channel) | 		return updateChannelAIGC2DBalance(channel) | ||||||
| 	default: | 	default: | ||||||
| 		return 0, errors.New("尚未实现") | 		return 0, errors.New("尚未实现") | ||||||
| @@ -292,16 +296,16 @@ func UpdateChannelBalance(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func updateAllChannelsBalance() error { | func updateAllChannelsBalance() error { | ||||||
| 	channels, err := model.GetAllChannels(0, 0, true) | 	channels, err := model.GetAllChannels(0, 0, "all") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	for _, channel := range channels { | 	for _, channel := range channels { | ||||||
| 		if channel.Status != common.ChannelStatusEnabled { | 		if channel.Status != model.ChannelStatusEnabled { | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		// TODO: support Azure | 		// TODO: support Azure | ||||||
| 		if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { | 		if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom { | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		balance, err := updateChannelBalance(channel) | 		balance, err := updateChannelBalance(channel) | ||||||
| @@ -310,24 +314,23 @@ func updateAllChannelsBalance() error { | |||||||
| 		} else { | 		} else { | ||||||
| 			// err is nil & balance <= 0 means quota is used up | 			// err is nil & balance <= 0 means quota is used up | ||||||
| 			if balance <= 0 { | 			if balance <= 0 { | ||||||
| 				disableChannel(channel.Id, channel.Name, "余额不足") | 				monitor.DisableChannel(channel.Id, channel.Name, "余额不足") | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		time.Sleep(common.RequestInterval) | 		time.Sleep(config.RequestInterval) | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateAllChannelsBalance(c *gin.Context) { | func UpdateAllChannelsBalance(c *gin.Context) { | ||||||
| 	// TODO: make it async | 	//err := updateAllChannelsBalance() | ||||||
| 	err := updateAllChannelsBalance() | 	//if err != nil { | ||||||
| 	if err != nil { | 	//	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 	//		"success": false, | ||||||
| 			"success": false, | 	//		"message": err.Error(), | ||||||
| 			"message": err.Error(), | 	//	}) | ||||||
| 		}) | 	//	return | ||||||
| 		return | 	//} | ||||||
| 	} |  | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| @@ -338,8 +341,8 @@ func UpdateAllChannelsBalance(c *gin.Context) { | |||||||
| func AutomaticallyUpdateChannels(frequency int) { | func AutomaticallyUpdateChannels(frequency int) { | ||||||
| 	for { | 	for { | ||||||
| 		time.Sleep(time.Duration(frequency) * time.Minute) | 		time.Sleep(time.Duration(frequency) * time.Minute) | ||||||
| 		common.SysLog("updating all channels") | 		logger.SysLog("updating all channels") | ||||||
| 		_ = updateAllChannelsBalance() | 		_ = updateAllChannelsBalance() | ||||||
| 		common.SysLog("channels update done") | 		logger.SysLog("channels update done") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -7,96 +7,37 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"net/http/httptest" | ||||||
| 	"one-api/model" | 	"net/url" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/message" | ||||||
|  | 	"github.com/songquanpeng/one-api/middleware" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/monitor" | ||||||
|  | 	relay "github.com/songquanpeng/one-api/relay" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/controller" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/meta" | ||||||
|  | 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { | func buildTestRequest() *relaymodel.GeneralOpenAIRequest { | ||||||
| 	switch channel.Type { | 	testRequest := &relaymodel.GeneralOpenAIRequest{ | ||||||
| 	case common.ChannelTypePaLM: | 		MaxTokens: 2, | ||||||
| 		fallthrough | 		Stream:    false, | ||||||
| 	case common.ChannelTypeGemini: | 		Model:     "gpt-3.5-turbo", | ||||||
| 		fallthrough |  | ||||||
| 	case common.ChannelTypeAnthropic: |  | ||||||
| 		fallthrough |  | ||||||
| 	case common.ChannelTypeBaidu: |  | ||||||
| 		fallthrough |  | ||||||
| 	case common.ChannelTypeZhipu: |  | ||||||
| 		fallthrough |  | ||||||
| 	case common.ChannelTypeAli: |  | ||||||
| 		fallthrough |  | ||||||
| 	case common.ChannelType360: |  | ||||||
| 		fallthrough |  | ||||||
| 	case common.ChannelTypeXunfei: |  | ||||||
| 		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil |  | ||||||
| 	case common.ChannelTypeAzure: |  | ||||||
| 		request.Model = "gpt-35-turbo" |  | ||||||
| 		defer func() { |  | ||||||
| 			if err != nil { |  | ||||||
| 				err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!") |  | ||||||
| 			} |  | ||||||
| 		}() |  | ||||||
| 	default: |  | ||||||
| 		request.Model = "gpt-3.5-turbo" |  | ||||||
| 	} | 	} | ||||||
| 	requestURL := common.ChannelBaseURLs[channel.Type] | 	testMessage := relaymodel.Message{ | ||||||
| 	if channel.Type == common.ChannelTypeAzure { |  | ||||||
| 		requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) |  | ||||||
| 	} else { |  | ||||||
| 		if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { |  | ||||||
| 			requestURL = baseURL |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) |  | ||||||
| 	} |  | ||||||
| 	jsonData, err := json.Marshal(request) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err, nil |  | ||||||
| 	} |  | ||||||
| 	req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err, nil |  | ||||||
| 	} |  | ||||||
| 	if channel.Type == common.ChannelTypeAzure { |  | ||||||
| 		req.Header.Set("api-key", channel.Key) |  | ||||||
| 	} else { |  | ||||||
| 		req.Header.Set("Authorization", "Bearer "+channel.Key) |  | ||||||
| 	} |  | ||||||
| 	req.Header.Set("Content-Type", "application/json") |  | ||||||
| 	resp, err := httpClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err, nil |  | ||||||
| 	} |  | ||||||
| 	defer resp.Body.Close() |  | ||||||
| 	var response TextResponse |  | ||||||
| 	body, err := io.ReadAll(resp.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err, nil |  | ||||||
| 	} |  | ||||||
| 	err = json.Unmarshal(body, &response) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil |  | ||||||
| 	} |  | ||||||
| 	if response.Usage.CompletionTokens == 0 { |  | ||||||
| 		if response.Error.Message == "" { |  | ||||||
| 			response.Error.Message = "补全 tokens 非预期返回 0" |  | ||||||
| 		} |  | ||||||
| 		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error |  | ||||||
| 	} |  | ||||||
| 	return nil, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func buildTestRequest() *ChatRequest { |  | ||||||
| 	testRequest := &ChatRequest{ |  | ||||||
| 		Model:     "", // this will be set later |  | ||||||
| 		MaxTokens: 1, |  | ||||||
| 	} |  | ||||||
| 	testMessage := Message{ |  | ||||||
| 		Role:    "user", | 		Role:    "user", | ||||||
| 		Content: "hi", | 		Content: "hi", | ||||||
| 	} | 	} | ||||||
| @@ -104,6 +45,83 @@ func buildTestRequest() *ChatRequest { | |||||||
| 	return testRequest | 	return testRequest | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) { | ||||||
|  | 	w := httptest.NewRecorder() | ||||||
|  | 	c, _ := gin.CreateTestContext(w) | ||||||
|  | 	c.Request = &http.Request{ | ||||||
|  | 		Method: "POST", | ||||||
|  | 		URL:    &url.URL{Path: "/v1/chat/completions"}, | ||||||
|  | 		Body:   nil, | ||||||
|  | 		Header: make(http.Header), | ||||||
|  | 	} | ||||||
|  | 	c.Request.Header.Set("Authorization", "Bearer "+channel.Key) | ||||||
|  | 	c.Request.Header.Set("Content-Type", "application/json") | ||||||
|  | 	c.Set(ctxkey.Channel, channel.Type) | ||||||
|  | 	c.Set(ctxkey.BaseURL, channel.GetBaseURL()) | ||||||
|  | 	cfg, _ := channel.LoadConfig() | ||||||
|  | 	c.Set(ctxkey.Config, cfg) | ||||||
|  | 	middleware.SetupContextForSelectedChannel(c, channel, "") | ||||||
|  | 	meta := meta.GetByContext(c) | ||||||
|  | 	apiType := channeltype.ToAPIType(channel.Type) | ||||||
|  | 	adaptor := relay.GetAdaptor(apiType) | ||||||
|  | 	if adaptor == nil { | ||||||
|  | 		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil | ||||||
|  | 	} | ||||||
|  | 	adaptor.Init(meta) | ||||||
|  | 	var modelName string | ||||||
|  | 	modelList := adaptor.GetModelList() | ||||||
|  | 	modelMap := channel.GetModelMapping() | ||||||
|  | 	if len(modelList) != 0 { | ||||||
|  | 		modelName = modelList[0] | ||||||
|  | 	} | ||||||
|  | 	if modelName == "" || !strings.Contains(channel.Models, modelName) { | ||||||
|  | 		modelNames := strings.Split(channel.Models, ",") | ||||||
|  | 		if len(modelNames) > 0 { | ||||||
|  | 			modelName = modelNames[0] | ||||||
|  | 		} | ||||||
|  | 		if modelMap != nil && modelMap[modelName] != "" { | ||||||
|  | 			modelName = modelMap[modelName] | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	request := buildTestRequest() | ||||||
|  | 	request.Model = modelName | ||||||
|  | 	meta.OriginModelName, meta.ActualModelName = modelName, modelName | ||||||
|  | 	convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err, nil | ||||||
|  | 	} | ||||||
|  | 	jsonData, err := json.Marshal(convertedRequest) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err, nil | ||||||
|  | 	} | ||||||
|  | 	logger.SysLog(string(jsonData)) | ||||||
|  | 	requestBody := bytes.NewBuffer(jsonData) | ||||||
|  | 	c.Request.Body = io.NopCloser(requestBody) | ||||||
|  | 	resp, err := adaptor.DoRequest(c, meta, requestBody) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err, nil | ||||||
|  | 	} | ||||||
|  | 	if resp != nil && resp.StatusCode != http.StatusOK { | ||||||
|  | 		err := controller.RelayErrorHandler(resp) | ||||||
|  | 		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error | ||||||
|  | 	} | ||||||
|  | 	usage, respErr := adaptor.DoResponse(c, resp, meta) | ||||||
|  | 	if respErr != nil { | ||||||
|  | 		return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error | ||||||
|  | 	} | ||||||
|  | 	if usage == nil { | ||||||
|  | 		return errors.New("usage is nil"), nil | ||||||
|  | 	} | ||||||
|  | 	result := w.Result() | ||||||
|  | 	// print result.Body | ||||||
|  | 	respBody, err := io.ReadAll(result.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err, nil | ||||||
|  | 	} | ||||||
|  | 	logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestChannel(c *gin.Context) { | func TestChannel(c *gin.Context) { | ||||||
| 	id, err := strconv.Atoi(c.Param("id")) | 	id, err := strconv.Atoi(c.Param("id")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -121,9 +139,8 @@ func TestChannel(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	testRequest := buildTestRequest() |  | ||||||
| 	tik := time.Now() | 	tik := time.Now() | ||||||
| 	err, _ = testChannel(channel, *testRequest) | 	err, _ = testChannel(channel) | ||||||
| 	tok := time.Now() | 	tok := time.Now() | ||||||
| 	milliseconds := tok.Sub(tik).Milliseconds() | 	milliseconds := tok.Sub(tik).Milliseconds() | ||||||
| 	go channel.UpdateResponseTime(milliseconds) | 	go channel.UpdateResponseTime(milliseconds) | ||||||
| @@ -147,35 +164,9 @@ func TestChannel(c *gin.Context) { | |||||||
| var testAllChannelsLock sync.Mutex | var testAllChannelsLock sync.Mutex | ||||||
| var testAllChannelsRunning bool = false | var testAllChannelsRunning bool = false | ||||||
|  |  | ||||||
| func notifyRootUser(subject string, content string) { | func testChannels(notify bool, scope string) error { | ||||||
| 	if common.RootUserEmail == "" { | 	if config.RootUserEmail == "" { | ||||||
| 		common.RootUserEmail = model.GetRootUserEmail() | 		config.RootUserEmail = model.GetRootUserEmail() | ||||||
| 	} |  | ||||||
| 	err := common.SendEmail(subject, common.RootUserEmail, content) |  | ||||||
| 	if err != nil { |  | ||||||
| 		common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // disable & notify |  | ||||||
| func disableChannel(channelId int, channelName string, reason string) { |  | ||||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) |  | ||||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) |  | ||||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) |  | ||||||
| 	notifyRootUser(subject, content) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // enable & notify |  | ||||||
| func enableChannel(channelId int, channelName string) { |  | ||||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) |  | ||||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) |  | ||||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) |  | ||||||
| 	notifyRootUser(subject, content) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func testAllChannels(notify bool) error { |  | ||||||
| 	if common.RootUserEmail == "" { |  | ||||||
| 		common.RootUserEmail = model.GetRootUserEmail() |  | ||||||
| 	} | 	} | ||||||
| 	testAllChannelsLock.Lock() | 	testAllChannelsLock.Lock() | ||||||
| 	if testAllChannelsRunning { | 	if testAllChannelsRunning { | ||||||
| @@ -184,50 +175,57 @@ func testAllChannels(notify bool) error { | |||||||
| 	} | 	} | ||||||
| 	testAllChannelsRunning = true | 	testAllChannelsRunning = true | ||||||
| 	testAllChannelsLock.Unlock() | 	testAllChannelsLock.Unlock() | ||||||
| 	channels, err := model.GetAllChannels(0, 0, true) | 	channels, err := model.GetAllChannels(0, 0, scope) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	testRequest := buildTestRequest() | 	var disableThreshold = int64(config.ChannelDisableThreshold * 1000) | ||||||
| 	var disableThreshold = int64(common.ChannelDisableThreshold * 1000) |  | ||||||
| 	if disableThreshold == 0 { | 	if disableThreshold == 0 { | ||||||
| 		disableThreshold = 10000000 // a impossible value | 		disableThreshold = 10000000 // a impossible value | ||||||
| 	} | 	} | ||||||
| 	go func() { | 	go func() { | ||||||
| 		for _, channel := range channels { | 		for _, channel := range channels { | ||||||
| 			isChannelEnabled := channel.Status == common.ChannelStatusEnabled | 			isChannelEnabled := channel.Status == model.ChannelStatusEnabled | ||||||
| 			tik := time.Now() | 			tik := time.Now() | ||||||
| 			err, openaiErr := testChannel(channel, *testRequest) | 			err, openaiErr := testChannel(channel) | ||||||
| 			tok := time.Now() | 			tok := time.Now() | ||||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | 			milliseconds := tok.Sub(tik).Milliseconds() | ||||||
| 			if isChannelEnabled && milliseconds > disableThreshold { | 			if isChannelEnabled && milliseconds > disableThreshold { | ||||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | 				if config.AutomaticDisableChannelEnabled { | ||||||
|  | 					monitor.DisableChannel(channel.Id, channel.Name, err.Error()) | ||||||
|  | 				} else { | ||||||
|  | 					_ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s (%d)测试超时", channel.Name, channel.Id), "", err.Error()) | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 			if isChannelEnabled && shouldDisableChannel(openaiErr, -1) { | 			if isChannelEnabled && monitor.ShouldDisableChannel(openaiErr, -1) { | ||||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | 				monitor.DisableChannel(channel.Id, channel.Name, err.Error()) | ||||||
| 			} | 			} | ||||||
| 			if !isChannelEnabled && shouldEnableChannel(err, openaiErr) { | 			if !isChannelEnabled && monitor.ShouldEnableChannel(err, openaiErr) { | ||||||
| 				enableChannel(channel.Id, channel.Name) | 				monitor.EnableChannel(channel.Id, channel.Name) | ||||||
| 			} | 			} | ||||||
| 			channel.UpdateResponseTime(milliseconds) | 			channel.UpdateResponseTime(milliseconds) | ||||||
| 			time.Sleep(common.RequestInterval) | 			time.Sleep(config.RequestInterval) | ||||||
| 		} | 		} | ||||||
| 		testAllChannelsLock.Lock() | 		testAllChannelsLock.Lock() | ||||||
| 		testAllChannelsRunning = false | 		testAllChannelsRunning = false | ||||||
| 		testAllChannelsLock.Unlock() | 		testAllChannelsLock.Unlock() | ||||||
| 		if notify { | 		if notify { | ||||||
| 			err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") | 			err := message.Notify(message.ByAll, "渠道测试完成", "", "渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常") | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | 				logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestAllChannels(c *gin.Context) { | func TestChannels(c *gin.Context) { | ||||||
| 	err := testAllChannels(true) | 	scope := c.Query("scope") | ||||||
|  | 	if scope == "" { | ||||||
|  | 		scope = "all" | ||||||
|  | 	} | ||||||
|  | 	err := testChannels(true, scope) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -245,8 +243,8 @@ func TestAllChannels(c *gin.Context) { | |||||||
| func AutomaticallyTestChannels(frequency int) { | func AutomaticallyTestChannels(frequency int) { | ||||||
| 	for { | 	for { | ||||||
| 		time.Sleep(time.Duration(frequency) * time.Minute) | 		time.Sleep(time.Duration(frequency) * time.Minute) | ||||||
| 		common.SysLog("testing all channels") | 		logger.SysLog("testing all channels") | ||||||
| 		_ = testAllChannels(false) | 		_ = testChannels(false, "all") | ||||||
| 		common.SysLog("channel test finished") | 		logger.SysLog("channel test finished") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -2,9 +2,10 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/model" |  | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
| @@ -14,7 +15,7 @@ func GetAllChannels(c *gin.Context) { | |||||||
| 	if p < 0 { | 	if p < 0 { | ||||||
| 		p = 0 | 		p = 0 | ||||||
| 	} | 	} | ||||||
| 	channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false) | 	channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, "limited") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -83,7 +84,7 @@ func AddChannel(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	channel.CreatedTime = common.GetTimestamp() | 	channel.CreatedTime = helper.GetTimestamp() | ||||||
| 	keys := strings.Split(channel.Key, "\n") | 	keys := strings.Split(channel.Key, "\n") | ||||||
| 	channels := make([]model.Channel, 0, len(keys)) | 	channels := make([]model.Channel, 0, len(keys)) | ||||||
| 	for _, key := range keys { | 	for _, key := range keys { | ||||||
|   | |||||||
| @@ -2,13 +2,13 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func GetGroups(c *gin.Context) { | func GetGroups(c *gin.Context) { | ||||||
| 	groupNames := make([]string, 0) | 	groupNames := make([]string, 0) | ||||||
| 	for groupName, _ := range common.GroupRatio { | 	for groupName := range billingratio.GroupRatio { | ||||||
| 		groupNames = append(groupNames, groupName) | 		groupNames = append(groupNames, groupName) | ||||||
| 	} | 	} | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|   | |||||||
| @@ -2,9 +2,10 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/model" |  | ||||||
| 	"strconv" | 	"strconv" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -20,7 +21,7 @@ func GetAllLogs(c *gin.Context) { | |||||||
| 	tokenName := c.Query("token_name") | 	tokenName := c.Query("token_name") | ||||||
| 	modelName := c.Query("model_name") | 	modelName := c.Query("model_name") | ||||||
| 	channel, _ := strconv.Atoi(c.Query("channel")) | 	channel, _ := strconv.Atoi(c.Query("channel")) | ||||||
| 	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) | 	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -41,13 +42,13 @@ func GetUserLogs(c *gin.Context) { | |||||||
| 	if p < 0 { | 	if p < 0 { | ||||||
| 		p = 0 | 		p = 0 | ||||||
| 	} | 	} | ||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt(ctxkey.Id) | ||||||
| 	logType, _ := strconv.Atoi(c.Query("type")) | 	logType, _ := strconv.Atoi(c.Query("type")) | ||||||
| 	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) | 	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) | ||||||
| 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | ||||||
| 	tokenName := c.Query("token_name") | 	tokenName := c.Query("token_name") | ||||||
| 	modelName := c.Query("model_name") | 	modelName := c.Query("model_name") | ||||||
| 	logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) | 	logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -83,7 +84,7 @@ func SearchAllLogs(c *gin.Context) { | |||||||
|  |  | ||||||
| func SearchUserLogs(c *gin.Context) { | func SearchUserLogs(c *gin.Context) { | ||||||
| 	keyword := c.Query("keyword") | 	keyword := c.Query("keyword") | ||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt(ctxkey.Id) | ||||||
| 	logs, err := model.SearchUserLogs(userId, keyword) | 	logs, err := model.SearchUserLogs(userId, keyword) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| @@ -122,7 +123,7 @@ func GetLogsStat(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GetLogsSelfStat(c *gin.Context) { | func GetLogsSelfStat(c *gin.Context) { | ||||||
| 	username := c.GetString("username") | 	username := c.GetString(ctxkey.Username) | ||||||
| 	logType, _ := strconv.Atoi(c.Query("type")) | 	logType, _ := strconv.Atoi(c.Query("type")) | ||||||
| 	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) | 	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) | ||||||
| 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | ||||||
|   | |||||||
| @@ -3,9 +3,11 @@ package controller | |||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/message" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/model" |  | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| @@ -18,55 +20,56 @@ func GetStatus(c *gin.Context) { | |||||||
| 		"data": gin.H{ | 		"data": gin.H{ | ||||||
| 			"version":             common.Version, | 			"version":             common.Version, | ||||||
| 			"start_time":          common.StartTime, | 			"start_time":          common.StartTime, | ||||||
| 			"email_verification":  common.EmailVerificationEnabled, | 			"email_verification":  config.EmailVerificationEnabled, | ||||||
| 			"github_oauth":        common.GitHubOAuthEnabled, | 			"github_oauth":        config.GitHubOAuthEnabled, | ||||||
| 			"github_client_id":    common.GitHubClientId, | 			"github_client_id":    config.GitHubClientId, | ||||||
| 			"system_name":         common.SystemName, | 			"lark_client_id":      config.LarkClientId, | ||||||
| 			"logo":                common.Logo, | 			"system_name":         config.SystemName, | ||||||
| 			"footer_html":         common.Footer, | 			"logo":                config.Logo, | ||||||
| 			"wechat_qrcode":       common.WeChatAccountQRCodeImageURL, | 			"footer_html":         config.Footer, | ||||||
| 			"wechat_login":        common.WeChatAuthEnabled, | 			"wechat_qrcode":       config.WeChatAccountQRCodeImageURL, | ||||||
| 			"server_address":      common.ServerAddress, | 			"wechat_login":        config.WeChatAuthEnabled, | ||||||
| 			"turnstile_check":     common.TurnstileCheckEnabled, | 			"server_address":      config.ServerAddress, | ||||||
| 			"turnstile_site_key":  common.TurnstileSiteKey, | 			"turnstile_check":     config.TurnstileCheckEnabled, | ||||||
| 			"top_up_link":         common.TopUpLink, | 			"turnstile_site_key":  config.TurnstileSiteKey, | ||||||
| 			"chat_link":           common.ChatLink, | 			"top_up_link":         config.TopUpLink, | ||||||
| 			"quota_per_unit":      common.QuotaPerUnit, | 			"chat_link":           config.ChatLink, | ||||||
| 			"display_in_currency": common.DisplayInCurrencyEnabled, | 			"quota_per_unit":      config.QuotaPerUnit, | ||||||
|  | 			"display_in_currency": config.DisplayInCurrencyEnabled, | ||||||
| 		}, | 		}, | ||||||
| 	}) | 	}) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetNotice(c *gin.Context) { | func GetNotice(c *gin.Context) { | ||||||
| 	common.OptionMapRWMutex.RLock() | 	config.OptionMapRWMutex.RLock() | ||||||
| 	defer common.OptionMapRWMutex.RUnlock() | 	defer config.OptionMapRWMutex.RUnlock() | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    common.OptionMap["Notice"], | 		"data":    config.OptionMap["Notice"], | ||||||
| 	}) | 	}) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAbout(c *gin.Context) { | func GetAbout(c *gin.Context) { | ||||||
| 	common.OptionMapRWMutex.RLock() | 	config.OptionMapRWMutex.RLock() | ||||||
| 	defer common.OptionMapRWMutex.RUnlock() | 	defer config.OptionMapRWMutex.RUnlock() | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    common.OptionMap["About"], | 		"data":    config.OptionMap["About"], | ||||||
| 	}) | 	}) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetHomePageContent(c *gin.Context) { | func GetHomePageContent(c *gin.Context) { | ||||||
| 	common.OptionMapRWMutex.RLock() | 	config.OptionMapRWMutex.RLock() | ||||||
| 	defer common.OptionMapRWMutex.RUnlock() | 	defer config.OptionMapRWMutex.RUnlock() | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    common.OptionMap["HomePageContent"], | 		"data":    config.OptionMap["HomePageContent"], | ||||||
| 	}) | 	}) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| @@ -80,9 +83,9 @@ func SendEmailVerification(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if common.EmailDomainRestrictionEnabled { | 	if config.EmailDomainRestrictionEnabled { | ||||||
| 		allowed := false | 		allowed := false | ||||||
| 		for _, domain := range common.EmailDomainWhitelist { | 		for _, domain := range config.EmailDomainWhitelist { | ||||||
| 			if strings.HasSuffix(email, "@"+domain) { | 			if strings.HasSuffix(email, "@"+domain) { | ||||||
| 				allowed = true | 				allowed = true | ||||||
| 				break | 				break | ||||||
| @@ -105,11 +108,11 @@ func SendEmailVerification(c *gin.Context) { | |||||||
| 	} | 	} | ||||||
| 	code := common.GenerateVerificationCode(6) | 	code := common.GenerateVerificationCode(6) | ||||||
| 	common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose) | 	common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose) | ||||||
| 	subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) | 	subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName) | ||||||
| 	content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+ | 	content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+ | ||||||
| 		"<p>您的验证码为: <strong>%s</strong></p>"+ | 		"<p>您的验证码为: <strong>%s</strong></p>"+ | ||||||
| 		"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes) | 		"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, code, common.VerificationValidMinutes) | ||||||
| 	err := common.SendEmail(subject, email, content) | 	err := message.SendEmail(subject, email, content) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -142,13 +145,13 @@ func SendPasswordResetEmail(c *gin.Context) { | |||||||
| 	} | 	} | ||||||
| 	code := common.GenerateVerificationCode(0) | 	code := common.GenerateVerificationCode(0) | ||||||
| 	common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) | 	common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) | ||||||
| 	link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) | 	link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code) | ||||||
| 	subject := fmt.Sprintf("%s密码重置", common.SystemName) | 	subject := fmt.Sprintf("%s密码重置", config.SystemName) | ||||||
| 	content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ | 	content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ | ||||||
| 		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ | 		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ | ||||||
| 		"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ | 		"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ | ||||||
| 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes) | 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, link, link, common.VerificationValidMinutes) | ||||||
| 	err := common.SendEmail(subject, email, content) | 	err := message.SendEmail(subject, email, content) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
|   | |||||||
| @@ -2,8 +2,17 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
|  | 	relay "github.com/songquanpeng/one-api/relay" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/apitype" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/meta" | ||||||
|  | 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // https://platform.openai.com/docs/api-reference/models/list | // https://platform.openai.com/docs/api-reference/models/list | ||||||
| @@ -33,8 +42,9 @@ type OpenAIModels struct { | |||||||
| 	Parent     *string                 `json:"parent"` | 	Parent     *string                 `json:"parent"` | ||||||
| } | } | ||||||
|  |  | ||||||
| var openAIModels []OpenAIModels | var models []OpenAIModels | ||||||
| var openAIModelsMap map[string]OpenAIModels | var modelsMap map[string]OpenAIModels | ||||||
|  | var channelId2Models map[int][]string | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
| 	var permission []OpenAIModelPermission | 	var permission []OpenAIModelPermission | ||||||
| @@ -53,574 +63,151 @@ func init() { | |||||||
| 		IsBlocking:         false, | 		IsBlocking:         false, | ||||||
| 	}) | 	}) | ||||||
| 	// https://platform.openai.com/docs/models/model-endpoint-compatibility | 	// https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||||
| 	openAIModels = []OpenAIModels{ | 	for i := 0; i < apitype.Dummy; i++ { | ||||||
| 		{ | 		if i == apitype.AIProxyLibrary { | ||||||
| 			Id:         "dall-e-2", | 			continue | ||||||
| 			Object:     "model", | 		} | ||||||
| 			Created:    1677649963, | 		adaptor := relay.GetAdaptor(i) | ||||||
| 			OwnedBy:    "openai", | 		channelName := adaptor.GetChannelName() | ||||||
| 			Permission: permission, | 		modelNames := adaptor.GetModelList() | ||||||
| 			Root:       "dall-e-2", | 		for _, modelName := range modelNames { | ||||||
| 			Parent:     nil, | 			models = append(models, OpenAIModels{ | ||||||
| 		}, | 				Id:         modelName, | ||||||
| 		{ | 				Object:     "model", | ||||||
| 			Id:         "dall-e-3", | 				Created:    1626777600, | ||||||
| 			Object:     "model", | 				OwnedBy:    channelName, | ||||||
| 			Created:    1677649963, | 				Permission: permission, | ||||||
| 			OwnedBy:    "openai", | 				Root:       modelName, | ||||||
| 			Permission: permission, | 				Parent:     nil, | ||||||
| 			Root:       "dall-e-3", | 			}) | ||||||
| 			Parent:     nil, | 		} | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "whisper-1", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "whisper-1", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "tts-1", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "tts-1", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "tts-1-1106", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "tts-1-1106", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "tts-1-hd", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "tts-1-hd", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "tts-1-hd-1106", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "tts-1-hd-1106", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gpt-3.5-turbo", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gpt-3.5-turbo", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gpt-3.5-turbo-0301", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gpt-3.5-turbo-0301", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gpt-3.5-turbo-0613", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gpt-3.5-turbo-0613", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gpt-3.5-turbo-16k", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gpt-3.5-turbo-16k", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gpt-3.5-turbo-16k-0613", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gpt-3.5-turbo-16k-0613", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gpt-3.5-turbo-1106", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1699593571, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gpt-3.5-turbo-1106", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gpt-3.5-turbo-instruct", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gpt-3.5-turbo-instruct", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gpt-4", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gpt-4", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gpt-4-0314", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gpt-4-0314", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gpt-4-0613", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gpt-4-0613", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gpt-4-32k", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gpt-4-32k", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gpt-4-32k-0314", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gpt-4-32k-0314", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gpt-4-32k-0613", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gpt-4-32k-0613", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gpt-4-1106-preview", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1699593571, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gpt-4-1106-preview", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gpt-4-vision-preview", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1699593571, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gpt-4-vision-preview", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "text-embedding-ada-002", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "text-embedding-ada-002", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "text-davinci-003", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "text-davinci-003", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "text-davinci-002", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "text-davinci-002", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "text-curie-001", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "text-curie-001", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "text-babbage-001", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "text-babbage-001", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "text-ada-001", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "text-ada-001", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "text-moderation-latest", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "text-moderation-latest", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "text-moderation-stable", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "text-moderation-stable", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "text-davinci-edit-001", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "text-davinci-edit-001", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "code-davinci-edit-001", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "code-davinci-edit-001", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "davinci-002", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "davinci-002", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "babbage-002", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "openai", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "babbage-002", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "claude-instant-1", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "anthropic", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "claude-instant-1", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "claude-2", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "anthropic", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "claude-2", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "claude-2.1", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "anthropic", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "claude-2.1", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "claude-2.0", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "anthropic", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "claude-2.0", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "ERNIE-Bot", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "baidu", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "ERNIE-Bot", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "ERNIE-Bot-turbo", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "baidu", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "ERNIE-Bot-turbo", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "ERNIE-Bot-4", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "baidu", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "ERNIE-Bot-4", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "Embedding-V1", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "baidu", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "Embedding-V1", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "PaLM-2", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "google palm", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "PaLM-2", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gemini-pro", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "google gemini", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gemini-pro", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "gemini-pro-vision", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "google gemini", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "gemini-pro-vision", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "chatglm_turbo", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "zhipu", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "chatglm_turbo", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "chatglm_pro", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "zhipu", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "chatglm_pro", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "chatglm_std", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "zhipu", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "chatglm_std", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "chatglm_lite", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "zhipu", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "chatglm_lite", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "qwen-turbo", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "ali", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "qwen-turbo", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "qwen-plus", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "ali", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "qwen-plus", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "qwen-max", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "ali", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "qwen-max", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "qwen-max-longcontext", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "ali", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "qwen-max-longcontext", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "text-embedding-v1", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "ali", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "text-embedding-v1", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "SparkDesk", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "xunfei", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "SparkDesk", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "360GPT_S2_V9", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "360", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "360GPT_S2_V9", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "embedding-bert-512-v1", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "360", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "embedding-bert-512-v1", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "embedding_s1_v1", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "360", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "embedding_s1_v1", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "semantic_similarity_s1_v1", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "360", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "semantic_similarity_s1_v1", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Id:         "hunyuan", |  | ||||||
| 			Object:     "model", |  | ||||||
| 			Created:    1677649963, |  | ||||||
| 			OwnedBy:    "tencent", |  | ||||||
| 			Permission: permission, |  | ||||||
| 			Root:       "hunyuan", |  | ||||||
| 			Parent:     nil, |  | ||||||
| 		}, |  | ||||||
| 	} | 	} | ||||||
| 	openAIModelsMap = make(map[string]OpenAIModels) | 	for _, channelType := range openai.CompatibleChannels { | ||||||
| 	for _, model := range openAIModels { | 		if channelType == channeltype.Azure { | ||||||
| 		openAIModelsMap[model.Id] = model | 			continue | ||||||
|  | 		} | ||||||
|  | 		channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) | ||||||
|  | 		for _, modelName := range channelModelList { | ||||||
|  | 			models = append(models, OpenAIModels{ | ||||||
|  | 				Id:         modelName, | ||||||
|  | 				Object:     "model", | ||||||
|  | 				Created:    1626777600, | ||||||
|  | 				OwnedBy:    channelName, | ||||||
|  | 				Permission: permission, | ||||||
|  | 				Root:       modelName, | ||||||
|  | 				Parent:     nil, | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	modelsMap = make(map[string]OpenAIModels) | ||||||
|  | 	for _, model := range models { | ||||||
|  | 		modelsMap[model.Id] = model | ||||||
|  | 	} | ||||||
|  | 	channelId2Models = make(map[int][]string) | ||||||
|  | 	for i := 1; i < channeltype.Dummy; i++ { | ||||||
|  | 		adaptor := relay.GetAdaptor(channeltype.ToAPIType(i)) | ||||||
|  | 		meta := &meta.Meta{ | ||||||
|  | 			ChannelType: i, | ||||||
|  | 		} | ||||||
|  | 		adaptor.Init(meta) | ||||||
|  | 		channelId2Models[i] = adaptor.GetModelList() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func ListModels(c *gin.Context) { | func DashboardListModels(c *gin.Context) { | ||||||
|  | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 		"success": true, | ||||||
|  | 		"message": "", | ||||||
|  | 		"data":    channelId2Models, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ListAllModels(c *gin.Context) { | ||||||
| 	c.JSON(200, gin.H{ | 	c.JSON(200, gin.H{ | ||||||
| 		"object": "list", | 		"object": "list", | ||||||
| 		"data":   openAIModels, | 		"data":   models, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ListModels(c *gin.Context) { | ||||||
|  | 	ctx := c.Request.Context() | ||||||
|  | 	var availableModels []string | ||||||
|  | 	if c.GetString(ctxkey.AvailableModels) != "" { | ||||||
|  | 		availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",") | ||||||
|  | 	} else { | ||||||
|  | 		userId := c.GetInt(ctxkey.Id) | ||||||
|  | 		userGroup, _ := model.CacheGetUserGroup(userId) | ||||||
|  | 		availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) | ||||||
|  | 	} | ||||||
|  | 	modelSet := make(map[string]bool) | ||||||
|  | 	for _, availableModel := range availableModels { | ||||||
|  | 		modelSet[availableModel] = true | ||||||
|  | 	} | ||||||
|  | 	availableOpenAIModels := make([]OpenAIModels, 0) | ||||||
|  | 	for _, model := range models { | ||||||
|  | 		if _, ok := modelSet[model.Id]; ok { | ||||||
|  | 			modelSet[model.Id] = false | ||||||
|  | 			availableOpenAIModels = append(availableOpenAIModels, model) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	for modelName, ok := range modelSet { | ||||||
|  | 		if ok { | ||||||
|  | 			availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ | ||||||
|  | 				Id:      modelName, | ||||||
|  | 				Object:  "model", | ||||||
|  | 				Created: 1626777600, | ||||||
|  | 				OwnedBy: "custom", | ||||||
|  | 				Root:    modelName, | ||||||
|  | 				Parent:  nil, | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	c.JSON(200, gin.H{ | ||||||
|  | 		"object": "list", | ||||||
|  | 		"data":   availableOpenAIModels, | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func RetrieveModel(c *gin.Context) { | func RetrieveModel(c *gin.Context) { | ||||||
| 	modelId := c.Param("model") | 	modelId := c.Param("model") | ||||||
| 	if model, ok := openAIModelsMap[modelId]; ok { | 	if model, ok := modelsMap[modelId]; ok { | ||||||
| 		c.JSON(200, model) | 		c.JSON(200, model) | ||||||
| 	} else { | 	} else { | ||||||
| 		openAIError := OpenAIError{ | 		Error := relaymodel.Error{ | ||||||
| 			Message: fmt.Sprintf("The model '%s' does not exist", modelId), | 			Message: fmt.Sprintf("The model '%s' does not exist", modelId), | ||||||
| 			Type:    "invalid_request_error", | 			Type:    "invalid_request_error", | ||||||
| 			Param:   "model", | 			Param:   "model", | ||||||
| 			Code:    "model_not_found", | 			Code:    "model_not_found", | ||||||
| 		} | 		} | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"error": openAIError, | 			"error": Error, | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GetUserAvailableModels(c *gin.Context) { | ||||||
|  | 	ctx := c.Request.Context() | ||||||
|  | 	id := c.GetInt(ctxkey.Id) | ||||||
|  | 	userGroup, err := model.CacheGetUserGroup(id) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	models, err := model.CacheGetGroupModels(ctx, userGroup) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 		"success": true, | ||||||
|  | 		"message": "", | ||||||
|  | 		"data":    models, | ||||||
|  | 	}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|   | |||||||
| @@ -2,9 +2,10 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/model" |  | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| @@ -12,17 +13,17 @@ import ( | |||||||
|  |  | ||||||
| func GetOptions(c *gin.Context) { | func GetOptions(c *gin.Context) { | ||||||
| 	var options []*model.Option | 	var options []*model.Option | ||||||
| 	common.OptionMapRWMutex.Lock() | 	config.OptionMapRWMutex.Lock() | ||||||
| 	for k, v := range common.OptionMap { | 	for k, v := range config.OptionMap { | ||||||
| 		if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { | 		if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		options = append(options, &model.Option{ | 		options = append(options, &model.Option{ | ||||||
| 			Key:   k, | 			Key:   k, | ||||||
| 			Value: common.Interface2String(v), | 			Value: helper.Interface2String(v), | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| 	common.OptionMapRWMutex.Unlock() | 	config.OptionMapRWMutex.Unlock() | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| @@ -43,7 +44,7 @@ func UpdateOption(c *gin.Context) { | |||||||
| 	} | 	} | ||||||
| 	switch option.Key { | 	switch option.Key { | ||||||
| 	case "Theme": | 	case "Theme": | ||||||
| 		if !common.ValidThemes[option.Value] { | 		if !config.ValidThemes[option.Value] { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "无效的主题", | 				"message": "无效的主题", | ||||||
| @@ -51,7 +52,7 @@ func UpdateOption(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	case "GitHubOAuthEnabled": | 	case "GitHubOAuthEnabled": | ||||||
| 		if option.Value == "true" && common.GitHubClientId == "" { | 		if option.Value == "true" && config.GitHubClientId == "" { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", | 				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", | ||||||
| @@ -59,7 +60,7 @@ func UpdateOption(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	case "EmailDomainRestrictionEnabled": | 	case "EmailDomainRestrictionEnabled": | ||||||
| 		if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { | 		if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", | 				"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", | ||||||
| @@ -67,7 +68,7 @@ func UpdateOption(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	case "WeChatAuthEnabled": | 	case "WeChatAuthEnabled": | ||||||
| 		if option.Value == "true" && common.WeChatServerAddress == "" { | 		if option.Value == "true" && config.WeChatServerAddress == "" { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "无法启用微信登录,请先填入微信登录相关配置信息!", | 				"message": "无法启用微信登录,请先填入微信登录相关配置信息!", | ||||||
| @@ -75,7 +76,7 @@ func UpdateOption(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	case "TurnstileCheckEnabled": | 	case "TurnstileCheckEnabled": | ||||||
| 		if option.Value == "true" && common.TurnstileSiteKey == "" { | 		if option.Value == "true" && config.TurnstileSiteKey == "" { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", | 				"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", | ||||||
|   | |||||||
| @@ -2,9 +2,12 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/random" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/model" |  | ||||||
| 	"strconv" | 	"strconv" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -13,7 +16,7 @@ func GetAllRedemptions(c *gin.Context) { | |||||||
| 	if p < 0 { | 	if p < 0 { | ||||||
| 		p = 0 | 		p = 0 | ||||||
| 	} | 	} | ||||||
| 	redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage) | 	redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -105,12 +108,12 @@ func AddRedemption(c *gin.Context) { | |||||||
| 	} | 	} | ||||||
| 	var keys []string | 	var keys []string | ||||||
| 	for i := 0; i < redemption.Count; i++ { | 	for i := 0; i < redemption.Count; i++ { | ||||||
| 		key := common.GetUUID() | 		key := random.GetUUID() | ||||||
| 		cleanRedemption := model.Redemption{ | 		cleanRedemption := model.Redemption{ | ||||||
| 			UserId:      c.GetInt("id"), | 			UserId:      c.GetInt(ctxkey.Id), | ||||||
| 			Name:        redemption.Name, | 			Name:        redemption.Name, | ||||||
| 			Key:         key, | 			Key:         key, | ||||||
| 			CreatedTime: common.GetTimestamp(), | 			CreatedTime: helper.GetTimestamp(), | ||||||
| 			Quota:       redemption.Quota, | 			Quota:       redemption.Quota, | ||||||
| 		} | 		} | ||||||
| 		err = cleanRedemption.Insert() | 		err = cleanRedemption.Insert() | ||||||
|   | |||||||
| @@ -1,220 +0,0 @@ | |||||||
| package controller |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bufio" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"fmt" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"one-api/common" |  | ||||||
| 	"strconv" |  | ||||||
| 	"strings" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 |  | ||||||
|  |  | ||||||
| type AIProxyLibraryRequest struct { |  | ||||||
| 	Model     string `json:"model"` |  | ||||||
| 	Query     string `json:"query"` |  | ||||||
| 	LibraryId string `json:"libraryId"` |  | ||||||
| 	Stream    bool   `json:"stream"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type AIProxyLibraryError struct { |  | ||||||
| 	ErrCode int    `json:"errCode"` |  | ||||||
| 	Message string `json:"message"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type AIProxyLibraryDocument struct { |  | ||||||
| 	Title string `json:"title"` |  | ||||||
| 	URL   string `json:"url"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type AIProxyLibraryResponse struct { |  | ||||||
| 	Success   bool                     `json:"success"` |  | ||||||
| 	Answer    string                   `json:"answer"` |  | ||||||
| 	Documents []AIProxyLibraryDocument `json:"documents"` |  | ||||||
| 	AIProxyLibraryError |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type AIProxyLibraryStreamResponse struct { |  | ||||||
| 	Content   string                   `json:"content"` |  | ||||||
| 	Finish    bool                     `json:"finish"` |  | ||||||
| 	Model     string                   `json:"model"` |  | ||||||
| 	Documents []AIProxyLibraryDocument `json:"documents"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { |  | ||||||
| 	query := "" |  | ||||||
| 	if len(request.Messages) != 0 { |  | ||||||
| 		query = request.Messages[len(request.Messages)-1].StringContent() |  | ||||||
| 	} |  | ||||||
| 	return &AIProxyLibraryRequest{ |  | ||||||
| 		Model:  request.Model, |  | ||||||
| 		Stream: request.Stream, |  | ||||||
| 		Query:  query, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { |  | ||||||
| 	if len(documents) == 0 { |  | ||||||
| 		return "" |  | ||||||
| 	} |  | ||||||
| 	content := "\n\n参考文档:\n" |  | ||||||
| 	for i, document := range documents { |  | ||||||
| 		content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL) |  | ||||||
| 	} |  | ||||||
| 	return content |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { |  | ||||||
| 	content := response.Answer + aiProxyDocuments2Markdown(response.Documents) |  | ||||||
| 	choice := OpenAITextResponseChoice{ |  | ||||||
| 		Index: 0, |  | ||||||
| 		Message: Message{ |  | ||||||
| 			Role:    "assistant", |  | ||||||
| 			Content: content, |  | ||||||
| 		}, |  | ||||||
| 		FinishReason: "stop", |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := OpenAITextResponse{ |  | ||||||
| 		Id:      common.GetUUID(), |  | ||||||
| 		Object:  "chat.completion", |  | ||||||
| 		Created: common.GetTimestamp(), |  | ||||||
| 		Choices: []OpenAITextResponseChoice{choice}, |  | ||||||
| 	} |  | ||||||
| 	return &fullTextResponse |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { |  | ||||||
| 	var choice ChatCompletionsStreamResponseChoice |  | ||||||
| 	choice.Delta.Content = aiProxyDocuments2Markdown(documents) |  | ||||||
| 	choice.FinishReason = &stopFinishReason |  | ||||||
| 	return &ChatCompletionsStreamResponse{ |  | ||||||
| 		Id:      common.GetUUID(), |  | ||||||
| 		Object:  "chat.completion.chunk", |  | ||||||
| 		Created: common.GetTimestamp(), |  | ||||||
| 		Model:   "", |  | ||||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { |  | ||||||
| 	var choice ChatCompletionsStreamResponseChoice |  | ||||||
| 	choice.Delta.Content = response.Content |  | ||||||
| 	return &ChatCompletionsStreamResponse{ |  | ||||||
| 		Id:      common.GetUUID(), |  | ||||||
| 		Object:  "chat.completion.chunk", |  | ||||||
| 		Created: common.GetTimestamp(), |  | ||||||
| 		Model:   response.Model, |  | ||||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	var usage Usage |  | ||||||
| 	scanner := bufio.NewScanner(resp.Body) |  | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { |  | ||||||
| 		if atEOF && len(data) == 0 { |  | ||||||
| 			return 0, nil, nil |  | ||||||
| 		} |  | ||||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { |  | ||||||
| 			return i + 1, data[0:i], nil |  | ||||||
| 		} |  | ||||||
| 		if atEOF { |  | ||||||
| 			return len(data), data, nil |  | ||||||
| 		} |  | ||||||
| 		return 0, nil, nil |  | ||||||
| 	}) |  | ||||||
| 	dataChan := make(chan string) |  | ||||||
| 	stopChan := make(chan bool) |  | ||||||
| 	go func() { |  | ||||||
| 		for scanner.Scan() { |  | ||||||
| 			data := scanner.Text() |  | ||||||
| 			if len(data) < 5 { // ignore blank line or wrong format |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			if data[:5] != "data:" { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			data = data[5:] |  | ||||||
| 			dataChan <- data |  | ||||||
| 		} |  | ||||||
| 		stopChan <- true |  | ||||||
| 	}() |  | ||||||
| 	setEventStreamHeaders(c) |  | ||||||
| 	var documents []AIProxyLibraryDocument |  | ||||||
| 	c.Stream(func(w io.Writer) bool { |  | ||||||
| 		select { |  | ||||||
| 		case data := <-dataChan: |  | ||||||
| 			var AIProxyLibraryResponse AIProxyLibraryStreamResponse |  | ||||||
| 			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			if len(AIProxyLibraryResponse.Documents) != 0 { |  | ||||||
| 				documents = AIProxyLibraryResponse.Documents |  | ||||||
| 			} |  | ||||||
| 			response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) |  | ||||||
| 			jsonResponse, err := json.Marshal(response) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error marshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) |  | ||||||
| 			return true |  | ||||||
| 		case <-stopChan: |  | ||||||
| 			response := documentsAIProxyLibrary(documents) |  | ||||||
| 			jsonResponse, err := json.Marshal(response) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error marshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
| 	err := resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	return nil, &usage |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	var AIProxyLibraryResponse AIProxyLibraryResponse |  | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	if AIProxyLibraryResponse.ErrCode != 0 { |  | ||||||
| 		return &OpenAIErrorWithStatusCode{ |  | ||||||
| 			OpenAIError: OpenAIError{ |  | ||||||
| 				Message: AIProxyLibraryResponse.Message, |  | ||||||
| 				Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode), |  | ||||||
| 				Code:    AIProxyLibraryResponse.ErrCode, |  | ||||||
| 			}, |  | ||||||
| 			StatusCode: resp.StatusCode, |  | ||||||
| 		}, nil |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) |  | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") |  | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) |  | ||||||
| 	_, err = c.Writer.Write(jsonResponse) |  | ||||||
| 	return nil, &fullTextResponse.Usage |  | ||||||
| } |  | ||||||
| @@ -1,322 +0,0 @@ | |||||||
| package controller |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bufio" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"one-api/common" |  | ||||||
| 	"strings" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r |  | ||||||
|  |  | ||||||
| type AliMessage struct { |  | ||||||
| 	Content string `json:"content"` |  | ||||||
| 	Role    string `json:"role"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type AliInput struct { |  | ||||||
| 	//Prompt   string       `json:"prompt"` |  | ||||||
| 	Messages []AliMessage `json:"messages"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type AliParameters struct { |  | ||||||
| 	TopP              float64 `json:"top_p,omitempty"` |  | ||||||
| 	TopK              int     `json:"top_k,omitempty"` |  | ||||||
| 	Seed              uint64  `json:"seed,omitempty"` |  | ||||||
| 	EnableSearch      bool    `json:"enable_search,omitempty"` |  | ||||||
| 	IncrementalOutput bool    `json:"incremental_output,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type AliChatRequest struct { |  | ||||||
| 	Model      string        `json:"model"` |  | ||||||
| 	Input      AliInput      `json:"input"` |  | ||||||
| 	Parameters AliParameters `json:"parameters,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type AliEmbeddingRequest struct { |  | ||||||
| 	Model string `json:"model"` |  | ||||||
| 	Input struct { |  | ||||||
| 		Texts []string `json:"texts"` |  | ||||||
| 	} `json:"input"` |  | ||||||
| 	Parameters *struct { |  | ||||||
| 		TextType string `json:"text_type,omitempty"` |  | ||||||
| 	} `json:"parameters,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type AliEmbedding struct { |  | ||||||
| 	Embedding []float64 `json:"embedding"` |  | ||||||
| 	TextIndex int       `json:"text_index"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type AliEmbeddingResponse struct { |  | ||||||
| 	Output struct { |  | ||||||
| 		Embeddings []AliEmbedding `json:"embeddings"` |  | ||||||
| 	} `json:"output"` |  | ||||||
| 	Usage AliUsage `json:"usage"` |  | ||||||
| 	AliError |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type AliError struct { |  | ||||||
| 	Code      string `json:"code"` |  | ||||||
| 	Message   string `json:"message"` |  | ||||||
| 	RequestId string `json:"request_id"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type AliUsage struct { |  | ||||||
| 	InputTokens  int `json:"input_tokens"` |  | ||||||
| 	OutputTokens int `json:"output_tokens"` |  | ||||||
| 	TotalTokens  int `json:"total_tokens"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type AliOutput struct { |  | ||||||
| 	Text         string `json:"text"` |  | ||||||
| 	FinishReason string `json:"finish_reason"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type AliChatResponse struct { |  | ||||||
| 	Output AliOutput `json:"output"` |  | ||||||
| 	Usage  AliUsage  `json:"usage"` |  | ||||||
| 	AliError |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const AliEnableSearchModelSuffix = "-internet" |  | ||||||
|  |  | ||||||
| func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { |  | ||||||
| 	messages := make([]AliMessage, 0, len(request.Messages)) |  | ||||||
| 	for i := 0; i < len(request.Messages); i++ { |  | ||||||
| 		message := request.Messages[i] |  | ||||||
| 		messages = append(messages, AliMessage{ |  | ||||||
| 			Content: message.StringContent(), |  | ||||||
| 			Role:    strings.ToLower(message.Role), |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| 	enableSearch := false |  | ||||||
| 	aliModel := request.Model |  | ||||||
| 	if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) { |  | ||||||
| 		enableSearch = true |  | ||||||
| 		aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix) |  | ||||||
| 	} |  | ||||||
| 	return &AliChatRequest{ |  | ||||||
| 		Model: aliModel, |  | ||||||
| 		Input: AliInput{ |  | ||||||
| 			Messages: messages, |  | ||||||
| 		}, |  | ||||||
| 		Parameters: AliParameters{ |  | ||||||
| 			EnableSearch:      enableSearch, |  | ||||||
| 			IncrementalOutput: request.Stream, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { |  | ||||||
| 	return &AliEmbeddingRequest{ |  | ||||||
| 		Model: "text-embedding-v1", |  | ||||||
| 		Input: struct { |  | ||||||
| 			Texts []string `json:"texts"` |  | ||||||
| 		}{ |  | ||||||
| 			Texts: request.ParseInput(), |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	var aliResponse AliEmbeddingResponse |  | ||||||
| 	err := json.NewDecoder(resp.Body).Decode(&aliResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if aliResponse.Code != "" { |  | ||||||
| 		return &OpenAIErrorWithStatusCode{ |  | ||||||
| 			OpenAIError: OpenAIError{ |  | ||||||
| 				Message: aliResponse.Message, |  | ||||||
| 				Type:    aliResponse.Code, |  | ||||||
| 				Param:   aliResponse.RequestId, |  | ||||||
| 				Code:    aliResponse.Code, |  | ||||||
| 			}, |  | ||||||
| 			StatusCode: resp.StatusCode, |  | ||||||
| 		}, nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) |  | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") |  | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) |  | ||||||
| 	_, err = c.Writer.Write(jsonResponse) |  | ||||||
| 	return nil, &fullTextResponse.Usage |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { |  | ||||||
| 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ |  | ||||||
| 		Object: "list", |  | ||||||
| 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), |  | ||||||
| 		Model:  "text-embedding-v1", |  | ||||||
| 		Usage:  Usage{TotalTokens: response.Usage.TotalTokens}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for _, item := range response.Output.Embeddings { |  | ||||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ |  | ||||||
| 			Object:    `embedding`, |  | ||||||
| 			Index:     item.TextIndex, |  | ||||||
| 			Embedding: item.Embedding, |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| 	return &openAIEmbeddingResponse |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { |  | ||||||
| 	choice := OpenAITextResponseChoice{ |  | ||||||
| 		Index: 0, |  | ||||||
| 		Message: Message{ |  | ||||||
| 			Role:    "assistant", |  | ||||||
| 			Content: response.Output.Text, |  | ||||||
| 		}, |  | ||||||
| 		FinishReason: response.Output.FinishReason, |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := OpenAITextResponse{ |  | ||||||
| 		Id:      response.RequestId, |  | ||||||
| 		Object:  "chat.completion", |  | ||||||
| 		Created: common.GetTimestamp(), |  | ||||||
| 		Choices: []OpenAITextResponseChoice{choice}, |  | ||||||
| 		Usage: Usage{ |  | ||||||
| 			PromptTokens:     response.Usage.InputTokens, |  | ||||||
| 			CompletionTokens: response.Usage.OutputTokens, |  | ||||||
| 			TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 	return &fullTextResponse |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { |  | ||||||
| 	var choice ChatCompletionsStreamResponseChoice |  | ||||||
| 	choice.Delta.Content = aliResponse.Output.Text |  | ||||||
| 	if aliResponse.Output.FinishReason != "null" { |  | ||||||
| 		finishReason := aliResponse.Output.FinishReason |  | ||||||
| 		choice.FinishReason = &finishReason |  | ||||||
| 	} |  | ||||||
| 	response := ChatCompletionsStreamResponse{ |  | ||||||
| 		Id:      aliResponse.RequestId, |  | ||||||
| 		Object:  "chat.completion.chunk", |  | ||||||
| 		Created: common.GetTimestamp(), |  | ||||||
| 		Model:   "qwen", |  | ||||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, |  | ||||||
| 	} |  | ||||||
| 	return &response |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	var usage Usage |  | ||||||
| 	scanner := bufio.NewScanner(resp.Body) |  | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { |  | ||||||
| 		if atEOF && len(data) == 0 { |  | ||||||
| 			return 0, nil, nil |  | ||||||
| 		} |  | ||||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { |  | ||||||
| 			return i + 1, data[0:i], nil |  | ||||||
| 		} |  | ||||||
| 		if atEOF { |  | ||||||
| 			return len(data), data, nil |  | ||||||
| 		} |  | ||||||
| 		return 0, nil, nil |  | ||||||
| 	}) |  | ||||||
| 	dataChan := make(chan string) |  | ||||||
| 	stopChan := make(chan bool) |  | ||||||
| 	go func() { |  | ||||||
| 		for scanner.Scan() { |  | ||||||
| 			data := scanner.Text() |  | ||||||
| 			if len(data) < 5 { // ignore blank line or wrong format |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			if data[:5] != "data:" { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			data = data[5:] |  | ||||||
| 			dataChan <- data |  | ||||||
| 		} |  | ||||||
| 		stopChan <- true |  | ||||||
| 	}() |  | ||||||
| 	setEventStreamHeaders(c) |  | ||||||
| 	//lastResponseText := "" |  | ||||||
| 	c.Stream(func(w io.Writer) bool { |  | ||||||
| 		select { |  | ||||||
| 		case data := <-dataChan: |  | ||||||
| 			var aliResponse AliChatResponse |  | ||||||
| 			err := json.Unmarshal([]byte(data), &aliResponse) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			if aliResponse.Usage.OutputTokens != 0 { |  | ||||||
| 				usage.PromptTokens = aliResponse.Usage.InputTokens |  | ||||||
| 				usage.CompletionTokens = aliResponse.Usage.OutputTokens |  | ||||||
| 				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens |  | ||||||
| 			} |  | ||||||
| 			response := streamResponseAli2OpenAI(&aliResponse) |  | ||||||
| 			//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) |  | ||||||
| 			//lastResponseText = aliResponse.Output.Text |  | ||||||
| 			jsonResponse, err := json.Marshal(response) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error marshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) |  | ||||||
| 			return true |  | ||||||
| 		case <-stopChan: |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
| 	err := resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	return nil, &usage |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	var aliResponse AliChatResponse |  | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = json.Unmarshal(responseBody, &aliResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	if aliResponse.Code != "" { |  | ||||||
| 		return &OpenAIErrorWithStatusCode{ |  | ||||||
| 			OpenAIError: OpenAIError{ |  | ||||||
| 				Message: aliResponse.Message, |  | ||||||
| 				Type:    aliResponse.Code, |  | ||||||
| 				Param:   aliResponse.RequestId, |  | ||||||
| 				Code:    aliResponse.Code, |  | ||||||
| 			}, |  | ||||||
| 			StatusCode: resp.StatusCode, |  | ||||||
| 		}, nil |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := responseAli2OpenAI(&aliResponse) |  | ||||||
| 	fullTextResponse.Model = "qwen" |  | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") |  | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) |  | ||||||
| 	_, err = c.Writer.Write(jsonResponse) |  | ||||||
| 	return nil, &fullTextResponse.Usage |  | ||||||
| } |  | ||||||
| @@ -1,262 +0,0 @@ | |||||||
| package controller |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bufio" |  | ||||||
| 	"bytes" |  | ||||||
| 	"context" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" |  | ||||||
| 	"fmt" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/model" |  | ||||||
| 	"strings" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { |  | ||||||
| 	audioModel := "whisper-1" |  | ||||||
|  |  | ||||||
| 	tokenId := c.GetInt("token_id") |  | ||||||
| 	channelType := c.GetInt("channel") |  | ||||||
| 	channelId := c.GetInt("channel_id") |  | ||||||
| 	userId := c.GetInt("id") |  | ||||||
| 	group := c.GetString("group") |  | ||||||
| 	tokenName := c.GetString("token_name") |  | ||||||
|  |  | ||||||
| 	var ttsRequest TextToSpeechRequest |  | ||||||
| 	if relayMode == RelayModeAudioSpeech { |  | ||||||
| 		// Read JSON |  | ||||||
| 		err := common.UnmarshalBodyReusable(c, &ttsRequest) |  | ||||||
| 		// Check if JSON is valid |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "invalid_json", http.StatusBadRequest) |  | ||||||
| 		} |  | ||||||
| 		audioModel = ttsRequest.Model |  | ||||||
| 		// Check if text is too long 4096 |  | ||||||
| 		if len(ttsRequest.Input) > 4096 { |  | ||||||
| 			return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	modelRatio := common.GetModelRatio(audioModel) |  | ||||||
| 	groupRatio := common.GetGroupRatio(group) |  | ||||||
| 	ratio := modelRatio * groupRatio |  | ||||||
| 	var quota int |  | ||||||
| 	var preConsumedQuota int |  | ||||||
| 	switch relayMode { |  | ||||||
| 	case RelayModeAudioSpeech: |  | ||||||
| 		preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) |  | ||||||
| 		quota = preConsumedQuota |  | ||||||
| 	default: |  | ||||||
| 		preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio) |  | ||||||
| 	} |  | ||||||
| 	userQuota, err := model.CacheGetUserQuota(userId) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Check if user quota is enough |  | ||||||
| 	if userQuota-preConsumedQuota < 0 { |  | ||||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) |  | ||||||
| 	} |  | ||||||
| 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
| 	if userQuota > 100*preConsumedQuota { |  | ||||||
| 		// in this case, we do not pre-consume quota |  | ||||||
| 		// because the user has enough quota |  | ||||||
| 		preConsumedQuota = 0 |  | ||||||
| 	} |  | ||||||
| 	if preConsumedQuota > 0 { |  | ||||||
| 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// map model name |  | ||||||
| 	modelMapping := c.GetString("model_mapping") |  | ||||||
| 	if modelMapping != "" { |  | ||||||
| 		modelMap := make(map[string]string) |  | ||||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		if modelMap[audioModel] != "" { |  | ||||||
| 			audioModel = modelMap[audioModel] |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	baseURL := common.ChannelBaseURLs[channelType] |  | ||||||
| 	requestURL := c.Request.URL.String() |  | ||||||
| 	if c.GetString("base_url") != "" { |  | ||||||
| 		baseURL = c.GetString("base_url") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) |  | ||||||
| 	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { |  | ||||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api |  | ||||||
| 		apiVersion := GetAPIVersion(c) |  | ||||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	requestBody := &bytes.Buffer{} |  | ||||||
| 	_, err = io.Copy(requestBody, c.Request.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
| 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) |  | ||||||
| 	responseFormat := c.DefaultPostForm("response_format", "json") |  | ||||||
|  |  | ||||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { |  | ||||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api |  | ||||||
| 		apiKey := c.Request.Header.Get("Authorization") |  | ||||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") |  | ||||||
| 		req.Header.Set("api-key", apiKey) |  | ||||||
| 		req.ContentLength = c.Request.ContentLength |  | ||||||
| 	} else { |  | ||||||
| 		req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) |  | ||||||
| 	} |  | ||||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) |  | ||||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) |  | ||||||
|  |  | ||||||
| 	resp, err := httpClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	err = req.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
| 	err = c.Request.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if relayMode != RelayModeAudioSpeech { |  | ||||||
| 		responseBody, err := io.ReadAll(resp.Body) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		err = resp.Body.Close() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		var openAIErr TextResponse |  | ||||||
| 		if err = json.Unmarshal(responseBody, &openAIErr); err == nil { |  | ||||||
| 			if openAIErr.Error.Message != "" { |  | ||||||
| 				return errorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		var text string |  | ||||||
| 		switch responseFormat { |  | ||||||
| 		case "json": |  | ||||||
| 			text, err = getTextFromJSON(responseBody) |  | ||||||
| 		case "text": |  | ||||||
| 			text, err = getTextFromText(responseBody) |  | ||||||
| 		case "srt": |  | ||||||
| 			text, err = getTextFromSRT(responseBody) |  | ||||||
| 		case "verbose_json": |  | ||||||
| 			text, err = getTextFromVerboseJSON(responseBody) |  | ||||||
| 		case "vtt": |  | ||||||
| 			text, err = getTextFromVTT(responseBody) |  | ||||||
| 		default: |  | ||||||
| 			return errorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		quota = countTokenText(text, audioModel) |  | ||||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) |  | ||||||
| 	} |  | ||||||
| 	if resp.StatusCode != http.StatusOK { |  | ||||||
| 		if preConsumedQuota > 0 { |  | ||||||
| 			// we need to roll back the pre-consumed quota |  | ||||||
| 			defer func(ctx context.Context) { |  | ||||||
| 				go func() { |  | ||||||
| 					// negative means add quota back for token & user |  | ||||||
| 					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) |  | ||||||
| 					if err != nil { |  | ||||||
| 						common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) |  | ||||||
| 					} |  | ||||||
| 				}() |  | ||||||
| 			}(c.Request.Context()) |  | ||||||
| 		} |  | ||||||
| 		return relayErrorHandler(resp) |  | ||||||
| 	} |  | ||||||
| 	quotaDelta := quota - preConsumedQuota |  | ||||||
| 	defer func(ctx context.Context) { |  | ||||||
| 		go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) |  | ||||||
| 	}(c.Request.Context()) |  | ||||||
|  |  | ||||||
| 	for k, v := range resp.Header { |  | ||||||
| 		c.Writer.Header().Set(k, v[0]) |  | ||||||
| 	} |  | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) |  | ||||||
|  |  | ||||||
| 	_, err = io.Copy(c.Writer, resp.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func getTextFromVTT(body []byte) (string, error) { |  | ||||||
| 	return getTextFromSRT(body) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func getTextFromVerboseJSON(body []byte) (string, error) { |  | ||||||
| 	var whisperResponse WhisperVerboseJSONResponse |  | ||||||
| 	if err := json.Unmarshal(body, &whisperResponse); err != nil { |  | ||||||
| 		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) |  | ||||||
| 	} |  | ||||||
| 	return whisperResponse.Text, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func getTextFromSRT(body []byte) (string, error) { |  | ||||||
| 	scanner := bufio.NewScanner(strings.NewReader(string(body))) |  | ||||||
| 	var builder strings.Builder |  | ||||||
| 	var textLine bool |  | ||||||
| 	for scanner.Scan() { |  | ||||||
| 		line := scanner.Text() |  | ||||||
| 		if textLine { |  | ||||||
| 			builder.WriteString(line) |  | ||||||
| 			textLine = false |  | ||||||
| 			continue |  | ||||||
| 		} else if strings.Contains(line, "-->") { |  | ||||||
| 			textLine = true |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	if err := scanner.Err(); err != nil { |  | ||||||
| 		return "", err |  | ||||||
| 	} |  | ||||||
| 	return builder.String(), nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func getTextFromText(body []byte) (string, error) { |  | ||||||
| 	return strings.TrimSuffix(string(body), "\n"), nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func getTextFromJSON(body []byte) (string, error) { |  | ||||||
| 	var whisperResponse WhisperJSONResponse |  | ||||||
| 	if err := json.Unmarshal(body, &whisperResponse); err != nil { |  | ||||||
| 		return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) |  | ||||||
| 	} |  | ||||||
| 	return whisperResponse.Text, nil |  | ||||||
| } |  | ||||||
| @@ -1,360 +0,0 @@ | |||||||
| package controller |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bufio" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" |  | ||||||
| 	"fmt" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"one-api/common" |  | ||||||
| 	"strings" |  | ||||||
| 	"sync" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 |  | ||||||
|  |  | ||||||
| type BaiduTokenResponse struct { |  | ||||||
| 	ExpiresIn   int    `json:"expires_in"` |  | ||||||
| 	AccessToken string `json:"access_token"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type BaiduMessage struct { |  | ||||||
| 	Role    string `json:"role"` |  | ||||||
| 	Content string `json:"content"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type BaiduChatRequest struct { |  | ||||||
| 	Messages []BaiduMessage `json:"messages"` |  | ||||||
| 	Stream   bool           `json:"stream"` |  | ||||||
| 	UserId   string         `json:"user_id,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type BaiduError struct { |  | ||||||
| 	ErrorCode int    `json:"error_code"` |  | ||||||
| 	ErrorMsg  string `json:"error_msg"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type BaiduChatResponse struct { |  | ||||||
| 	Id               string `json:"id"` |  | ||||||
| 	Object           string `json:"object"` |  | ||||||
| 	Created          int64  `json:"created"` |  | ||||||
| 	Result           string `json:"result"` |  | ||||||
| 	IsTruncated      bool   `json:"is_truncated"` |  | ||||||
| 	NeedClearHistory bool   `json:"need_clear_history"` |  | ||||||
| 	Usage            Usage  `json:"usage"` |  | ||||||
| 	BaiduError |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type BaiduChatStreamResponse struct { |  | ||||||
| 	BaiduChatResponse |  | ||||||
| 	SentenceId int  `json:"sentence_id"` |  | ||||||
| 	IsEnd      bool `json:"is_end"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type BaiduEmbeddingRequest struct { |  | ||||||
| 	Input []string `json:"input"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type BaiduEmbeddingData struct { |  | ||||||
| 	Object    string    `json:"object"` |  | ||||||
| 	Embedding []float64 `json:"embedding"` |  | ||||||
| 	Index     int       `json:"index"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type BaiduEmbeddingResponse struct { |  | ||||||
| 	Id      string               `json:"id"` |  | ||||||
| 	Object  string               `json:"object"` |  | ||||||
| 	Created int64                `json:"created"` |  | ||||||
| 	Data    []BaiduEmbeddingData `json:"data"` |  | ||||||
| 	Usage   Usage                `json:"usage"` |  | ||||||
| 	BaiduError |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type BaiduAccessToken struct { |  | ||||||
| 	AccessToken      string    `json:"access_token"` |  | ||||||
| 	Error            string    `json:"error,omitempty"` |  | ||||||
| 	ErrorDescription string    `json:"error_description,omitempty"` |  | ||||||
| 	ExpiresIn        int64     `json:"expires_in,omitempty"` |  | ||||||
| 	ExpiresAt        time.Time `json:"-"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var baiduTokenStore sync.Map |  | ||||||
|  |  | ||||||
| func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { |  | ||||||
| 	messages := make([]BaiduMessage, 0, len(request.Messages)) |  | ||||||
| 	for _, message := range request.Messages { |  | ||||||
| 		if message.Role == "system" { |  | ||||||
| 			messages = append(messages, BaiduMessage{ |  | ||||||
| 				Role:    "user", |  | ||||||
| 				Content: message.StringContent(), |  | ||||||
| 			}) |  | ||||||
| 			messages = append(messages, BaiduMessage{ |  | ||||||
| 				Role:    "assistant", |  | ||||||
| 				Content: "Okay", |  | ||||||
| 			}) |  | ||||||
| 		} else { |  | ||||||
| 			messages = append(messages, BaiduMessage{ |  | ||||||
| 				Role:    message.Role, |  | ||||||
| 				Content: message.StringContent(), |  | ||||||
| 			}) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return &BaiduChatRequest{ |  | ||||||
| 		Messages: messages, |  | ||||||
| 		Stream:   request.Stream, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { |  | ||||||
| 	choice := OpenAITextResponseChoice{ |  | ||||||
| 		Index: 0, |  | ||||||
| 		Message: Message{ |  | ||||||
| 			Role:    "assistant", |  | ||||||
| 			Content: response.Result, |  | ||||||
| 		}, |  | ||||||
| 		FinishReason: "stop", |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := OpenAITextResponse{ |  | ||||||
| 		Id:      response.Id, |  | ||||||
| 		Object:  "chat.completion", |  | ||||||
| 		Created: response.Created, |  | ||||||
| 		Choices: []OpenAITextResponseChoice{choice}, |  | ||||||
| 		Usage:   response.Usage, |  | ||||||
| 	} |  | ||||||
| 	return &fullTextResponse |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { |  | ||||||
| 	var choice ChatCompletionsStreamResponseChoice |  | ||||||
| 	choice.Delta.Content = baiduResponse.Result |  | ||||||
| 	if baiduResponse.IsEnd { |  | ||||||
| 		choice.FinishReason = &stopFinishReason |  | ||||||
| 	} |  | ||||||
| 	response := ChatCompletionsStreamResponse{ |  | ||||||
| 		Id:      baiduResponse.Id, |  | ||||||
| 		Object:  "chat.completion.chunk", |  | ||||||
| 		Created: baiduResponse.Created, |  | ||||||
| 		Model:   "ernie-bot", |  | ||||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, |  | ||||||
| 	} |  | ||||||
| 	return &response |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { |  | ||||||
| 	return &BaiduEmbeddingRequest{ |  | ||||||
| 		Input: request.ParseInput(), |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { |  | ||||||
| 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ |  | ||||||
| 		Object: "list", |  | ||||||
| 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), |  | ||||||
| 		Model:  "baidu-embedding", |  | ||||||
| 		Usage:  response.Usage, |  | ||||||
| 	} |  | ||||||
| 	for _, item := range response.Data { |  | ||||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ |  | ||||||
| 			Object:    item.Object, |  | ||||||
| 			Index:     item.Index, |  | ||||||
| 			Embedding: item.Embedding, |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| 	return &openAIEmbeddingResponse |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	var usage Usage |  | ||||||
| 	scanner := bufio.NewScanner(resp.Body) |  | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { |  | ||||||
| 		if atEOF && len(data) == 0 { |  | ||||||
| 			return 0, nil, nil |  | ||||||
| 		} |  | ||||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { |  | ||||||
| 			return i + 1, data[0:i], nil |  | ||||||
| 		} |  | ||||||
| 		if atEOF { |  | ||||||
| 			return len(data), data, nil |  | ||||||
| 		} |  | ||||||
| 		return 0, nil, nil |  | ||||||
| 	}) |  | ||||||
| 	dataChan := make(chan string) |  | ||||||
| 	stopChan := make(chan bool) |  | ||||||
| 	go func() { |  | ||||||
| 		for scanner.Scan() { |  | ||||||
| 			data := scanner.Text() |  | ||||||
| 			if len(data) < 6 { // ignore blank line or wrong format |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			data = data[6:] |  | ||||||
| 			dataChan <- data |  | ||||||
| 		} |  | ||||||
| 		stopChan <- true |  | ||||||
| 	}() |  | ||||||
| 	setEventStreamHeaders(c) |  | ||||||
| 	c.Stream(func(w io.Writer) bool { |  | ||||||
| 		select { |  | ||||||
| 		case data := <-dataChan: |  | ||||||
| 			var baiduResponse BaiduChatStreamResponse |  | ||||||
| 			err := json.Unmarshal([]byte(data), &baiduResponse) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			if baiduResponse.Usage.TotalTokens != 0 { |  | ||||||
| 				usage.TotalTokens = baiduResponse.Usage.TotalTokens |  | ||||||
| 				usage.PromptTokens = baiduResponse.Usage.PromptTokens |  | ||||||
| 				usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens |  | ||||||
| 			} |  | ||||||
| 			response := streamResponseBaidu2OpenAI(&baiduResponse) |  | ||||||
| 			jsonResponse, err := json.Marshal(response) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error marshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) |  | ||||||
| 			return true |  | ||||||
| 		case <-stopChan: |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
| 	err := resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	return nil, &usage |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	var baiduResponse BaiduChatResponse |  | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = json.Unmarshal(responseBody, &baiduResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	if baiduResponse.ErrorMsg != "" { |  | ||||||
| 		return &OpenAIErrorWithStatusCode{ |  | ||||||
| 			OpenAIError: OpenAIError{ |  | ||||||
| 				Message: baiduResponse.ErrorMsg, |  | ||||||
| 				Type:    "baidu_error", |  | ||||||
| 				Param:   "", |  | ||||||
| 				Code:    baiduResponse.ErrorCode, |  | ||||||
| 			}, |  | ||||||
| 			StatusCode: resp.StatusCode, |  | ||||||
| 		}, nil |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := responseBaidu2OpenAI(&baiduResponse) |  | ||||||
| 	fullTextResponse.Model = "ernie-bot" |  | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") |  | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) |  | ||||||
| 	_, err = c.Writer.Write(jsonResponse) |  | ||||||
| 	return nil, &fullTextResponse.Usage |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	var baiduResponse BaiduEmbeddingResponse |  | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = json.Unmarshal(responseBody, &baiduResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	if baiduResponse.ErrorMsg != "" { |  | ||||||
| 		return &OpenAIErrorWithStatusCode{ |  | ||||||
| 			OpenAIError: OpenAIError{ |  | ||||||
| 				Message: baiduResponse.ErrorMsg, |  | ||||||
| 				Type:    "baidu_error", |  | ||||||
| 				Param:   "", |  | ||||||
| 				Code:    baiduResponse.ErrorCode, |  | ||||||
| 			}, |  | ||||||
| 			StatusCode: resp.StatusCode, |  | ||||||
| 		}, nil |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) |  | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") |  | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) |  | ||||||
| 	_, err = c.Writer.Write(jsonResponse) |  | ||||||
| 	return nil, &fullTextResponse.Usage |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func getBaiduAccessToken(apiKey string) (string, error) { |  | ||||||
| 	if val, ok := baiduTokenStore.Load(apiKey); ok { |  | ||||||
| 		var accessToken BaiduAccessToken |  | ||||||
| 		if accessToken, ok = val.(BaiduAccessToken); ok { |  | ||||||
| 			// soon this will expire |  | ||||||
| 			if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { |  | ||||||
| 				go func() { |  | ||||||
| 					_, _ = getBaiduAccessTokenHelper(apiKey) |  | ||||||
| 				}() |  | ||||||
| 			} |  | ||||||
| 			return accessToken.AccessToken, nil |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	accessToken, err := getBaiduAccessTokenHelper(apiKey) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", err |  | ||||||
| 	} |  | ||||||
| 	if accessToken == nil { |  | ||||||
| 		return "", errors.New("getBaiduAccessToken return a nil token") |  | ||||||
| 	} |  | ||||||
| 	return (*accessToken).AccessToken, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { |  | ||||||
| 	parts := strings.Split(apiKey, "|") |  | ||||||
| 	if len(parts) != 2 { |  | ||||||
| 		return nil, errors.New("invalid baidu apikey") |  | ||||||
| 	} |  | ||||||
| 	req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", |  | ||||||
| 		parts[0], parts[1]), nil) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	req.Header.Add("Content-Type", "application/json") |  | ||||||
| 	req.Header.Add("Accept", "application/json") |  | ||||||
| 	res, err := impatientHTTPClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	defer res.Body.Close() |  | ||||||
|  |  | ||||||
| 	var accessToken BaiduAccessToken |  | ||||||
| 	err = json.NewDecoder(res.Body).Decode(&accessToken) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	if accessToken.Error != "" { |  | ||||||
| 		return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) |  | ||||||
| 	} |  | ||||||
| 	if accessToken.AccessToken == "" { |  | ||||||
| 		return nil, errors.New("getBaiduAccessTokenHelper get empty access token") |  | ||||||
| 	} |  | ||||||
| 	accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second) |  | ||||||
| 	baiduTokenStore.Store(apiKey, accessToken) |  | ||||||
| 	return &accessToken, nil |  | ||||||
| } |  | ||||||
| @@ -1,223 +0,0 @@ | |||||||
| package controller |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bufio" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"fmt" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"one-api/common" |  | ||||||
| 	"strings" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type ClaudeMetadata struct { |  | ||||||
| 	UserId string `json:"user_id"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ClaudeRequest struct { |  | ||||||
| 	Model             string   `json:"model"` |  | ||||||
| 	Prompt            string   `json:"prompt"` |  | ||||||
| 	MaxTokensToSample int      `json:"max_tokens_to_sample"` |  | ||||||
| 	StopSequences     []string `json:"stop_sequences,omitempty"` |  | ||||||
| 	Temperature       float64  `json:"temperature,omitempty"` |  | ||||||
| 	TopP              float64  `json:"top_p,omitempty"` |  | ||||||
| 	TopK              int      `json:"top_k,omitempty"` |  | ||||||
| 	//ClaudeMetadata    `json:"metadata,omitempty"` |  | ||||||
| 	Stream bool `json:"stream,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ClaudeError struct { |  | ||||||
| 	Type    string `json:"type"` |  | ||||||
| 	Message string `json:"message"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ClaudeResponse struct { |  | ||||||
| 	Completion string      `json:"completion"` |  | ||||||
| 	StopReason string      `json:"stop_reason"` |  | ||||||
| 	Model      string      `json:"model"` |  | ||||||
| 	Error      ClaudeError `json:"error"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func stopReasonClaude2OpenAI(reason string) string { |  | ||||||
| 	switch reason { |  | ||||||
| 	case "stop_sequence": |  | ||||||
| 		return "stop" |  | ||||||
| 	case "max_tokens": |  | ||||||
| 		return "length" |  | ||||||
| 	default: |  | ||||||
| 		return reason |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { |  | ||||||
| 	claudeRequest := ClaudeRequest{ |  | ||||||
| 		Model:             textRequest.Model, |  | ||||||
| 		Prompt:            "", |  | ||||||
| 		MaxTokensToSample: textRequest.MaxTokens, |  | ||||||
| 		StopSequences:     nil, |  | ||||||
| 		Temperature:       textRequest.Temperature, |  | ||||||
| 		TopP:              textRequest.TopP, |  | ||||||
| 		Stream:            textRequest.Stream, |  | ||||||
| 	} |  | ||||||
| 	if claudeRequest.MaxTokensToSample == 0 { |  | ||||||
| 		claudeRequest.MaxTokensToSample = 1000000 |  | ||||||
| 	} |  | ||||||
| 	prompt := "" |  | ||||||
| 	for _, message := range textRequest.Messages { |  | ||||||
| 		if message.Role == "user" { |  | ||||||
| 			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) |  | ||||||
| 		} else if message.Role == "assistant" { |  | ||||||
| 			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) |  | ||||||
| 		} else if message.Role == "system" { |  | ||||||
| 			if prompt == "" { |  | ||||||
| 				prompt = message.StringContent() |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	prompt += "\n\nAssistant:" |  | ||||||
| 	claudeRequest.Prompt = prompt |  | ||||||
| 	return &claudeRequest |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse { |  | ||||||
| 	var choice ChatCompletionsStreamResponseChoice |  | ||||||
| 	choice.Delta.Content = claudeResponse.Completion |  | ||||||
| 	finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) |  | ||||||
| 	if finishReason != "null" { |  | ||||||
| 		choice.FinishReason = &finishReason |  | ||||||
| 	} |  | ||||||
| 	var response ChatCompletionsStreamResponse |  | ||||||
| 	response.Object = "chat.completion.chunk" |  | ||||||
| 	response.Model = claudeResponse.Model |  | ||||||
| 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} |  | ||||||
| 	return &response |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { |  | ||||||
| 	choice := OpenAITextResponseChoice{ |  | ||||||
| 		Index: 0, |  | ||||||
| 		Message: Message{ |  | ||||||
| 			Role:    "assistant", |  | ||||||
| 			Content: strings.TrimPrefix(claudeResponse.Completion, " "), |  | ||||||
| 			Name:    nil, |  | ||||||
| 		}, |  | ||||||
| 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := OpenAITextResponse{ |  | ||||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), |  | ||||||
| 		Object:  "chat.completion", |  | ||||||
| 		Created: common.GetTimestamp(), |  | ||||||
| 		Choices: []OpenAITextResponseChoice{choice}, |  | ||||||
| 	} |  | ||||||
| 	return &fullTextResponse |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { |  | ||||||
| 	responseText := "" |  | ||||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) |  | ||||||
| 	createdTime := common.GetTimestamp() |  | ||||||
| 	scanner := bufio.NewScanner(resp.Body) |  | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { |  | ||||||
| 		if atEOF && len(data) == 0 { |  | ||||||
| 			return 0, nil, nil |  | ||||||
| 		} |  | ||||||
| 		if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { |  | ||||||
| 			return i + 4, data[0:i], nil |  | ||||||
| 		} |  | ||||||
| 		if atEOF { |  | ||||||
| 			return len(data), data, nil |  | ||||||
| 		} |  | ||||||
| 		return 0, nil, nil |  | ||||||
| 	}) |  | ||||||
| 	dataChan := make(chan string) |  | ||||||
| 	stopChan := make(chan bool) |  | ||||||
| 	go func() { |  | ||||||
| 		for scanner.Scan() { |  | ||||||
| 			data := scanner.Text() |  | ||||||
| 			if !strings.HasPrefix(data, "event: completion") { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			data = strings.TrimPrefix(data, "event: completion\r\ndata: ") |  | ||||||
| 			dataChan <- data |  | ||||||
| 		} |  | ||||||
| 		stopChan <- true |  | ||||||
| 	}() |  | ||||||
| 	setEventStreamHeaders(c) |  | ||||||
| 	c.Stream(func(w io.Writer) bool { |  | ||||||
| 		select { |  | ||||||
| 		case data := <-dataChan: |  | ||||||
| 			// some implementations may add \r at the end of data |  | ||||||
| 			data = strings.TrimSuffix(data, "\r") |  | ||||||
| 			var claudeResponse ClaudeResponse |  | ||||||
| 			err := json.Unmarshal([]byte(data), &claudeResponse) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			responseText += claudeResponse.Completion |  | ||||||
| 			response := streamResponseClaude2OpenAI(&claudeResponse) |  | ||||||
| 			response.Id = responseId |  | ||||||
| 			response.Created = createdTime |  | ||||||
| 			jsonStr, err := json.Marshal(response) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error marshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) |  | ||||||
| 			return true |  | ||||||
| 		case <-stopChan: |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
| 	err := resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" |  | ||||||
| 	} |  | ||||||
| 	return nil, responseText |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	var claudeResponse ClaudeResponse |  | ||||||
| 	err = json.Unmarshal(responseBody, &claudeResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	if claudeResponse.Error.Type != "" { |  | ||||||
| 		return &OpenAIErrorWithStatusCode{ |  | ||||||
| 			OpenAIError: OpenAIError{ |  | ||||||
| 				Message: claudeResponse.Error.Message, |  | ||||||
| 				Type:    claudeResponse.Error.Type, |  | ||||||
| 				Param:   "", |  | ||||||
| 				Code:    claudeResponse.Error.Type, |  | ||||||
| 			}, |  | ||||||
| 			StatusCode: resp.StatusCode, |  | ||||||
| 		}, nil |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := responseClaude2OpenAI(&claudeResponse) |  | ||||||
| 	fullTextResponse.Model = model |  | ||||||
| 	completionTokens := countTokenText(claudeResponse.Completion, model) |  | ||||||
| 	usage := Usage{ |  | ||||||
| 		PromptTokens:     promptTokens, |  | ||||||
| 		CompletionTokens: completionTokens, |  | ||||||
| 		TotalTokens:      promptTokens + completionTokens, |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse.Usage = usage |  | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") |  | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) |  | ||||||
| 	_, err = c.Writer.Write(jsonResponse) |  | ||||||
| 	return nil, &usage |  | ||||||
| } |  | ||||||
| @@ -1,337 +0,0 @@ | |||||||
| package controller |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bufio" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"fmt" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/common/image" |  | ||||||
| 	"strings" |  | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	GeminiVisionMaxImageNum = 16 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type GeminiChatRequest struct { |  | ||||||
| 	Contents         []GeminiChatContent        `json:"contents"` |  | ||||||
| 	SafetySettings   []GeminiChatSafetySettings `json:"safety_settings,omitempty"` |  | ||||||
| 	GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` |  | ||||||
| 	Tools            []GeminiChatTools          `json:"tools,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type GeminiInlineData struct { |  | ||||||
| 	MimeType string `json:"mimeType"` |  | ||||||
| 	Data     string `json:"data"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type GeminiPart struct { |  | ||||||
| 	Text       string            `json:"text,omitempty"` |  | ||||||
| 	InlineData *GeminiInlineData `json:"inlineData,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type GeminiChatContent struct { |  | ||||||
| 	Role  string       `json:"role,omitempty"` |  | ||||||
| 	Parts []GeminiPart `json:"parts"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type GeminiChatSafetySettings struct { |  | ||||||
| 	Category  string `json:"category"` |  | ||||||
| 	Threshold string `json:"threshold"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type GeminiChatTools struct { |  | ||||||
| 	FunctionDeclarations any `json:"functionDeclarations,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type GeminiChatGenerationConfig struct { |  | ||||||
| 	Temperature     float64  `json:"temperature,omitempty"` |  | ||||||
| 	TopP            float64  `json:"topP,omitempty"` |  | ||||||
| 	TopK            float64  `json:"topK,omitempty"` |  | ||||||
| 	MaxOutputTokens int      `json:"maxOutputTokens,omitempty"` |  | ||||||
| 	CandidateCount  int      `json:"candidateCount,omitempty"` |  | ||||||
| 	StopSequences   []string `json:"stopSequences,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Setting safety to the lowest possible values since Gemini is already powerless enough |  | ||||||
| func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { |  | ||||||
| 	geminiRequest := GeminiChatRequest{ |  | ||||||
| 		Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), |  | ||||||
| 		SafetySettings: []GeminiChatSafetySettings{ |  | ||||||
| 			{ |  | ||||||
| 				Category:  "HARM_CATEGORY_HARASSMENT", |  | ||||||
| 				Threshold: common.GeminiSafetySetting, |  | ||||||
| 			}, |  | ||||||
| 			{ |  | ||||||
| 				Category:  "HARM_CATEGORY_HATE_SPEECH", |  | ||||||
| 				Threshold: common.GeminiSafetySetting, |  | ||||||
| 			}, |  | ||||||
| 			{ |  | ||||||
| 				Category:  "HARM_CATEGORY_SEXUALLY_EXPLICIT", |  | ||||||
| 				Threshold: common.GeminiSafetySetting, |  | ||||||
| 			}, |  | ||||||
| 			{ |  | ||||||
| 				Category:  "HARM_CATEGORY_DANGEROUS_CONTENT", |  | ||||||
| 				Threshold: common.GeminiSafetySetting, |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 		GenerationConfig: GeminiChatGenerationConfig{ |  | ||||||
| 			Temperature:     textRequest.Temperature, |  | ||||||
| 			TopP:            textRequest.TopP, |  | ||||||
| 			MaxOutputTokens: textRequest.MaxTokens, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 	if textRequest.Functions != nil { |  | ||||||
| 		geminiRequest.Tools = []GeminiChatTools{ |  | ||||||
| 			{ |  | ||||||
| 				FunctionDeclarations: textRequest.Functions, |  | ||||||
| 			}, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	shouldAddDummyModelMessage := false |  | ||||||
| 	for _, message := range textRequest.Messages { |  | ||||||
| 		content := GeminiChatContent{ |  | ||||||
| 			Role: message.Role, |  | ||||||
| 			Parts: []GeminiPart{ |  | ||||||
| 				{ |  | ||||||
| 					Text: message.StringContent(), |  | ||||||
| 				}, |  | ||||||
| 			}, |  | ||||||
| 		} |  | ||||||
| 		openaiContent := message.ParseContent() |  | ||||||
| 		var parts []GeminiPart |  | ||||||
| 		imageNum := 0 |  | ||||||
| 		for _, part := range openaiContent { |  | ||||||
| 			if part.Type == ContentTypeText { |  | ||||||
| 				parts = append(parts, GeminiPart{ |  | ||||||
| 					Text: part.Text, |  | ||||||
| 				}) |  | ||||||
| 			} else if part.Type == ContentTypeImageURL { |  | ||||||
| 				imageNum += 1 |  | ||||||
| 				if imageNum > GeminiVisionMaxImageNum { |  | ||||||
| 					continue |  | ||||||
| 				} |  | ||||||
| 				mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) |  | ||||||
| 				parts = append(parts, GeminiPart{ |  | ||||||
| 					InlineData: &GeminiInlineData{ |  | ||||||
| 						MimeType: mimeType, |  | ||||||
| 						Data:     data, |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		content.Parts = parts |  | ||||||
|  |  | ||||||
| 		// there's no assistant role in gemini and API shall vomit if Role is not user or model |  | ||||||
| 		if content.Role == "assistant" { |  | ||||||
| 			content.Role = "model" |  | ||||||
| 		} |  | ||||||
| 		// Converting system prompt to prompt from user for the same reason |  | ||||||
| 		if content.Role == "system" { |  | ||||||
| 			content.Role = "user" |  | ||||||
| 			shouldAddDummyModelMessage = true |  | ||||||
| 		} |  | ||||||
| 		geminiRequest.Contents = append(geminiRequest.Contents, content) |  | ||||||
|  |  | ||||||
| 		// If a system message is the last message, we need to add a dummy model message to make gemini happy |  | ||||||
| 		if shouldAddDummyModelMessage { |  | ||||||
| 			geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ |  | ||||||
| 				Role: "model", |  | ||||||
| 				Parts: []GeminiPart{ |  | ||||||
| 					{ |  | ||||||
| 						Text: "Okay", |  | ||||||
| 					}, |  | ||||||
| 				}, |  | ||||||
| 			}) |  | ||||||
| 			shouldAddDummyModelMessage = false |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return &geminiRequest |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type GeminiChatResponse struct { |  | ||||||
| 	Candidates     []GeminiChatCandidate    `json:"candidates"` |  | ||||||
| 	PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (g *GeminiChatResponse) GetResponseText() string { |  | ||||||
| 	if g == nil { |  | ||||||
| 		return "" |  | ||||||
| 	} |  | ||||||
| 	if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 { |  | ||||||
| 		return g.Candidates[0].Content.Parts[0].Text |  | ||||||
| 	} |  | ||||||
| 	return "" |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type GeminiChatCandidate struct { |  | ||||||
| 	Content       GeminiChatContent        `json:"content"` |  | ||||||
| 	FinishReason  string                   `json:"finishReason"` |  | ||||||
| 	Index         int64                    `json:"index"` |  | ||||||
| 	SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type GeminiChatSafetyRating struct { |  | ||||||
| 	Category    string `json:"category"` |  | ||||||
| 	Probability string `json:"probability"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type GeminiChatPromptFeedback struct { |  | ||||||
| 	SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { |  | ||||||
| 	fullTextResponse := OpenAITextResponse{ |  | ||||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), |  | ||||||
| 		Object:  "chat.completion", |  | ||||||
| 		Created: common.GetTimestamp(), |  | ||||||
| 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), |  | ||||||
| 	} |  | ||||||
| 	for i, candidate := range response.Candidates { |  | ||||||
| 		choice := OpenAITextResponseChoice{ |  | ||||||
| 			Index: i, |  | ||||||
| 			Message: Message{ |  | ||||||
| 				Role:    "assistant", |  | ||||||
| 				Content: "", |  | ||||||
| 			}, |  | ||||||
| 			FinishReason: stopFinishReason, |  | ||||||
| 		} |  | ||||||
| 		if len(candidate.Content.Parts) > 0 { |  | ||||||
| 			choice.Message.Content = candidate.Content.Parts[0].Text |  | ||||||
| 		} |  | ||||||
| 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice) |  | ||||||
| 	} |  | ||||||
| 	return &fullTextResponse |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { |  | ||||||
| 	var choice ChatCompletionsStreamResponseChoice |  | ||||||
| 	choice.Delta.Content = geminiResponse.GetResponseText() |  | ||||||
| 	choice.FinishReason = &stopFinishReason |  | ||||||
| 	var response ChatCompletionsStreamResponse |  | ||||||
| 	response.Object = "chat.completion.chunk" |  | ||||||
| 	response.Model = "gemini" |  | ||||||
| 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} |  | ||||||
| 	return &response |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { |  | ||||||
| 	responseText := "" |  | ||||||
| 	dataChan := make(chan string) |  | ||||||
| 	stopChan := make(chan bool) |  | ||||||
| 	scanner := bufio.NewScanner(resp.Body) |  | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { |  | ||||||
| 		if atEOF && len(data) == 0 { |  | ||||||
| 			return 0, nil, nil |  | ||||||
| 		} |  | ||||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { |  | ||||||
| 			return i + 1, data[0:i], nil |  | ||||||
| 		} |  | ||||||
| 		if atEOF { |  | ||||||
| 			return len(data), data, nil |  | ||||||
| 		} |  | ||||||
| 		return 0, nil, nil |  | ||||||
| 	}) |  | ||||||
| 	go func() { |  | ||||||
| 		for scanner.Scan() { |  | ||||||
| 			data := scanner.Text() |  | ||||||
| 			data = strings.TrimSpace(data) |  | ||||||
| 			if !strings.HasPrefix(data, "\"text\": \"") { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			data = strings.TrimPrefix(data, "\"text\": \"") |  | ||||||
| 			data = strings.TrimSuffix(data, "\"") |  | ||||||
| 			dataChan <- data |  | ||||||
| 		} |  | ||||||
| 		stopChan <- true |  | ||||||
| 	}() |  | ||||||
| 	setEventStreamHeaders(c) |  | ||||||
| 	c.Stream(func(w io.Writer) bool { |  | ||||||
| 		select { |  | ||||||
| 		case data := <-dataChan: |  | ||||||
| 			// this is used to prevent annoying \ related format bug |  | ||||||
| 			data = fmt.Sprintf("{\"content\": \"%s\"}", data) |  | ||||||
| 			type dummyStruct struct { |  | ||||||
| 				Content string `json:"content"` |  | ||||||
| 			} |  | ||||||
| 			var dummy dummyStruct |  | ||||||
| 			err := json.Unmarshal([]byte(data), &dummy) |  | ||||||
| 			responseText += dummy.Content |  | ||||||
| 			var choice ChatCompletionsStreamResponseChoice |  | ||||||
| 			choice.Delta.Content = dummy.Content |  | ||||||
| 			response := ChatCompletionsStreamResponse{ |  | ||||||
| 				Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), |  | ||||||
| 				Object:  "chat.completion.chunk", |  | ||||||
| 				Created: common.GetTimestamp(), |  | ||||||
| 				Model:   "gemini-pro", |  | ||||||
| 				Choices: []ChatCompletionsStreamResponseChoice{choice}, |  | ||||||
| 			} |  | ||||||
| 			jsonResponse, err := json.Marshal(response) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error marshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) |  | ||||||
| 			return true |  | ||||||
| 		case <-stopChan: |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
| 	err := resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" |  | ||||||
| 	} |  | ||||||
| 	return nil, responseText |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	var geminiResponse GeminiChatResponse |  | ||||||
| 	err = json.Unmarshal(responseBody, &geminiResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	if len(geminiResponse.Candidates) == 0 { |  | ||||||
| 		return &OpenAIErrorWithStatusCode{ |  | ||||||
| 			OpenAIError: OpenAIError{ |  | ||||||
| 				Message: "No candidates returned", |  | ||||||
| 				Type:    "server_error", |  | ||||||
| 				Param:   "", |  | ||||||
| 				Code:    500, |  | ||||||
| 			}, |  | ||||||
| 			StatusCode: resp.StatusCode, |  | ||||||
| 		}, nil |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) |  | ||||||
| 	fullTextResponse.Model = model |  | ||||||
| 	completionTokens := countTokenText(geminiResponse.GetResponseText(), model) |  | ||||||
| 	usage := Usage{ |  | ||||||
| 		PromptTokens:     promptTokens, |  | ||||||
| 		CompletionTokens: completionTokens, |  | ||||||
| 		TotalTokens:      promptTokens + completionTokens, |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse.Usage = usage |  | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") |  | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) |  | ||||||
| 	_, err = c.Writer.Write(jsonResponse) |  | ||||||
| 	return nil, &usage |  | ||||||
| } |  | ||||||
| @@ -1,222 +0,0 @@ | |||||||
| package controller |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bytes" |  | ||||||
| 	"context" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" |  | ||||||
| 	"fmt" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/model" |  | ||||||
| 	"strings" |  | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func isWithinRange(element string, value int) bool { |  | ||||||
| 	if _, ok := common.DalleGenerationImageAmounts[element]; !ok { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 	min := common.DalleGenerationImageAmounts[element][0] |  | ||||||
| 	max := common.DalleGenerationImageAmounts[element][1] |  | ||||||
|  |  | ||||||
| 	return value >= min && value <= max |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { |  | ||||||
| 	imageModel := "dall-e-2" |  | ||||||
| 	imageSize := "1024x1024" |  | ||||||
|  |  | ||||||
| 	tokenId := c.GetInt("token_id") |  | ||||||
| 	channelType := c.GetInt("channel") |  | ||||||
| 	channelId := c.GetInt("channel_id") |  | ||||||
| 	userId := c.GetInt("id") |  | ||||||
| 	group := c.GetString("group") |  | ||||||
|  |  | ||||||
| 	var imageRequest ImageRequest |  | ||||||
| 	err := common.UnmarshalBodyReusable(c, &imageRequest) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if imageRequest.N == 0 { |  | ||||||
| 		imageRequest.N = 1 |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Size validation |  | ||||||
| 	if imageRequest.Size != "" { |  | ||||||
| 		imageSize = imageRequest.Size |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Model validation |  | ||||||
| 	if imageRequest.Model != "" { |  | ||||||
| 		imageModel = imageRequest.Model |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] |  | ||||||
|  |  | ||||||
| 	// Check if model is supported |  | ||||||
| 	if hasValidSize { |  | ||||||
| 		if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { |  | ||||||
| 			if imageSize == "1024x1024" { |  | ||||||
| 				imageCostRatio *= 2 |  | ||||||
| 			} else { |  | ||||||
| 				imageCostRatio *= 1.5 |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} else { |  | ||||||
| 		return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Prompt validation |  | ||||||
| 	if imageRequest.Prompt == "" { |  | ||||||
| 		return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Check prompt length |  | ||||||
| 	if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { |  | ||||||
| 		return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Number of generated images validation |  | ||||||
| 	if isWithinRange(imageModel, imageRequest.N) == false { |  | ||||||
| 		// channel not azure |  | ||||||
| 		if channelType != common.ChannelTypeAzure { |  | ||||||
| 			return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// map model name |  | ||||||
| 	modelMapping := c.GetString("model_mapping") |  | ||||||
| 	isModelMapped := false |  | ||||||
| 	if modelMapping != "" { |  | ||||||
| 		modelMap := make(map[string]string) |  | ||||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		if modelMap[imageModel] != "" { |  | ||||||
| 			imageModel = modelMap[imageModel] |  | ||||||
| 			isModelMapped = true |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	baseURL := common.ChannelBaseURLs[channelType] |  | ||||||
| 	requestURL := c.Request.URL.String() |  | ||||||
| 	if c.GetString("base_url") != "" { |  | ||||||
| 		baseURL = c.GetString("base_url") |  | ||||||
| 	} |  | ||||||
| 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) |  | ||||||
| 	if channelType == common.ChannelTypeAzure { |  | ||||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api |  | ||||||
| 		apiVersion := GetAPIVersion(c) |  | ||||||
| 		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview |  | ||||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	var requestBody io.Reader |  | ||||||
| 	if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body |  | ||||||
| 		jsonStr, err := json.Marshal(imageRequest) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) |  | ||||||
| 	} else { |  | ||||||
| 		requestBody = c.Request.Body |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	modelRatio := common.GetModelRatio(imageModel) |  | ||||||
| 	groupRatio := common.GetGroupRatio(group) |  | ||||||
| 	ratio := modelRatio * groupRatio |  | ||||||
| 	userQuota, err := model.CacheGetUserQuota(userId) |  | ||||||
|  |  | ||||||
| 	quota := int(ratio*imageCostRatio*1000) * imageRequest.N |  | ||||||
|  |  | ||||||
| 	if userQuota-quota < 0 { |  | ||||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
| 	token := c.Request.Header.Get("Authorization") |  | ||||||
| 	if channelType == common.ChannelTypeAzure { // Azure authentication |  | ||||||
| 		token = strings.TrimPrefix(token, "Bearer ") |  | ||||||
| 		req.Header.Set("api-key", token) |  | ||||||
| 	} else { |  | ||||||
| 		req.Header.Set("Authorization", token) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) |  | ||||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) |  | ||||||
|  |  | ||||||
| 	resp, err := httpClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	err = req.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
| 	err = c.Request.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
| 	var textResponse ImageResponse |  | ||||||
|  |  | ||||||
| 	defer func(ctx context.Context) { |  | ||||||
| 		if resp.StatusCode != http.StatusOK { |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		err := model.PostConsumeTokenQuota(tokenId, quota) |  | ||||||
| 		if err != nil { |  | ||||||
| 			common.SysError("error consuming token remain quota: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 		err = model.CacheUpdateUserQuota(userId) |  | ||||||
| 		if err != nil { |  | ||||||
| 			common.SysError("error update user quota cache: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 		if quota != 0 { |  | ||||||
| 			tokenName := c.GetString("token_name") |  | ||||||
| 			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) |  | ||||||
| 			model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) |  | ||||||
| 			model.UpdateUserUsedQuotaAndRequestCount(userId, quota) |  | ||||||
| 			channelId := c.GetInt("channel_id") |  | ||||||
| 			model.UpdateChannelUsedQuota(channelId, quota) |  | ||||||
| 		} |  | ||||||
| 	}(c.Request.Context()) |  | ||||||
|  |  | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) |  | ||||||
|  |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
| 	err = json.Unmarshal(responseBody, &textResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) |  | ||||||
|  |  | ||||||
| 	for k, v := range resp.Header { |  | ||||||
| 		c.Writer.Header().Set(k, v[0]) |  | ||||||
| 	} |  | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) |  | ||||||
|  |  | ||||||
| 	_, err = io.Copy(c.Writer, resp.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| @@ -1,143 +0,0 @@ | |||||||
| package controller |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bufio" |  | ||||||
| 	"bytes" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"one-api/common" |  | ||||||
| 	"strings" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { |  | ||||||
| 	responseText := "" |  | ||||||
| 	scanner := bufio.NewScanner(resp.Body) |  | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { |  | ||||||
| 		if atEOF && len(data) == 0 { |  | ||||||
| 			return 0, nil, nil |  | ||||||
| 		} |  | ||||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { |  | ||||||
| 			return i + 1, data[0:i], nil |  | ||||||
| 		} |  | ||||||
| 		if atEOF { |  | ||||||
| 			return len(data), data, nil |  | ||||||
| 		} |  | ||||||
| 		return 0, nil, nil |  | ||||||
| 	}) |  | ||||||
| 	dataChan := make(chan string) |  | ||||||
| 	stopChan := make(chan bool) |  | ||||||
| 	go func() { |  | ||||||
| 		for scanner.Scan() { |  | ||||||
| 			data := scanner.Text() |  | ||||||
| 			if len(data) < 6 { // ignore blank line or wrong format |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			if data[:6] != "data: " && data[:6] != "[DONE]" { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			dataChan <- data |  | ||||||
| 			data = data[6:] |  | ||||||
| 			if !strings.HasPrefix(data, "[DONE]") { |  | ||||||
| 				switch relayMode { |  | ||||||
| 				case RelayModeChatCompletions: |  | ||||||
| 					var streamResponse ChatCompletionsStreamResponse |  | ||||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) |  | ||||||
| 					if err != nil { |  | ||||||
| 						common.SysError("error unmarshalling stream response: " + err.Error()) |  | ||||||
| 						continue // just ignore the error |  | ||||||
| 					} |  | ||||||
| 					for _, choice := range streamResponse.Choices { |  | ||||||
| 						responseText += choice.Delta.Content |  | ||||||
| 					} |  | ||||||
| 				case RelayModeCompletions: |  | ||||||
| 					var streamResponse CompletionsStreamResponse |  | ||||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) |  | ||||||
| 					if err != nil { |  | ||||||
| 						common.SysError("error unmarshalling stream response: " + err.Error()) |  | ||||||
| 						continue |  | ||||||
| 					} |  | ||||||
| 					for _, choice := range streamResponse.Choices { |  | ||||||
| 						responseText += choice.Text |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		stopChan <- true |  | ||||||
| 	}() |  | ||||||
| 	setEventStreamHeaders(c) |  | ||||||
| 	c.Stream(func(w io.Writer) bool { |  | ||||||
| 		select { |  | ||||||
| 		case data := <-dataChan: |  | ||||||
| 			if strings.HasPrefix(data, "data: [DONE]") { |  | ||||||
| 				data = data[:12] |  | ||||||
| 			} |  | ||||||
| 			// some implementations may add \r at the end of data |  | ||||||
| 			data = strings.TrimSuffix(data, "\r") |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: data}) |  | ||||||
| 			return true |  | ||||||
| 		case <-stopChan: |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
| 	err := resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" |  | ||||||
| 	} |  | ||||||
| 	return nil, responseText |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	var textResponse TextResponse |  | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = json.Unmarshal(responseBody, &textResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	if textResponse.Error.Type != "" { |  | ||||||
| 		return &OpenAIErrorWithStatusCode{ |  | ||||||
| 			OpenAIError: textResponse.Error, |  | ||||||
| 			StatusCode:  resp.StatusCode, |  | ||||||
| 		}, nil |  | ||||||
| 	} |  | ||||||
| 	// Reset response body |  | ||||||
| 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) |  | ||||||
|  |  | ||||||
| 	// We shouldn't set the header before we parse the response body, because the parse part may fail. |  | ||||||
| 	// And then we will have to send an error response, but in this case, the header has already been set. |  | ||||||
| 	// So the httpClient will be confused by the response. |  | ||||||
| 	// For example, Postman will report error, and we cannot check the response at all. |  | ||||||
| 	for k, v := range resp.Header { |  | ||||||
| 		c.Writer.Header().Set(k, v[0]) |  | ||||||
| 	} |  | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) |  | ||||||
| 	_, err = io.Copy(c.Writer, resp.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if textResponse.Usage.TotalTokens == 0 { |  | ||||||
| 		completionTokens := 0 |  | ||||||
| 		for _, choice := range textResponse.Choices { |  | ||||||
| 			completionTokens += countTokenText(choice.Message.StringContent(), model) |  | ||||||
| 		} |  | ||||||
| 		textResponse.Usage = Usage{ |  | ||||||
| 			PromptTokens:     promptTokens, |  | ||||||
| 			CompletionTokens: completionTokens, |  | ||||||
| 			TotalTokens:      promptTokens + completionTokens, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return nil, &textResponse.Usage |  | ||||||
| } |  | ||||||
| @@ -1,206 +0,0 @@ | |||||||
| package controller |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"fmt" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"one-api/common" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body |  | ||||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body |  | ||||||
|  |  | ||||||
| type PaLMChatMessage struct { |  | ||||||
| 	Author  string `json:"author"` |  | ||||||
| 	Content string `json:"content"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type PaLMFilter struct { |  | ||||||
| 	Reason  string `json:"reason"` |  | ||||||
| 	Message string `json:"message"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type PaLMPrompt struct { |  | ||||||
| 	Messages []PaLMChatMessage `json:"messages"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type PaLMChatRequest struct { |  | ||||||
| 	Prompt         PaLMPrompt `json:"prompt"` |  | ||||||
| 	Temperature    float64    `json:"temperature,omitempty"` |  | ||||||
| 	CandidateCount int        `json:"candidateCount,omitempty"` |  | ||||||
| 	TopP           float64    `json:"topP,omitempty"` |  | ||||||
| 	TopK           int        `json:"topK,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type PaLMError struct { |  | ||||||
| 	Code    int    `json:"code"` |  | ||||||
| 	Message string `json:"message"` |  | ||||||
| 	Status  string `json:"status"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type PaLMChatResponse struct { |  | ||||||
| 	Candidates []PaLMChatMessage `json:"candidates"` |  | ||||||
| 	Messages   []Message         `json:"messages"` |  | ||||||
| 	Filters    []PaLMFilter      `json:"filters"` |  | ||||||
| 	Error      PaLMError         `json:"error"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { |  | ||||||
| 	palmRequest := PaLMChatRequest{ |  | ||||||
| 		Prompt: PaLMPrompt{ |  | ||||||
| 			Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), |  | ||||||
| 		}, |  | ||||||
| 		Temperature:    textRequest.Temperature, |  | ||||||
| 		CandidateCount: textRequest.N, |  | ||||||
| 		TopP:           textRequest.TopP, |  | ||||||
| 		TopK:           textRequest.MaxTokens, |  | ||||||
| 	} |  | ||||||
| 	for _, message := range textRequest.Messages { |  | ||||||
| 		palmMessage := PaLMChatMessage{ |  | ||||||
| 			Content: message.StringContent(), |  | ||||||
| 		} |  | ||||||
| 		if message.Role == "user" { |  | ||||||
| 			palmMessage.Author = "0" |  | ||||||
| 		} else { |  | ||||||
| 			palmMessage.Author = "1" |  | ||||||
| 		} |  | ||||||
| 		palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage) |  | ||||||
| 	} |  | ||||||
| 	return &palmRequest |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { |  | ||||||
| 	fullTextResponse := OpenAITextResponse{ |  | ||||||
| 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), |  | ||||||
| 	} |  | ||||||
| 	for i, candidate := range response.Candidates { |  | ||||||
| 		choice := OpenAITextResponseChoice{ |  | ||||||
| 			Index: i, |  | ||||||
| 			Message: Message{ |  | ||||||
| 				Role:    "assistant", |  | ||||||
| 				Content: candidate.Content, |  | ||||||
| 			}, |  | ||||||
| 			FinishReason: "stop", |  | ||||||
| 		} |  | ||||||
| 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice) |  | ||||||
| 	} |  | ||||||
| 	return &fullTextResponse |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse { |  | ||||||
| 	var choice ChatCompletionsStreamResponseChoice |  | ||||||
| 	if len(palmResponse.Candidates) > 0 { |  | ||||||
| 		choice.Delta.Content = palmResponse.Candidates[0].Content |  | ||||||
| 	} |  | ||||||
| 	choice.FinishReason = &stopFinishReason |  | ||||||
| 	var response ChatCompletionsStreamResponse |  | ||||||
| 	response.Object = "chat.completion.chunk" |  | ||||||
| 	response.Model = "palm2" |  | ||||||
| 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} |  | ||||||
| 	return &response |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { |  | ||||||
| 	responseText := "" |  | ||||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) |  | ||||||
| 	createdTime := common.GetTimestamp() |  | ||||||
| 	dataChan := make(chan string) |  | ||||||
| 	stopChan := make(chan bool) |  | ||||||
| 	go func() { |  | ||||||
| 		responseBody, err := io.ReadAll(resp.Body) |  | ||||||
| 		if err != nil { |  | ||||||
| 			common.SysError("error reading stream response: " + err.Error()) |  | ||||||
| 			stopChan <- true |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		err = resp.Body.Close() |  | ||||||
| 		if err != nil { |  | ||||||
| 			common.SysError("error closing stream response: " + err.Error()) |  | ||||||
| 			stopChan <- true |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		var palmResponse PaLMChatResponse |  | ||||||
| 		err = json.Unmarshal(responseBody, &palmResponse) |  | ||||||
| 		if err != nil { |  | ||||||
| 			common.SysError("error unmarshalling stream response: " + err.Error()) |  | ||||||
| 			stopChan <- true |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) |  | ||||||
| 		fullTextResponse.Id = responseId |  | ||||||
| 		fullTextResponse.Created = createdTime |  | ||||||
| 		if len(palmResponse.Candidates) > 0 { |  | ||||||
| 			responseText = palmResponse.Candidates[0].Content |  | ||||||
| 		} |  | ||||||
| 		jsonResponse, err := json.Marshal(fullTextResponse) |  | ||||||
| 		if err != nil { |  | ||||||
| 			common.SysError("error marshalling stream response: " + err.Error()) |  | ||||||
| 			stopChan <- true |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		dataChan <- string(jsonResponse) |  | ||||||
| 		stopChan <- true |  | ||||||
| 	}() |  | ||||||
| 	setEventStreamHeaders(c) |  | ||||||
| 	c.Stream(func(w io.Writer) bool { |  | ||||||
| 		select { |  | ||||||
| 		case data := <-dataChan: |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + data}) |  | ||||||
| 			return true |  | ||||||
| 		case <-stopChan: |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
| 	err := resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" |  | ||||||
| 	} |  | ||||||
| 	return nil, responseText |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	var palmResponse PaLMChatResponse |  | ||||||
| 	err = json.Unmarshal(responseBody, &palmResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { |  | ||||||
| 		return &OpenAIErrorWithStatusCode{ |  | ||||||
| 			OpenAIError: OpenAIError{ |  | ||||||
| 				Message: palmResponse.Error.Message, |  | ||||||
| 				Type:    palmResponse.Error.Status, |  | ||||||
| 				Param:   "", |  | ||||||
| 				Code:    palmResponse.Error.Code, |  | ||||||
| 			}, |  | ||||||
| 			StatusCode: resp.StatusCode, |  | ||||||
| 		}, nil |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := responsePaLM2OpenAI(&palmResponse) |  | ||||||
| 	fullTextResponse.Model = model |  | ||||||
| 	completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) |  | ||||||
| 	usage := Usage{ |  | ||||||
| 		PromptTokens:     promptTokens, |  | ||||||
| 		CompletionTokens: completionTokens, |  | ||||||
| 		TotalTokens:      promptTokens + completionTokens, |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse.Usage = usage |  | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") |  | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) |  | ||||||
| 	_, err = c.Writer.Write(jsonResponse) |  | ||||||
| 	return nil, &usage |  | ||||||
| } |  | ||||||
| @@ -1,288 +0,0 @@ | |||||||
| package controller |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bufio" |  | ||||||
| 	"crypto/hmac" |  | ||||||
| 	"crypto/sha1" |  | ||||||
| 	"encoding/base64" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" |  | ||||||
| 	"fmt" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"one-api/common" |  | ||||||
| 	"sort" |  | ||||||
| 	"strconv" |  | ||||||
| 	"strings" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // https://cloud.tencent.com/document/product/1729/97732 |  | ||||||
|  |  | ||||||
| type TencentMessage struct { |  | ||||||
| 	Role    string `json:"role"` |  | ||||||
| 	Content string `json:"content"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TencentChatRequest struct { |  | ||||||
| 	AppId    int64  `json:"app_id"`    // 腾讯云账号的 APPID |  | ||||||
| 	SecretId string `json:"secret_id"` // 官网 SecretId |  | ||||||
| 	// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 |  | ||||||
| 	// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 |  | ||||||
| 	Timestamp int64 `json:"timestamp"` |  | ||||||
| 	// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, |  | ||||||
| 	// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 |  | ||||||
| 	Expired int64  `json:"expired"` |  | ||||||
| 	QueryID string `json:"query_id"` //请求 Id,用于问题排查 |  | ||||||
| 	// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 |  | ||||||
| 	// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 |  | ||||||
| 	// 建议该参数和 top_p 只设置1个,不要同时更改 top_p |  | ||||||
| 	Temperature float64 `json:"temperature"` |  | ||||||
| 	// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 |  | ||||||
| 	// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 |  | ||||||
| 	// 建议该参数和 temperature 只设置1个,不要同时更改 |  | ||||||
| 	TopP float64 `json:"top_p"` |  | ||||||
| 	// Stream 0:同步,1:流式 (默认,协议:SSE) |  | ||||||
| 	// 同步请求超时:60s,如果内容较长建议使用流式 |  | ||||||
| 	Stream int `json:"stream"` |  | ||||||
| 	// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 |  | ||||||
| 	// 输入 content 总数最大支持 3000 token。 |  | ||||||
| 	Messages []TencentMessage `json:"messages"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TencentError struct { |  | ||||||
| 	Code    int    `json:"code"` |  | ||||||
| 	Message string `json:"message"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TencentUsage struct { |  | ||||||
| 	InputTokens  int `json:"input_tokens"` |  | ||||||
| 	OutputTokens int `json:"output_tokens"` |  | ||||||
| 	TotalTokens  int `json:"total_tokens"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TencentResponseChoices struct { |  | ||||||
| 	FinishReason string         `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 |  | ||||||
| 	Messages     TencentMessage `json:"messages,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 |  | ||||||
| 	Delta        TencentMessage `json:"delta,omitempty"`         // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TencentChatResponse struct { |  | ||||||
| 	Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 |  | ||||||
| 	Created string                   `json:"created,omitempty"` // unix 时间戳的字符串 |  | ||||||
| 	Id      string                   `json:"id,omitempty"`      // 会话 id |  | ||||||
| 	Usage   Usage                    `json:"usage,omitempty"`   // token 数量 |  | ||||||
| 	Error   TencentError             `json:"error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值 |  | ||||||
| 	Note    string                   `json:"note,omitempty"`    // 注释 |  | ||||||
| 	ReqID   string                   `json:"req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { |  | ||||||
| 	messages := make([]TencentMessage, 0, len(request.Messages)) |  | ||||||
| 	for i := 0; i < len(request.Messages); i++ { |  | ||||||
| 		message := request.Messages[i] |  | ||||||
| 		if message.Role == "system" { |  | ||||||
| 			messages = append(messages, TencentMessage{ |  | ||||||
| 				Role:    "user", |  | ||||||
| 				Content: message.StringContent(), |  | ||||||
| 			}) |  | ||||||
| 			messages = append(messages, TencentMessage{ |  | ||||||
| 				Role:    "assistant", |  | ||||||
| 				Content: "Okay", |  | ||||||
| 			}) |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		messages = append(messages, TencentMessage{ |  | ||||||
| 			Content: message.StringContent(), |  | ||||||
| 			Role:    message.Role, |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| 	stream := 0 |  | ||||||
| 	if request.Stream { |  | ||||||
| 		stream = 1 |  | ||||||
| 	} |  | ||||||
| 	return &TencentChatRequest{ |  | ||||||
| 		Timestamp:   common.GetTimestamp(), |  | ||||||
| 		Expired:     common.GetTimestamp() + 24*60*60, |  | ||||||
| 		QueryID:     common.GetUUID(), |  | ||||||
| 		Temperature: request.Temperature, |  | ||||||
| 		TopP:        request.TopP, |  | ||||||
| 		Stream:      stream, |  | ||||||
| 		Messages:    messages, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { |  | ||||||
| 	fullTextResponse := OpenAITextResponse{ |  | ||||||
| 		Object:  "chat.completion", |  | ||||||
| 		Created: common.GetTimestamp(), |  | ||||||
| 		Usage:   response.Usage, |  | ||||||
| 	} |  | ||||||
| 	if len(response.Choices) > 0 { |  | ||||||
| 		choice := OpenAITextResponseChoice{ |  | ||||||
| 			Index: 0, |  | ||||||
| 			Message: Message{ |  | ||||||
| 				Role:    "assistant", |  | ||||||
| 				Content: response.Choices[0].Messages.Content, |  | ||||||
| 			}, |  | ||||||
| 			FinishReason: response.Choices[0].FinishReason, |  | ||||||
| 		} |  | ||||||
| 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice) |  | ||||||
| 	} |  | ||||||
| 	return &fullTextResponse |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse { |  | ||||||
| 	response := ChatCompletionsStreamResponse{ |  | ||||||
| 		Object:  "chat.completion.chunk", |  | ||||||
| 		Created: common.GetTimestamp(), |  | ||||||
| 		Model:   "tencent-hunyuan", |  | ||||||
| 	} |  | ||||||
| 	if len(TencentResponse.Choices) > 0 { |  | ||||||
| 		var choice ChatCompletionsStreamResponseChoice |  | ||||||
| 		choice.Delta.Content = TencentResponse.Choices[0].Delta.Content |  | ||||||
| 		if TencentResponse.Choices[0].FinishReason == "stop" { |  | ||||||
| 			choice.FinishReason = &stopFinishReason |  | ||||||
| 		} |  | ||||||
| 		response.Choices = append(response.Choices, choice) |  | ||||||
| 	} |  | ||||||
| 	return &response |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { |  | ||||||
| 	var responseText string |  | ||||||
| 	scanner := bufio.NewScanner(resp.Body) |  | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { |  | ||||||
| 		if atEOF && len(data) == 0 { |  | ||||||
| 			return 0, nil, nil |  | ||||||
| 		} |  | ||||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { |  | ||||||
| 			return i + 1, data[0:i], nil |  | ||||||
| 		} |  | ||||||
| 		if atEOF { |  | ||||||
| 			return len(data), data, nil |  | ||||||
| 		} |  | ||||||
| 		return 0, nil, nil |  | ||||||
| 	}) |  | ||||||
| 	dataChan := make(chan string) |  | ||||||
| 	stopChan := make(chan bool) |  | ||||||
| 	go func() { |  | ||||||
| 		for scanner.Scan() { |  | ||||||
| 			data := scanner.Text() |  | ||||||
| 			if len(data) < 5 { // ignore blank line or wrong format |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			if data[:5] != "data:" { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			data = data[5:] |  | ||||||
| 			dataChan <- data |  | ||||||
| 		} |  | ||||||
| 		stopChan <- true |  | ||||||
| 	}() |  | ||||||
| 	setEventStreamHeaders(c) |  | ||||||
| 	c.Stream(func(w io.Writer) bool { |  | ||||||
| 		select { |  | ||||||
| 		case data := <-dataChan: |  | ||||||
| 			var TencentResponse TencentChatResponse |  | ||||||
| 			err := json.Unmarshal([]byte(data), &TencentResponse) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			response := streamResponseTencent2OpenAI(&TencentResponse) |  | ||||||
| 			if len(response.Choices) != 0 { |  | ||||||
| 				responseText += response.Choices[0].Delta.Content |  | ||||||
| 			} |  | ||||||
| 			jsonResponse, err := json.Marshal(response) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error marshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) |  | ||||||
| 			return true |  | ||||||
| 		case <-stopChan: |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
| 	err := resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" |  | ||||||
| 	} |  | ||||||
| 	return nil, responseText |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	var TencentResponse TencentChatResponse |  | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = json.Unmarshal(responseBody, &TencentResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	if TencentResponse.Error.Code != 0 { |  | ||||||
| 		return &OpenAIErrorWithStatusCode{ |  | ||||||
| 			OpenAIError: OpenAIError{ |  | ||||||
| 				Message: TencentResponse.Error.Message, |  | ||||||
| 				Code:    TencentResponse.Error.Code, |  | ||||||
| 			}, |  | ||||||
| 			StatusCode: resp.StatusCode, |  | ||||||
| 		}, nil |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := responseTencent2OpenAI(&TencentResponse) |  | ||||||
| 	fullTextResponse.Model = "hunyuan" |  | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") |  | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) |  | ||||||
| 	_, err = c.Writer.Write(jsonResponse) |  | ||||||
| 	return nil, &fullTextResponse.Usage |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) { |  | ||||||
| 	parts := strings.Split(config, "|") |  | ||||||
| 	if len(parts) != 3 { |  | ||||||
| 		err = errors.New("invalid tencent config") |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	appId, err = strconv.ParseInt(parts[0], 10, 64) |  | ||||||
| 	secretId = parts[1] |  | ||||||
| 	secretKey = parts[2] |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func getTencentSign(req TencentChatRequest, secretKey string) string { |  | ||||||
| 	params := make([]string, 0) |  | ||||||
| 	params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) |  | ||||||
| 	params = append(params, "secret_id="+req.SecretId) |  | ||||||
| 	params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) |  | ||||||
| 	params = append(params, "query_id="+req.QueryID) |  | ||||||
| 	params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) |  | ||||||
| 	params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) |  | ||||||
| 	params = append(params, "stream="+strconv.Itoa(req.Stream)) |  | ||||||
| 	params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) |  | ||||||
|  |  | ||||||
| 	var messageStr string |  | ||||||
| 	for _, msg := range req.Messages { |  | ||||||
| 		messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) |  | ||||||
| 	} |  | ||||||
| 	messageStr = strings.TrimSuffix(messageStr, ",") |  | ||||||
| 	params = append(params, "messages=["+messageStr+"]") |  | ||||||
|  |  | ||||||
| 	sort.Sort(sort.StringSlice(params)) |  | ||||||
| 	url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") |  | ||||||
| 	mac := hmac.New(sha1.New, []byte(secretKey)) |  | ||||||
| 	signURL := url |  | ||||||
| 	mac.Write([]byte(signURL)) |  | ||||||
| 	sign := mac.Sum([]byte(nil)) |  | ||||||
| 	return base64.StdEncoding.EncodeToString(sign) |  | ||||||
| } |  | ||||||
| @@ -1,689 +0,0 @@ | |||||||
| package controller |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bytes" |  | ||||||
| 	"context" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" |  | ||||||
| 	"fmt" |  | ||||||
| 	"io" |  | ||||||
| 	"math" |  | ||||||
| 	"net/http" |  | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/model" |  | ||||||
| 	"strings" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	APITypeOpenAI = iota |  | ||||||
| 	APITypeClaude |  | ||||||
| 	APITypePaLM |  | ||||||
| 	APITypeBaidu |  | ||||||
| 	APITypeZhipu |  | ||||||
| 	APITypeAli |  | ||||||
| 	APITypeXunfei |  | ||||||
| 	APITypeAIProxyLibrary |  | ||||||
| 	APITypeTencent |  | ||||||
| 	APITypeGemini |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| var httpClient *http.Client |  | ||||||
| var impatientHTTPClient *http.Client |  | ||||||
|  |  | ||||||
| func init() { |  | ||||||
| 	if common.RelayTimeout == 0 { |  | ||||||
| 		httpClient = &http.Client{} |  | ||||||
| 	} else { |  | ||||||
| 		httpClient = &http.Client{ |  | ||||||
| 			Timeout: time.Duration(common.RelayTimeout) * time.Second, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	impatientHTTPClient = &http.Client{ |  | ||||||
| 		Timeout: 5 * time.Second, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { |  | ||||||
| 	channelType := c.GetInt("channel") |  | ||||||
| 	channelId := c.GetInt("channel_id") |  | ||||||
| 	tokenId := c.GetInt("token_id") |  | ||||||
| 	userId := c.GetInt("id") |  | ||||||
| 	group := c.GetString("group") |  | ||||||
| 	var textRequest GeneralOpenAIRequest |  | ||||||
| 	err := common.UnmarshalBodyReusable(c, &textRequest) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) |  | ||||||
| 	} |  | ||||||
| 	if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { |  | ||||||
| 		return errorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest) |  | ||||||
| 	} |  | ||||||
| 	if relayMode == RelayModeModerations && textRequest.Model == "" { |  | ||||||
| 		textRequest.Model = "text-moderation-latest" |  | ||||||
| 	} |  | ||||||
| 	if relayMode == RelayModeEmbeddings && textRequest.Model == "" { |  | ||||||
| 		textRequest.Model = c.Param("model") |  | ||||||
| 	} |  | ||||||
| 	// request validation |  | ||||||
| 	if textRequest.Model == "" { |  | ||||||
| 		return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) |  | ||||||
| 	} |  | ||||||
| 	switch relayMode { |  | ||||||
| 	case RelayModeCompletions: |  | ||||||
| 		if textRequest.Prompt == "" { |  | ||||||
| 			return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) |  | ||||||
| 		} |  | ||||||
| 	case RelayModeChatCompletions: |  | ||||||
| 		if textRequest.Messages == nil || len(textRequest.Messages) == 0 { |  | ||||||
| 			return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) |  | ||||||
| 		} |  | ||||||
| 	case RelayModeEmbeddings: |  | ||||||
| 	case RelayModeModerations: |  | ||||||
| 		if textRequest.Input == "" { |  | ||||||
| 			return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) |  | ||||||
| 		} |  | ||||||
| 	case RelayModeEdits: |  | ||||||
| 		if textRequest.Instruction == "" { |  | ||||||
| 			return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	// map model name |  | ||||||
| 	modelMapping := c.GetString("model_mapping") |  | ||||||
| 	isModelMapped := false |  | ||||||
| 	if modelMapping != "" && modelMapping != "{}" { |  | ||||||
| 		modelMap := make(map[string]string) |  | ||||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		if modelMap[textRequest.Model] != "" { |  | ||||||
| 			textRequest.Model = modelMap[textRequest.Model] |  | ||||||
| 			isModelMapped = true |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	apiType := APITypeOpenAI |  | ||||||
| 	switch channelType { |  | ||||||
| 	case common.ChannelTypeAnthropic: |  | ||||||
| 		apiType = APITypeClaude |  | ||||||
| 	case common.ChannelTypeBaidu: |  | ||||||
| 		apiType = APITypeBaidu |  | ||||||
| 	case common.ChannelTypePaLM: |  | ||||||
| 		apiType = APITypePaLM |  | ||||||
| 	case common.ChannelTypeZhipu: |  | ||||||
| 		apiType = APITypeZhipu |  | ||||||
| 	case common.ChannelTypeAli: |  | ||||||
| 		apiType = APITypeAli |  | ||||||
| 	case common.ChannelTypeXunfei: |  | ||||||
| 		apiType = APITypeXunfei |  | ||||||
| 	case common.ChannelTypeAIProxyLibrary: |  | ||||||
| 		apiType = APITypeAIProxyLibrary |  | ||||||
| 	case common.ChannelTypeTencent: |  | ||||||
| 		apiType = APITypeTencent |  | ||||||
| 	case common.ChannelTypeGemini: |  | ||||||
| 		apiType = APITypeGemini |  | ||||||
| 	} |  | ||||||
| 	baseURL := common.ChannelBaseURLs[channelType] |  | ||||||
| 	requestURL := c.Request.URL.String() |  | ||||||
| 	if c.GetString("base_url") != "" { |  | ||||||
| 		baseURL = c.GetString("base_url") |  | ||||||
| 	} |  | ||||||
| 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) |  | ||||||
| 	switch apiType { |  | ||||||
| 	case APITypeOpenAI: |  | ||||||
| 		if channelType == common.ChannelTypeAzure { |  | ||||||
| 			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api |  | ||||||
| 			apiVersion := GetAPIVersion(c) |  | ||||||
| 			requestURL := strings.Split(requestURL, "?")[0] |  | ||||||
| 			requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) |  | ||||||
| 			baseURL = c.GetString("base_url") |  | ||||||
| 			task := strings.TrimPrefix(requestURL, "/v1/") |  | ||||||
| 			model_ := textRequest.Model |  | ||||||
| 			model_ = strings.Replace(model_, ".", "", -1) |  | ||||||
| 			// https://github.com/songquanpeng/one-api/issues/67 |  | ||||||
| 			model_ = strings.TrimSuffix(model_, "-0301") |  | ||||||
| 			model_ = strings.TrimSuffix(model_, "-0314") |  | ||||||
| 			model_ = strings.TrimSuffix(model_, "-0613") |  | ||||||
|  |  | ||||||
| 			requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) |  | ||||||
| 			fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType) |  | ||||||
| 		} |  | ||||||
| 	case APITypeClaude: |  | ||||||
| 		fullRequestURL = "https://api.anthropic.com/v1/complete" |  | ||||||
| 		if baseURL != "" { |  | ||||||
| 			fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL) |  | ||||||
| 		} |  | ||||||
| 	case APITypeBaidu: |  | ||||||
| 		switch textRequest.Model { |  | ||||||
| 		case "ERNIE-Bot": |  | ||||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" |  | ||||||
| 		case "ERNIE-Bot-turbo": |  | ||||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" |  | ||||||
| 		case "ERNIE-Bot-4": |  | ||||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" |  | ||||||
| 		case "BLOOMZ-7B": |  | ||||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" |  | ||||||
| 		case "Embedding-V1": |  | ||||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" |  | ||||||
| 		} |  | ||||||
| 		apiKey := c.Request.Header.Get("Authorization") |  | ||||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") |  | ||||||
| 		var err error |  | ||||||
| 		if apiKey, err = getBaiduAccessToken(apiKey); err != nil { |  | ||||||
| 			return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		fullRequestURL += "?access_token=" + apiKey |  | ||||||
| 	case APITypePaLM: |  | ||||||
| 		fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" |  | ||||||
| 		if baseURL != "" { |  | ||||||
| 			fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) |  | ||||||
| 		} |  | ||||||
| 	case APITypeGemini: |  | ||||||
| 		requestBaseURL := "https://generativelanguage.googleapis.com" |  | ||||||
| 		if baseURL != "" { |  | ||||||
| 			requestBaseURL = baseURL |  | ||||||
| 		} |  | ||||||
| 		version := "v1" |  | ||||||
| 		if c.GetString("api_version") != "" { |  | ||||||
| 			version = c.GetString("api_version") |  | ||||||
| 		} |  | ||||||
| 		action := "generateContent" |  | ||||||
| 		if textRequest.Stream { |  | ||||||
| 			action = "streamGenerateContent" |  | ||||||
| 		} |  | ||||||
| 		fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) |  | ||||||
| 	case APITypeZhipu: |  | ||||||
| 		method := "invoke" |  | ||||||
| 		if textRequest.Stream { |  | ||||||
| 			method = "sse-invoke" |  | ||||||
| 		} |  | ||||||
| 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) |  | ||||||
| 	case APITypeAli: |  | ||||||
| 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" |  | ||||||
| 		if relayMode == RelayModeEmbeddings { |  | ||||||
| 			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" |  | ||||||
| 		} |  | ||||||
| 	case APITypeTencent: |  | ||||||
| 		fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" |  | ||||||
| 	case APITypeAIProxyLibrary: |  | ||||||
| 		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) |  | ||||||
| 	} |  | ||||||
| 	var promptTokens int |  | ||||||
| 	var completionTokens int |  | ||||||
| 	switch relayMode { |  | ||||||
| 	case RelayModeChatCompletions: |  | ||||||
| 		promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) |  | ||||||
| 	case RelayModeCompletions: |  | ||||||
| 		promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) |  | ||||||
| 	case RelayModeModerations: |  | ||||||
| 		promptTokens = countTokenInput(textRequest.Input, textRequest.Model) |  | ||||||
| 	} |  | ||||||
| 	preConsumedTokens := common.PreConsumedQuota |  | ||||||
| 	if textRequest.MaxTokens != 0 { |  | ||||||
| 		preConsumedTokens = promptTokens + textRequest.MaxTokens |  | ||||||
| 	} |  | ||||||
| 	modelRatio := common.GetModelRatio(textRequest.Model) |  | ||||||
| 	groupRatio := common.GetGroupRatio(group) |  | ||||||
| 	ratio := modelRatio * groupRatio |  | ||||||
| 	preConsumedQuota := int(float64(preConsumedTokens) * ratio) |  | ||||||
| 	userQuota, err := model.CacheGetUserQuota(userId) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
| 	if userQuota-preConsumedQuota < 0 { |  | ||||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) |  | ||||||
| 	} |  | ||||||
| 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
| 	if userQuota > 100*preConsumedQuota { |  | ||||||
| 		// in this case, we do not pre-consume quota |  | ||||||
| 		// because the user has enough quota |  | ||||||
| 		preConsumedQuota = 0 |  | ||||||
| 		common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) |  | ||||||
| 	} |  | ||||||
| 	if preConsumedQuota > 0 { |  | ||||||
| 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	var requestBody io.Reader |  | ||||||
| 	if isModelMapped { |  | ||||||
| 		jsonStr, err := json.Marshal(textRequest) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) |  | ||||||
| 	} else { |  | ||||||
| 		requestBody = c.Request.Body |  | ||||||
| 	} |  | ||||||
| 	switch apiType { |  | ||||||
| 	case APITypeClaude: |  | ||||||
| 		claudeRequest := requestOpenAI2Claude(textRequest) |  | ||||||
| 		jsonStr, err := json.Marshal(claudeRequest) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) |  | ||||||
| 	case APITypeBaidu: |  | ||||||
| 		var jsonData []byte |  | ||||||
| 		var err error |  | ||||||
| 		switch relayMode { |  | ||||||
| 		case RelayModeEmbeddings: |  | ||||||
| 			baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) |  | ||||||
| 			jsonData, err = json.Marshal(baiduEmbeddingRequest) |  | ||||||
| 		default: |  | ||||||
| 			baiduRequest := requestOpenAI2Baidu(textRequest) |  | ||||||
| 			jsonData, err = json.Marshal(baiduRequest) |  | ||||||
| 		} |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		requestBody = bytes.NewBuffer(jsonData) |  | ||||||
| 	case APITypePaLM: |  | ||||||
| 		palmRequest := requestOpenAI2PaLM(textRequest) |  | ||||||
| 		jsonStr, err := json.Marshal(palmRequest) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) |  | ||||||
| 	case APITypeGemini: |  | ||||||
| 		geminiChatRequest := requestOpenAI2Gemini(textRequest) |  | ||||||
| 		jsonStr, err := json.Marshal(geminiChatRequest) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) |  | ||||||
| 	case APITypeZhipu: |  | ||||||
| 		zhipuRequest := requestOpenAI2Zhipu(textRequest) |  | ||||||
| 		jsonStr, err := json.Marshal(zhipuRequest) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) |  | ||||||
| 	case APITypeAli: |  | ||||||
| 		var jsonStr []byte |  | ||||||
| 		var err error |  | ||||||
| 		switch relayMode { |  | ||||||
| 		case RelayModeEmbeddings: |  | ||||||
| 			aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) |  | ||||||
| 			jsonStr, err = json.Marshal(aliEmbeddingRequest) |  | ||||||
| 		default: |  | ||||||
| 			aliRequest := requestOpenAI2Ali(textRequest) |  | ||||||
| 			jsonStr, err = json.Marshal(aliRequest) |  | ||||||
| 		} |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) |  | ||||||
| 	case APITypeTencent: |  | ||||||
| 		apiKey := c.Request.Header.Get("Authorization") |  | ||||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") |  | ||||||
| 		appId, secretId, secretKey, err := parseTencentConfig(apiKey) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		tencentRequest := requestOpenAI2Tencent(textRequest) |  | ||||||
| 		tencentRequest.AppId = appId |  | ||||||
| 		tencentRequest.SecretId = secretId |  | ||||||
| 		jsonStr, err := json.Marshal(tencentRequest) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		sign := getTencentSign(*tencentRequest, secretKey) |  | ||||||
| 		c.Request.Header.Set("Authorization", sign) |  | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) |  | ||||||
| 	case APITypeAIProxyLibrary: |  | ||||||
| 		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) |  | ||||||
| 		aiProxyLibraryRequest.LibraryId = c.GetString("library_id") |  | ||||||
| 		jsonStr, err := json.Marshal(aiProxyLibraryRequest) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	var req *http.Request |  | ||||||
| 	var resp *http.Response |  | ||||||
| 	isStream := textRequest.Stream |  | ||||||
|  |  | ||||||
| 	if apiType != APITypeXunfei { // cause xunfei use websocket |  | ||||||
| 		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		apiKey := c.Request.Header.Get("Authorization") |  | ||||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") |  | ||||||
| 		switch apiType { |  | ||||||
| 		case APITypeOpenAI: |  | ||||||
| 			if channelType == common.ChannelTypeAzure { |  | ||||||
| 				req.Header.Set("api-key", apiKey) |  | ||||||
| 			} else { |  | ||||||
| 				req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) |  | ||||||
| 				if channelType == common.ChannelTypeOpenRouter { |  | ||||||
| 					req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") |  | ||||||
| 					req.Header.Set("X-Title", "One API") |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		case APITypeClaude: |  | ||||||
| 			req.Header.Set("x-api-key", apiKey) |  | ||||||
| 			anthropicVersion := c.Request.Header.Get("anthropic-version") |  | ||||||
| 			if anthropicVersion == "" { |  | ||||||
| 				anthropicVersion = "2023-06-01" |  | ||||||
| 			} |  | ||||||
| 			req.Header.Set("anthropic-version", anthropicVersion) |  | ||||||
| 		case APITypeZhipu: |  | ||||||
| 			token := getZhipuToken(apiKey) |  | ||||||
| 			req.Header.Set("Authorization", token) |  | ||||||
| 		case APITypeAli: |  | ||||||
| 			req.Header.Set("Authorization", "Bearer "+apiKey) |  | ||||||
| 			if textRequest.Stream { |  | ||||||
| 				req.Header.Set("X-DashScope-SSE", "enable") |  | ||||||
| 			} |  | ||||||
| 			if c.GetString("plugin") != "" { |  | ||||||
| 				req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) |  | ||||||
| 			} |  | ||||||
| 		case APITypeTencent: |  | ||||||
| 			req.Header.Set("Authorization", apiKey) |  | ||||||
| 		case APITypePaLM: |  | ||||||
| 			req.Header.Set("x-goog-api-key", apiKey) |  | ||||||
| 		case APITypeGemini: |  | ||||||
| 			req.Header.Set("x-goog-api-key", apiKey) |  | ||||||
| 		default: |  | ||||||
| 			req.Header.Set("Authorization", "Bearer "+apiKey) |  | ||||||
| 		} |  | ||||||
| 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) |  | ||||||
| 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) |  | ||||||
| 		if isStream && c.Request.Header.Get("Accept") == "" { |  | ||||||
| 			req.Header.Set("Accept", "text/event-stream") |  | ||||||
| 		} |  | ||||||
| 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) |  | ||||||
| 		resp, err = httpClient.Do(req) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		err = req.Body.Close() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		err = c.Request.Body.Close() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) |  | ||||||
| 		} |  | ||||||
| 		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") |  | ||||||
|  |  | ||||||
| 		if resp.StatusCode != http.StatusOK { |  | ||||||
| 			if preConsumedQuota != 0 { |  | ||||||
| 				go func(ctx context.Context) { |  | ||||||
| 					// return pre-consumed quota |  | ||||||
| 					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) |  | ||||||
| 					if err != nil { |  | ||||||
| 						common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) |  | ||||||
| 					} |  | ||||||
| 				}(c.Request.Context()) |  | ||||||
| 			} |  | ||||||
| 			return relayErrorHandler(resp) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	var textResponse TextResponse |  | ||||||
| 	tokenName := c.GetString("token_name") |  | ||||||
|  |  | ||||||
| 	defer func(ctx context.Context) { |  | ||||||
| 		// c.Writer.Flush() |  | ||||||
| 		go func() { |  | ||||||
| 			quota := 0 |  | ||||||
| 			completionRatio := common.GetCompletionRatio(textRequest.Model) |  | ||||||
| 			promptTokens = textResponse.Usage.PromptTokens |  | ||||||
| 			completionTokens = textResponse.Usage.CompletionTokens |  | ||||||
| 			quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) |  | ||||||
| 			if ratio != 0 && quota <= 0 { |  | ||||||
| 				quota = 1 |  | ||||||
| 			} |  | ||||||
| 			totalTokens := promptTokens + completionTokens |  | ||||||
| 			if totalTokens == 0 { |  | ||||||
| 				// in this case, must be some error happened |  | ||||||
| 				// we cannot just return, because we may have to return the pre-consumed quota |  | ||||||
| 				quota = 0 |  | ||||||
| 			} |  | ||||||
| 			quotaDelta := quota - preConsumedQuota |  | ||||||
| 			err := model.PostConsumeTokenQuota(tokenId, quotaDelta) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.LogError(ctx, "error consuming token remain quota: "+err.Error()) |  | ||||||
| 			} |  | ||||||
| 			err = model.CacheUpdateUserQuota(userId) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.LogError(ctx, "error update user quota cache: "+err.Error()) |  | ||||||
| 			} |  | ||||||
| 			if quota != 0 { |  | ||||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) |  | ||||||
| 				model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) |  | ||||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) |  | ||||||
| 				model.UpdateChannelUsedQuota(channelId, quota) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 		}() |  | ||||||
| 	}(c.Request.Context()) |  | ||||||
| 	switch apiType { |  | ||||||
| 	case APITypeOpenAI: |  | ||||||
| 		if isStream { |  | ||||||
| 			err, responseText := openaiStreamHandler(c, resp, relayMode) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			textResponse.Usage.PromptTokens = promptTokens |  | ||||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) |  | ||||||
| 			return nil |  | ||||||
| 		} else { |  | ||||||
| 			err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			if usage != nil { |  | ||||||
| 				textResponse.Usage = *usage |  | ||||||
| 			} |  | ||||||
| 			return nil |  | ||||||
| 		} |  | ||||||
| 	case APITypeClaude: |  | ||||||
| 		if isStream { |  | ||||||
| 			err, responseText := claudeStreamHandler(c, resp) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			textResponse.Usage.PromptTokens = promptTokens |  | ||||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) |  | ||||||
| 			return nil |  | ||||||
| 		} else { |  | ||||||
| 			err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			if usage != nil { |  | ||||||
| 				textResponse.Usage = *usage |  | ||||||
| 			} |  | ||||||
| 			return nil |  | ||||||
| 		} |  | ||||||
| 	case APITypeBaidu: |  | ||||||
| 		if isStream { |  | ||||||
| 			err, usage := baiduStreamHandler(c, resp) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			if usage != nil { |  | ||||||
| 				textResponse.Usage = *usage |  | ||||||
| 			} |  | ||||||
| 			return nil |  | ||||||
| 		} else { |  | ||||||
| 			var err *OpenAIErrorWithStatusCode |  | ||||||
| 			var usage *Usage |  | ||||||
| 			switch relayMode { |  | ||||||
| 			case RelayModeEmbeddings: |  | ||||||
| 				err, usage = baiduEmbeddingHandler(c, resp) |  | ||||||
| 			default: |  | ||||||
| 				err, usage = baiduHandler(c, resp) |  | ||||||
| 			} |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			if usage != nil { |  | ||||||
| 				textResponse.Usage = *usage |  | ||||||
| 			} |  | ||||||
| 			return nil |  | ||||||
| 		} |  | ||||||
| 	case APITypePaLM: |  | ||||||
| 		if textRequest.Stream { // PaLM2 API does not support stream |  | ||||||
| 			err, responseText := palmStreamHandler(c, resp) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			textResponse.Usage.PromptTokens = promptTokens |  | ||||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) |  | ||||||
| 			return nil |  | ||||||
| 		} else { |  | ||||||
| 			err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			if usage != nil { |  | ||||||
| 				textResponse.Usage = *usage |  | ||||||
| 			} |  | ||||||
| 			return nil |  | ||||||
| 		} |  | ||||||
| 	case APITypeGemini: |  | ||||||
| 		if textRequest.Stream { |  | ||||||
| 			err, responseText := geminiChatStreamHandler(c, resp) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			textResponse.Usage.PromptTokens = promptTokens |  | ||||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) |  | ||||||
| 			return nil |  | ||||||
| 		} else { |  | ||||||
| 			err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			if usage != nil { |  | ||||||
| 				textResponse.Usage = *usage |  | ||||||
| 			} |  | ||||||
| 			return nil |  | ||||||
| 		} |  | ||||||
| 	case APITypeZhipu: |  | ||||||
| 		if isStream { |  | ||||||
| 			err, usage := zhipuStreamHandler(c, resp) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			if usage != nil { |  | ||||||
| 				textResponse.Usage = *usage |  | ||||||
| 			} |  | ||||||
| 			// zhipu's API does not return prompt tokens & completion tokens |  | ||||||
| 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens |  | ||||||
| 			return nil |  | ||||||
| 		} else { |  | ||||||
| 			err, usage := zhipuHandler(c, resp) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			if usage != nil { |  | ||||||
| 				textResponse.Usage = *usage |  | ||||||
| 			} |  | ||||||
| 			// zhipu's API does not return prompt tokens & completion tokens |  | ||||||
| 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens |  | ||||||
| 			return nil |  | ||||||
| 		} |  | ||||||
| 	case APITypeAli: |  | ||||||
| 		if isStream { |  | ||||||
| 			err, usage := aliStreamHandler(c, resp) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			if usage != nil { |  | ||||||
| 				textResponse.Usage = *usage |  | ||||||
| 			} |  | ||||||
| 			return nil |  | ||||||
| 		} else { |  | ||||||
| 			var err *OpenAIErrorWithStatusCode |  | ||||||
| 			var usage *Usage |  | ||||||
| 			switch relayMode { |  | ||||||
| 			case RelayModeEmbeddings: |  | ||||||
| 				err, usage = aliEmbeddingHandler(c, resp) |  | ||||||
| 			default: |  | ||||||
| 				err, usage = aliHandler(c, resp) |  | ||||||
| 			} |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			if usage != nil { |  | ||||||
| 				textResponse.Usage = *usage |  | ||||||
| 			} |  | ||||||
| 			return nil |  | ||||||
| 		} |  | ||||||
| 	case APITypeXunfei: |  | ||||||
| 		auth := c.Request.Header.Get("Authorization") |  | ||||||
| 		auth = strings.TrimPrefix(auth, "Bearer ") |  | ||||||
| 		splits := strings.Split(auth, "|") |  | ||||||
| 		if len(splits) != 3 { |  | ||||||
| 			return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) |  | ||||||
| 		} |  | ||||||
| 		var err *OpenAIErrorWithStatusCode |  | ||||||
| 		var usage *Usage |  | ||||||
| 		if isStream { |  | ||||||
| 			err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) |  | ||||||
| 		} else { |  | ||||||
| 			err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2]) |  | ||||||
| 		} |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		if usage != nil { |  | ||||||
| 			textResponse.Usage = *usage |  | ||||||
| 		} |  | ||||||
| 		return nil |  | ||||||
| 	case APITypeAIProxyLibrary: |  | ||||||
| 		if isStream { |  | ||||||
| 			err, usage := aiProxyLibraryStreamHandler(c, resp) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			if usage != nil { |  | ||||||
| 				textResponse.Usage = *usage |  | ||||||
| 			} |  | ||||||
| 			return nil |  | ||||||
| 		} else { |  | ||||||
| 			err, usage := aiProxyLibraryHandler(c, resp) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			if usage != nil { |  | ||||||
| 				textResponse.Usage = *usage |  | ||||||
| 			} |  | ||||||
| 			return nil |  | ||||||
| 		} |  | ||||||
| 	case APITypeTencent: |  | ||||||
| 		if isStream { |  | ||||||
| 			err, responseText := tencentStreamHandler(c, resp) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			textResponse.Usage.PromptTokens = promptTokens |  | ||||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) |  | ||||||
| 			return nil |  | ||||||
| 		} else { |  | ||||||
| 			err, usage := tencentHandler(c, resp) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 			if usage != nil { |  | ||||||
| 				textResponse.Usage = *usage |  | ||||||
| 			} |  | ||||||
| 			return nil |  | ||||||
| 		} |  | ||||||
| 	default: |  | ||||||
| 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @@ -1,312 +0,0 @@ | |||||||
| package controller |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"crypto/hmac" |  | ||||||
| 	"crypto/sha256" |  | ||||||
| 	"encoding/base64" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"fmt" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"github.com/gorilla/websocket" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"net/url" |  | ||||||
| 	"one-api/common" |  | ||||||
| 	"strings" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // https://console.xfyun.cn/services/cbm |  | ||||||
| // https://www.xfyun.cn/doc/spark/Web.html |  | ||||||
|  |  | ||||||
| type XunfeiMessage struct { |  | ||||||
| 	Role    string `json:"role"` |  | ||||||
| 	Content string `json:"content"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type XunfeiChatRequest struct { |  | ||||||
| 	Header struct { |  | ||||||
| 		AppId string `json:"app_id"` |  | ||||||
| 	} `json:"header"` |  | ||||||
| 	Parameter struct { |  | ||||||
| 		Chat struct { |  | ||||||
| 			Domain      string  `json:"domain,omitempty"` |  | ||||||
| 			Temperature float64 `json:"temperature,omitempty"` |  | ||||||
| 			TopK        int     `json:"top_k,omitempty"` |  | ||||||
| 			MaxTokens   int     `json:"max_tokens,omitempty"` |  | ||||||
| 			Auditing    bool    `json:"auditing,omitempty"` |  | ||||||
| 		} `json:"chat"` |  | ||||||
| 	} `json:"parameter"` |  | ||||||
| 	Payload struct { |  | ||||||
| 		Message struct { |  | ||||||
| 			Text []XunfeiMessage `json:"text"` |  | ||||||
| 		} `json:"message"` |  | ||||||
| 	} `json:"payload"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type XunfeiChatResponseTextItem struct { |  | ||||||
| 	Content string `json:"content"` |  | ||||||
| 	Role    string `json:"role"` |  | ||||||
| 	Index   int    `json:"index"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type XunfeiChatResponse struct { |  | ||||||
| 	Header struct { |  | ||||||
| 		Code    int    `json:"code"` |  | ||||||
| 		Message string `json:"message"` |  | ||||||
| 		Sid     string `json:"sid"` |  | ||||||
| 		Status  int    `json:"status"` |  | ||||||
| 	} `json:"header"` |  | ||||||
| 	Payload struct { |  | ||||||
| 		Choices struct { |  | ||||||
| 			Status int                          `json:"status"` |  | ||||||
| 			Seq    int                          `json:"seq"` |  | ||||||
| 			Text   []XunfeiChatResponseTextItem `json:"text"` |  | ||||||
| 		} `json:"choices"` |  | ||||||
| 		Usage struct { |  | ||||||
| 			//Text struct { |  | ||||||
| 			//	QuestionTokens   string `json:"question_tokens"` |  | ||||||
| 			//	PromptTokens     string `json:"prompt_tokens"` |  | ||||||
| 			//	CompletionTokens string `json:"completion_tokens"` |  | ||||||
| 			//	TotalTokens      string `json:"total_tokens"` |  | ||||||
| 			//} `json:"text"` |  | ||||||
| 			Text Usage `json:"text"` |  | ||||||
| 		} `json:"usage"` |  | ||||||
| 	} `json:"payload"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest { |  | ||||||
| 	messages := make([]XunfeiMessage, 0, len(request.Messages)) |  | ||||||
| 	for _, message := range request.Messages { |  | ||||||
| 		if message.Role == "system" { |  | ||||||
| 			messages = append(messages, XunfeiMessage{ |  | ||||||
| 				Role:    "user", |  | ||||||
| 				Content: message.StringContent(), |  | ||||||
| 			}) |  | ||||||
| 			messages = append(messages, XunfeiMessage{ |  | ||||||
| 				Role:    "assistant", |  | ||||||
| 				Content: "Okay", |  | ||||||
| 			}) |  | ||||||
| 		} else { |  | ||||||
| 			messages = append(messages, XunfeiMessage{ |  | ||||||
| 				Role:    message.Role, |  | ||||||
| 				Content: message.StringContent(), |  | ||||||
| 			}) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	xunfeiRequest := XunfeiChatRequest{} |  | ||||||
| 	xunfeiRequest.Header.AppId = xunfeiAppId |  | ||||||
| 	xunfeiRequest.Parameter.Chat.Domain = domain |  | ||||||
| 	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature |  | ||||||
| 	xunfeiRequest.Parameter.Chat.TopK = request.N |  | ||||||
| 	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens |  | ||||||
| 	xunfeiRequest.Payload.Message.Text = messages |  | ||||||
| 	return &xunfeiRequest |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { |  | ||||||
| 	if len(response.Payload.Choices.Text) == 0 { |  | ||||||
| 		response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ |  | ||||||
| 			{ |  | ||||||
| 				Content: "", |  | ||||||
| 			}, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	choice := OpenAITextResponseChoice{ |  | ||||||
| 		Index: 0, |  | ||||||
| 		Message: Message{ |  | ||||||
| 			Role:    "assistant", |  | ||||||
| 			Content: response.Payload.Choices.Text[0].Content, |  | ||||||
| 		}, |  | ||||||
| 		FinishReason: stopFinishReason, |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := OpenAITextResponse{ |  | ||||||
| 		Object:  "chat.completion", |  | ||||||
| 		Created: common.GetTimestamp(), |  | ||||||
| 		Choices: []OpenAITextResponseChoice{choice}, |  | ||||||
| 		Usage:   response.Payload.Usage.Text, |  | ||||||
| 	} |  | ||||||
| 	return &fullTextResponse |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { |  | ||||||
| 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { |  | ||||||
| 		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ |  | ||||||
| 			{ |  | ||||||
| 				Content: "", |  | ||||||
| 			}, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	var choice ChatCompletionsStreamResponseChoice |  | ||||||
| 	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content |  | ||||||
| 	if xunfeiResponse.Payload.Choices.Status == 2 { |  | ||||||
| 		choice.FinishReason = &stopFinishReason |  | ||||||
| 	} |  | ||||||
| 	response := ChatCompletionsStreamResponse{ |  | ||||||
| 		Object:  "chat.completion.chunk", |  | ||||||
| 		Created: common.GetTimestamp(), |  | ||||||
| 		Model:   "SparkDesk", |  | ||||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, |  | ||||||
| 	} |  | ||||||
| 	return &response |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { |  | ||||||
| 	HmacWithShaToBase64 := func(algorithm, data, key string) string { |  | ||||||
| 		mac := hmac.New(sha256.New, []byte(key)) |  | ||||||
| 		mac.Write([]byte(data)) |  | ||||||
| 		encodeData := mac.Sum(nil) |  | ||||||
| 		return base64.StdEncoding.EncodeToString(encodeData) |  | ||||||
| 	} |  | ||||||
| 	ul, err := url.Parse(hostUrl) |  | ||||||
| 	if err != nil { |  | ||||||
| 		fmt.Println(err) |  | ||||||
| 	} |  | ||||||
| 	date := time.Now().UTC().Format(time.RFC1123) |  | ||||||
| 	signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} |  | ||||||
| 	sign := strings.Join(signString, "\n") |  | ||||||
| 	sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret) |  | ||||||
| 	authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, |  | ||||||
| 		"hmac-sha256", "host date request-line", sha) |  | ||||||
| 	authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) |  | ||||||
| 	v := url.Values{} |  | ||||||
| 	v.Add("host", ul.Host) |  | ||||||
| 	v.Add("date", date) |  | ||||||
| 	v.Add("authorization", authorization) |  | ||||||
| 	callUrl := hostUrl + "?" + v.Encode() |  | ||||||
| 	return callUrl |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) |  | ||||||
| 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	setEventStreamHeaders(c) |  | ||||||
| 	var usage Usage |  | ||||||
| 	c.Stream(func(w io.Writer) bool { |  | ||||||
| 		select { |  | ||||||
| 		case xunfeiResponse := <-dataChan: |  | ||||||
| 			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens |  | ||||||
| 			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens |  | ||||||
| 			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens |  | ||||||
| 			response := streamResponseXunfei2OpenAI(&xunfeiResponse) |  | ||||||
| 			jsonResponse, err := json.Marshal(response) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error marshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) |  | ||||||
| 			return true |  | ||||||
| 		case <-stopChan: |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
| 	return nil, &usage |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) |  | ||||||
| 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	var usage Usage |  | ||||||
| 	var content string |  | ||||||
| 	var xunfeiResponse XunfeiChatResponse |  | ||||||
| 	stop := false |  | ||||||
| 	for !stop { |  | ||||||
| 		select { |  | ||||||
| 		case xunfeiResponse = <-dataChan: |  | ||||||
| 			if len(xunfeiResponse.Payload.Choices.Text) == 0 { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			content += xunfeiResponse.Payload.Choices.Text[0].Content |  | ||||||
| 			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens |  | ||||||
| 			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens |  | ||||||
| 			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens |  | ||||||
| 		case stop = <-stopChan: |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { |  | ||||||
| 		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ |  | ||||||
| 			{ |  | ||||||
| 				Content: "", |  | ||||||
| 			}, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	xunfeiResponse.Payload.Choices.Text[0].Content = content |  | ||||||
|  |  | ||||||
| 	response := responseXunfei2OpenAI(&xunfeiResponse) |  | ||||||
| 	jsonResponse, err := json.Marshal(response) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") |  | ||||||
| 	_, _ = c.Writer.Write(jsonResponse) |  | ||||||
| 	return nil, &usage |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { |  | ||||||
| 	d := websocket.Dialer{ |  | ||||||
| 		HandshakeTimeout: 5 * time.Second, |  | ||||||
| 	} |  | ||||||
| 	conn, resp, err := d.Dial(authUrl, nil) |  | ||||||
| 	if err != nil || resp.StatusCode != 101 { |  | ||||||
| 		return nil, nil, err |  | ||||||
| 	} |  | ||||||
| 	data := requestOpenAI2Xunfei(textRequest, appId, domain) |  | ||||||
| 	err = conn.WriteJSON(data) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, nil, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	dataChan := make(chan XunfeiChatResponse) |  | ||||||
| 	stopChan := make(chan bool) |  | ||||||
| 	go func() { |  | ||||||
| 		for { |  | ||||||
| 			_, msg, err := conn.ReadMessage() |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error reading stream response: " + err.Error()) |  | ||||||
| 				break |  | ||||||
| 			} |  | ||||||
| 			var response XunfeiChatResponse |  | ||||||
| 			err = json.Unmarshal(msg, &response) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) |  | ||||||
| 				break |  | ||||||
| 			} |  | ||||||
| 			dataChan <- response |  | ||||||
| 			if response.Payload.Choices.Status == 2 { |  | ||||||
| 				err := conn.Close() |  | ||||||
| 				if err != nil { |  | ||||||
| 					common.SysError("error closing websocket connection: " + err.Error()) |  | ||||||
| 				} |  | ||||||
| 				break |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		stopChan <- true |  | ||||||
| 	}() |  | ||||||
|  |  | ||||||
| 	return dataChan, stopChan, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) { |  | ||||||
| 	query := c.Request.URL.Query() |  | ||||||
| 	apiVersion := query.Get("api-version") |  | ||||||
| 	if apiVersion == "" { |  | ||||||
| 		apiVersion = c.GetString("api_version") |  | ||||||
| 	} |  | ||||||
| 	if apiVersion == "" { |  | ||||||
| 		apiVersion = "v1.1" |  | ||||||
| 		common.SysLog("api_version not found, use default: " + apiVersion) |  | ||||||
| 	} |  | ||||||
| 	domain := "general" |  | ||||||
| 	if apiVersion != "v1.1" { |  | ||||||
| 		domain += strings.Split(apiVersion, ".")[0] |  | ||||||
| 	} |  | ||||||
| 	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) |  | ||||||
| 	return domain, authUrl |  | ||||||
| } |  | ||||||
| @@ -1,302 +0,0 @@ | |||||||
| package controller |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bufio" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"github.com/golang-jwt/jwt" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"one-api/common" |  | ||||||
| 	"strings" |  | ||||||
| 	"sync" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // https://open.bigmodel.cn/doc/api#chatglm_std |  | ||||||
| // chatglm_std, chatglm_lite |  | ||||||
| // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke |  | ||||||
| // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke |  | ||||||
|  |  | ||||||
| type ZhipuMessage struct { |  | ||||||
| 	Role    string `json:"role"` |  | ||||||
| 	Content string `json:"content"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ZhipuRequest struct { |  | ||||||
| 	Prompt      []ZhipuMessage `json:"prompt"` |  | ||||||
| 	Temperature float64        `json:"temperature,omitempty"` |  | ||||||
| 	TopP        float64        `json:"top_p,omitempty"` |  | ||||||
| 	RequestId   string         `json:"request_id,omitempty"` |  | ||||||
| 	Incremental bool           `json:"incremental,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ZhipuResponseData struct { |  | ||||||
| 	TaskId     string         `json:"task_id"` |  | ||||||
| 	RequestId  string         `json:"request_id"` |  | ||||||
| 	TaskStatus string         `json:"task_status"` |  | ||||||
| 	Choices    []ZhipuMessage `json:"choices"` |  | ||||||
| 	Usage      `json:"usage"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ZhipuResponse struct { |  | ||||||
| 	Code    int               `json:"code"` |  | ||||||
| 	Msg     string            `json:"msg"` |  | ||||||
| 	Success bool              `json:"success"` |  | ||||||
| 	Data    ZhipuResponseData `json:"data"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ZhipuStreamMetaResponse struct { |  | ||||||
| 	RequestId  string `json:"request_id"` |  | ||||||
| 	TaskId     string `json:"task_id"` |  | ||||||
| 	TaskStatus string `json:"task_status"` |  | ||||||
| 	Usage      `json:"usage"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type zhipuTokenData struct { |  | ||||||
| 	Token      string |  | ||||||
| 	ExpiryTime time.Time |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var zhipuTokens sync.Map |  | ||||||
| var expSeconds int64 = 24 * 3600 |  | ||||||
|  |  | ||||||
| func getZhipuToken(apikey string) string { |  | ||||||
| 	data, ok := zhipuTokens.Load(apikey) |  | ||||||
| 	if ok { |  | ||||||
| 		tokenData := data.(zhipuTokenData) |  | ||||||
| 		if time.Now().Before(tokenData.ExpiryTime) { |  | ||||||
| 			return tokenData.Token |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	split := strings.Split(apikey, ".") |  | ||||||
| 	if len(split) != 2 { |  | ||||||
| 		common.SysError("invalid zhipu key: " + apikey) |  | ||||||
| 		return "" |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	id := split[0] |  | ||||||
| 	secret := split[1] |  | ||||||
|  |  | ||||||
| 	expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6 |  | ||||||
| 	expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second) |  | ||||||
|  |  | ||||||
| 	timestamp := time.Now().UnixNano() / 1e6 |  | ||||||
|  |  | ||||||
| 	payload := jwt.MapClaims{ |  | ||||||
| 		"api_key":   id, |  | ||||||
| 		"exp":       expMillis, |  | ||||||
| 		"timestamp": timestamp, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) |  | ||||||
|  |  | ||||||
| 	token.Header["alg"] = "HS256" |  | ||||||
| 	token.Header["sign_type"] = "SIGN" |  | ||||||
|  |  | ||||||
| 	tokenString, err := token.SignedString([]byte(secret)) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "" |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	zhipuTokens.Store(apikey, zhipuTokenData{ |  | ||||||
| 		Token:      tokenString, |  | ||||||
| 		ExpiryTime: expiryTime, |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	return tokenString |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { |  | ||||||
| 	messages := make([]ZhipuMessage, 0, len(request.Messages)) |  | ||||||
| 	for _, message := range request.Messages { |  | ||||||
| 		if message.Role == "system" { |  | ||||||
| 			messages = append(messages, ZhipuMessage{ |  | ||||||
| 				Role:    "system", |  | ||||||
| 				Content: message.StringContent(), |  | ||||||
| 			}) |  | ||||||
| 			messages = append(messages, ZhipuMessage{ |  | ||||||
| 				Role:    "user", |  | ||||||
| 				Content: "Okay", |  | ||||||
| 			}) |  | ||||||
| 		} else { |  | ||||||
| 			messages = append(messages, ZhipuMessage{ |  | ||||||
| 				Role:    message.Role, |  | ||||||
| 				Content: message.StringContent(), |  | ||||||
| 			}) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return &ZhipuRequest{ |  | ||||||
| 		Prompt:      messages, |  | ||||||
| 		Temperature: request.Temperature, |  | ||||||
| 		TopP:        request.TopP, |  | ||||||
| 		Incremental: false, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { |  | ||||||
| 	fullTextResponse := OpenAITextResponse{ |  | ||||||
| 		Id:      response.Data.TaskId, |  | ||||||
| 		Object:  "chat.completion", |  | ||||||
| 		Created: common.GetTimestamp(), |  | ||||||
| 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)), |  | ||||||
| 		Usage:   response.Data.Usage, |  | ||||||
| 	} |  | ||||||
| 	for i, choice := range response.Data.Choices { |  | ||||||
| 		openaiChoice := OpenAITextResponseChoice{ |  | ||||||
| 			Index: i, |  | ||||||
| 			Message: Message{ |  | ||||||
| 				Role:    choice.Role, |  | ||||||
| 				Content: strings.Trim(choice.Content, "\""), |  | ||||||
| 			}, |  | ||||||
| 			FinishReason: "", |  | ||||||
| 		} |  | ||||||
| 		if i == len(response.Data.Choices)-1 { |  | ||||||
| 			openaiChoice.FinishReason = "stop" |  | ||||||
| 		} |  | ||||||
| 		fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice) |  | ||||||
| 	} |  | ||||||
| 	return &fullTextResponse |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse { |  | ||||||
| 	var choice ChatCompletionsStreamResponseChoice |  | ||||||
| 	choice.Delta.Content = zhipuResponse |  | ||||||
| 	response := ChatCompletionsStreamResponse{ |  | ||||||
| 		Object:  "chat.completion.chunk", |  | ||||||
| 		Created: common.GetTimestamp(), |  | ||||||
| 		Model:   "chatglm", |  | ||||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, |  | ||||||
| 	} |  | ||||||
| 	return &response |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) { |  | ||||||
| 	var choice ChatCompletionsStreamResponseChoice |  | ||||||
| 	choice.Delta.Content = "" |  | ||||||
| 	choice.FinishReason = &stopFinishReason |  | ||||||
| 	response := ChatCompletionsStreamResponse{ |  | ||||||
| 		Id:      zhipuResponse.RequestId, |  | ||||||
| 		Object:  "chat.completion.chunk", |  | ||||||
| 		Created: common.GetTimestamp(), |  | ||||||
| 		Model:   "chatglm", |  | ||||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, |  | ||||||
| 	} |  | ||||||
| 	return &response, &zhipuResponse.Usage |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	var usage *Usage |  | ||||||
| 	scanner := bufio.NewScanner(resp.Body) |  | ||||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { |  | ||||||
| 		if atEOF && len(data) == 0 { |  | ||||||
| 			return 0, nil, nil |  | ||||||
| 		} |  | ||||||
| 		if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 { |  | ||||||
| 			return i + 2, data[0:i], nil |  | ||||||
| 		} |  | ||||||
| 		if atEOF { |  | ||||||
| 			return len(data), data, nil |  | ||||||
| 		} |  | ||||||
| 		return 0, nil, nil |  | ||||||
| 	}) |  | ||||||
| 	dataChan := make(chan string) |  | ||||||
| 	metaChan := make(chan string) |  | ||||||
| 	stopChan := make(chan bool) |  | ||||||
| 	go func() { |  | ||||||
| 		for scanner.Scan() { |  | ||||||
| 			data := scanner.Text() |  | ||||||
| 			lines := strings.Split(data, "\n") |  | ||||||
| 			for i, line := range lines { |  | ||||||
| 				if len(line) < 5 { |  | ||||||
| 					continue |  | ||||||
| 				} |  | ||||||
| 				if line[:5] == "data:" { |  | ||||||
| 					dataChan <- line[5:] |  | ||||||
| 					if i != len(lines)-1 { |  | ||||||
| 						dataChan <- "\n" |  | ||||||
| 					} |  | ||||||
| 				} else if line[:5] == "meta:" { |  | ||||||
| 					metaChan <- line[5:] |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		stopChan <- true |  | ||||||
| 	}() |  | ||||||
| 	setEventStreamHeaders(c) |  | ||||||
| 	c.Stream(func(w io.Writer) bool { |  | ||||||
| 		select { |  | ||||||
| 		case data := <-dataChan: |  | ||||||
| 			response := streamResponseZhipu2OpenAI(data) |  | ||||||
| 			jsonResponse, err := json.Marshal(response) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error marshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) |  | ||||||
| 			return true |  | ||||||
| 		case data := <-metaChan: |  | ||||||
| 			var zhipuResponse ZhipuStreamMetaResponse |  | ||||||
| 			err := json.Unmarshal([]byte(data), &zhipuResponse) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) |  | ||||||
| 			jsonResponse, err := json.Marshal(response) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error marshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			usage = zhipuUsage |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) |  | ||||||
| 			return true |  | ||||||
| 		case <-stopChan: |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
| 	err := resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	return nil, usage |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { |  | ||||||
| 	var zhipuResponse ZhipuResponse |  | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	err = json.Unmarshal(responseBody, &zhipuResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	if !zhipuResponse.Success { |  | ||||||
| 		return &OpenAIErrorWithStatusCode{ |  | ||||||
| 			OpenAIError: OpenAIError{ |  | ||||||
| 				Message: zhipuResponse.Msg, |  | ||||||
| 				Type:    "zhipu_error", |  | ||||||
| 				Param:   "", |  | ||||||
| 				Code:    zhipuResponse.Code, |  | ||||||
| 			}, |  | ||||||
| 			StatusCode: resp.StatusCode, |  | ||||||
| 		}, nil |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := responseZhipu2OpenAI(&zhipuResponse) |  | ||||||
| 	fullTextResponse.Model = "chatglm" |  | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") |  | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) |  | ||||||
| 	_, err = c.Writer.Write(jsonResponse) |  | ||||||
| 	return nil, &fullTextResponse.Usage |  | ||||||
| } |  | ||||||
| @@ -1,384 +1,134 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" |  | ||||||
| 	"strconv" |  | ||||||
| 	"strings" |  | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| ) | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| type Message struct { | 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||||
| 	Role    string  `json:"role"` | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
| 	Content any     `json:"content"` | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	Name    *string `json:"name,omitempty"` | 	"github.com/songquanpeng/one-api/middleware" | ||||||
| } | 	dbmodel "github.com/songquanpeng/one-api/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/monitor" | ||||||
| type ImageURL struct { | 	"github.com/songquanpeng/one-api/relay/controller" | ||||||
| 	Url    string `json:"url,omitempty"` | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
| 	Detail string `json:"detail,omitempty"` | 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||||
| } |  | ||||||
|  |  | ||||||
| type TextContent struct { |  | ||||||
| 	Type string `json:"type,omitempty"` |  | ||||||
| 	Text string `json:"text,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ImageContent struct { |  | ||||||
| 	Type     string    `json:"type,omitempty"` |  | ||||||
| 	ImageURL *ImageURL `json:"image_url,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	ContentTypeText     = "text" |  | ||||||
| 	ContentTypeImageURL = "image_url" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type OpenAIMessageContent struct { |  | ||||||
| 	Type     string    `json:"type,omitempty"` |  | ||||||
| 	Text     string    `json:"text"` |  | ||||||
| 	ImageURL *ImageURL `json:"image_url,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (m Message) IsStringContent() bool { |  | ||||||
| 	_, ok := m.Content.(string) |  | ||||||
| 	return ok |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (m Message) StringContent() string { |  | ||||||
| 	content, ok := m.Content.(string) |  | ||||||
| 	if ok { |  | ||||||
| 		return content |  | ||||||
| 	} |  | ||||||
| 	contentList, ok := m.Content.([]any) |  | ||||||
| 	if ok { |  | ||||||
| 		var contentStr string |  | ||||||
| 		for _, contentItem := range contentList { |  | ||||||
| 			contentMap, ok := contentItem.(map[string]any) |  | ||||||
| 			if !ok { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			if contentMap["type"] == ContentTypeText { |  | ||||||
| 				if subStr, ok := contentMap["text"].(string); ok { |  | ||||||
| 					contentStr += subStr |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		return contentStr |  | ||||||
| 	} |  | ||||||
| 	return "" |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (m Message) ParseContent() []OpenAIMessageContent { |  | ||||||
| 	var contentList []OpenAIMessageContent |  | ||||||
| 	content, ok := m.Content.(string) |  | ||||||
| 	if ok { |  | ||||||
| 		contentList = append(contentList, OpenAIMessageContent{ |  | ||||||
| 			Type: ContentTypeText, |  | ||||||
| 			Text: content, |  | ||||||
| 		}) |  | ||||||
| 		return contentList |  | ||||||
| 	} |  | ||||||
| 	anyList, ok := m.Content.([]any) |  | ||||||
| 	if ok { |  | ||||||
| 		for _, contentItem := range anyList { |  | ||||||
| 			contentMap, ok := contentItem.(map[string]any) |  | ||||||
| 			if !ok { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			switch contentMap["type"] { |  | ||||||
| 			case ContentTypeText: |  | ||||||
| 				if subStr, ok := contentMap["text"].(string); ok { |  | ||||||
| 					contentList = append(contentList, OpenAIMessageContent{ |  | ||||||
| 						Type: ContentTypeText, |  | ||||||
| 						Text: subStr, |  | ||||||
| 					}) |  | ||||||
| 				} |  | ||||||
| 			case ContentTypeImageURL: |  | ||||||
| 				if subObj, ok := contentMap["image_url"].(map[string]any); ok { |  | ||||||
| 					contentList = append(contentList, OpenAIMessageContent{ |  | ||||||
| 						Type: ContentTypeImageURL, |  | ||||||
| 						ImageURL: &ImageURL{ |  | ||||||
| 							Url: subObj["url"].(string), |  | ||||||
| 						}, |  | ||||||
| 					}) |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		return contentList |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	RelayModeUnknown = iota |  | ||||||
| 	RelayModeChatCompletions |  | ||||||
| 	RelayModeCompletions |  | ||||||
| 	RelayModeEmbeddings |  | ||||||
| 	RelayModeModerations |  | ||||||
| 	RelayModeImagesGenerations |  | ||||||
| 	RelayModeEdits |  | ||||||
| 	RelayModeAudioSpeech |  | ||||||
| 	RelayModeAudioTranscription |  | ||||||
| 	RelayModeAudioTranslation |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // https://platform.openai.com/docs/api-reference/chat | // https://platform.openai.com/docs/api-reference/chat | ||||||
|  |  | ||||||
| type ResponseFormat struct { | func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { | ||||||
| 	Type string `json:"type,omitempty"` | 	var err *model.ErrorWithStatusCode | ||||||
| } | 	switch relayMode { | ||||||
|  | 	case relaymode.ImagesGenerations: | ||||||
| type GeneralOpenAIRequest struct { | 		err = controller.RelayImageHelper(c, relayMode) | ||||||
| 	Model            string          `json:"model,omitempty"` | 	case relaymode.AudioSpeech: | ||||||
| 	Messages         []Message       `json:"messages,omitempty"` | 		fallthrough | ||||||
| 	Prompt           any             `json:"prompt,omitempty"` | 	case relaymode.AudioTranslation: | ||||||
| 	Stream           bool            `json:"stream,omitempty"` | 		fallthrough | ||||||
| 	MaxTokens        int             `json:"max_tokens,omitempty"` | 	case relaymode.AudioTranscription: | ||||||
| 	Temperature      float64         `json:"temperature,omitempty"` | 		err = controller.RelayAudioHelper(c, relayMode) | ||||||
| 	TopP             float64         `json:"top_p,omitempty"` | 	default: | ||||||
| 	N                int             `json:"n,omitempty"` | 		err = controller.RelayTextHelper(c) | ||||||
| 	Input            any             `json:"input,omitempty"` |  | ||||||
| 	Instruction      string          `json:"instruction,omitempty"` |  | ||||||
| 	Size             string          `json:"size,omitempty"` |  | ||||||
| 	Functions        any             `json:"functions,omitempty"` |  | ||||||
| 	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"` |  | ||||||
| 	PresencePenalty  float64         `json:"presence_penalty,omitempty"` |  | ||||||
| 	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"` |  | ||||||
| 	Seed             float64         `json:"seed,omitempty"` |  | ||||||
| 	Tools            any             `json:"tools,omitempty"` |  | ||||||
| 	ToolChoice       any             `json:"tool_choice,omitempty"` |  | ||||||
| 	User             string          `json:"user,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (r GeneralOpenAIRequest) ParseInput() []string { |  | ||||||
| 	if r.Input == nil { |  | ||||||
| 		return nil |  | ||||||
| 	} | 	} | ||||||
| 	var input []string | 	return err | ||||||
| 	switch r.Input.(type) { |  | ||||||
| 	case string: |  | ||||||
| 		input = []string{r.Input.(string)} |  | ||||||
| 	case []any: |  | ||||||
| 		input = make([]string, 0, len(r.Input.([]any))) |  | ||||||
| 		for _, item := range r.Input.([]any) { |  | ||||||
| 			if str, ok := item.(string); ok { |  | ||||||
| 				input = append(input, str) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return input |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ChatRequest struct { |  | ||||||
| 	Model     string    `json:"model"` |  | ||||||
| 	Messages  []Message `json:"messages"` |  | ||||||
| 	MaxTokens int       `json:"max_tokens"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TextRequest struct { |  | ||||||
| 	Model     string    `json:"model"` |  | ||||||
| 	Messages  []Message `json:"messages"` |  | ||||||
| 	Prompt    string    `json:"prompt"` |  | ||||||
| 	MaxTokens int       `json:"max_tokens"` |  | ||||||
| 	//Stream   bool      `json:"stream"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create |  | ||||||
| type ImageRequest struct { |  | ||||||
| 	Model          string `json:"model"` |  | ||||||
| 	Prompt         string `json:"prompt" binding:"required"` |  | ||||||
| 	N              int    `json:"n,omitempty"` |  | ||||||
| 	Size           string `json:"size,omitempty"` |  | ||||||
| 	Quality        string `json:"quality,omitempty"` |  | ||||||
| 	ResponseFormat string `json:"response_format,omitempty"` |  | ||||||
| 	Style          string `json:"style,omitempty"` |  | ||||||
| 	User           string `json:"user,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type WhisperJSONResponse struct { |  | ||||||
| 	Text string `json:"text,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type WhisperVerboseJSONResponse struct { |  | ||||||
| 	Task     string    `json:"task,omitempty"` |  | ||||||
| 	Language string    `json:"language,omitempty"` |  | ||||||
| 	Duration float64   `json:"duration,omitempty"` |  | ||||||
| 	Text     string    `json:"text,omitempty"` |  | ||||||
| 	Segments []Segment `json:"segments,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type Segment struct { |  | ||||||
| 	Id               int     `json:"id"` |  | ||||||
| 	Seek             int     `json:"seek"` |  | ||||||
| 	Start            float64 `json:"start"` |  | ||||||
| 	End              float64 `json:"end"` |  | ||||||
| 	Text             string  `json:"text"` |  | ||||||
| 	Tokens           []int   `json:"tokens"` |  | ||||||
| 	Temperature      float64 `json:"temperature"` |  | ||||||
| 	AvgLogprob       float64 `json:"avg_logprob"` |  | ||||||
| 	CompressionRatio float64 `json:"compression_ratio"` |  | ||||||
| 	NoSpeechProb     float64 `json:"no_speech_prob"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TextToSpeechRequest struct { |  | ||||||
| 	Model          string  `json:"model" binding:"required"` |  | ||||||
| 	Input          string  `json:"input" binding:"required"` |  | ||||||
| 	Voice          string  `json:"voice" binding:"required"` |  | ||||||
| 	Speed          float64 `json:"speed"` |  | ||||||
| 	ResponseFormat string  `json:"response_format"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type Usage struct { |  | ||||||
| 	PromptTokens     int `json:"prompt_tokens"` |  | ||||||
| 	CompletionTokens int `json:"completion_tokens"` |  | ||||||
| 	TotalTokens      int `json:"total_tokens"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type OpenAIError struct { |  | ||||||
| 	Message string `json:"message"` |  | ||||||
| 	Type    string `json:"type"` |  | ||||||
| 	Param   string `json:"param"` |  | ||||||
| 	Code    any    `json:"code"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type OpenAIErrorWithStatusCode struct { |  | ||||||
| 	OpenAIError |  | ||||||
| 	StatusCode int `json:"status_code"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TextResponse struct { |  | ||||||
| 	Choices []OpenAITextResponseChoice `json:"choices"` |  | ||||||
| 	Usage   `json:"usage"` |  | ||||||
| 	Error   OpenAIError `json:"error"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type OpenAITextResponseChoice struct { |  | ||||||
| 	Index        int `json:"index"` |  | ||||||
| 	Message      `json:"message"` |  | ||||||
| 	FinishReason string `json:"finish_reason"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type OpenAITextResponse struct { |  | ||||||
| 	Id      string                     `json:"id"` |  | ||||||
| 	Model   string                     `json:"model,omitempty"` |  | ||||||
| 	Object  string                     `json:"object"` |  | ||||||
| 	Created int64                      `json:"created"` |  | ||||||
| 	Choices []OpenAITextResponseChoice `json:"choices"` |  | ||||||
| 	Usage   `json:"usage"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type OpenAIEmbeddingResponseItem struct { |  | ||||||
| 	Object    string    `json:"object"` |  | ||||||
| 	Index     int       `json:"index"` |  | ||||||
| 	Embedding []float64 `json:"embedding"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type OpenAIEmbeddingResponse struct { |  | ||||||
| 	Object string                        `json:"object"` |  | ||||||
| 	Data   []OpenAIEmbeddingResponseItem `json:"data"` |  | ||||||
| 	Model  string                        `json:"model"` |  | ||||||
| 	Usage  `json:"usage"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ImageResponse struct { |  | ||||||
| 	Created int `json:"created"` |  | ||||||
| 	Data    []struct { |  | ||||||
| 		Url string `json:"url"` |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ChatCompletionsStreamResponseChoice struct { |  | ||||||
| 	Delta struct { |  | ||||||
| 		Content string `json:"content"` |  | ||||||
| 	} `json:"delta"` |  | ||||||
| 	FinishReason *string `json:"finish_reason,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type ChatCompletionsStreamResponse struct { |  | ||||||
| 	Id      string                                `json:"id"` |  | ||||||
| 	Object  string                                `json:"object"` |  | ||||||
| 	Created int64                                 `json:"created"` |  | ||||||
| 	Model   string                                `json:"model"` |  | ||||||
| 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type CompletionsStreamResponse struct { |  | ||||||
| 	Choices []struct { |  | ||||||
| 		Text         string `json:"text"` |  | ||||||
| 		FinishReason string `json:"finish_reason"` |  | ||||||
| 	} `json:"choices"` |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func Relay(c *gin.Context) { | func Relay(c *gin.Context) { | ||||||
| 	relayMode := RelayModeUnknown | 	ctx := c.Request.Context() | ||||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { | 	relayMode := relaymode.GetByPath(c.Request.URL.Path) | ||||||
| 		relayMode = RelayModeChatCompletions | 	if config.DebugEnabled { | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { | 		requestBody, _ := common.GetRequestBody(c) | ||||||
| 		relayMode = RelayModeCompletions | 		logger.Debugf(ctx, "request body: %s", string(requestBody)) | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { |  | ||||||
| 		relayMode = RelayModeEmbeddings |  | ||||||
| 	} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { |  | ||||||
| 		relayMode = RelayModeEmbeddings |  | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { |  | ||||||
| 		relayMode = RelayModeModerations |  | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { |  | ||||||
| 		relayMode = RelayModeImagesGenerations |  | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { |  | ||||||
| 		relayMode = RelayModeEdits |  | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { |  | ||||||
| 		relayMode = RelayModeAudioSpeech |  | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { |  | ||||||
| 		relayMode = RelayModeAudioTranscription |  | ||||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { |  | ||||||
| 		relayMode = RelayModeAudioTranslation |  | ||||||
| 	} | 	} | ||||||
| 	var err *OpenAIErrorWithStatusCode | 	channelId := c.GetInt(ctxkey.ChannelId) | ||||||
| 	switch relayMode { | 	userId := c.GetInt(ctxkey.Id) | ||||||
| 	case RelayModeImagesGenerations: | 	bizErr := relayHelper(c, relayMode) | ||||||
| 		err = relayImageHelper(c, relayMode) | 	if bizErr == nil { | ||||||
| 	case RelayModeAudioSpeech: | 		monitor.Emit(channelId, true) | ||||||
| 		fallthrough | 		return | ||||||
| 	case RelayModeAudioTranslation: |  | ||||||
| 		fallthrough |  | ||||||
| 	case RelayModeAudioTranscription: |  | ||||||
| 		err = relayAudioHelper(c, relayMode) |  | ||||||
| 	default: |  | ||||||
| 		err = relayTextHelper(c, relayMode) |  | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	lastFailedChannelId := channelId | ||||||
| 		requestId := c.GetString(common.RequestIdKey) | 	channelName := c.GetString(ctxkey.ChannelName) | ||||||
| 		retryTimesStr := c.Query("retry") | 	group := c.GetString(ctxkey.Group) | ||||||
| 		retryTimes, _ := strconv.Atoi(retryTimesStr) | 	originalModel := c.GetString(ctxkey.OriginalModel) | ||||||
| 		if retryTimesStr == "" { | 	go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) | ||||||
| 			retryTimes = common.RetryTimes | 	requestId := c.GetString(helper.RequestIdKey) | ||||||
|  | 	retryTimes := config.RetryTimes | ||||||
|  | 	if !shouldRetry(c, bizErr.StatusCode) { | ||||||
|  | 		logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode) | ||||||
|  | 		retryTimes = 0 | ||||||
|  | 	} | ||||||
|  | 	for i := retryTimes; i > 0; i-- { | ||||||
|  | 		channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes) | ||||||
|  | 		if err != nil { | ||||||
|  | 			logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %+v", err) | ||||||
|  | 			break | ||||||
| 		} | 		} | ||||||
| 		if retryTimes > 0 { | 		logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i) | ||||||
| 			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) | 		if channel.Id == lastFailedChannelId { | ||||||
| 		} else { | 			continue | ||||||
| 			if err.StatusCode == http.StatusTooManyRequests { |  | ||||||
| 				err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" |  | ||||||
| 			} |  | ||||||
| 			err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) |  | ||||||
| 			c.JSON(err.StatusCode, gin.H{ |  | ||||||
| 				"error": err.OpenAIError, |  | ||||||
| 			}) |  | ||||||
| 		} | 		} | ||||||
| 		channelId := c.GetInt("channel_id") | 		middleware.SetupContextForSelectedChannel(c, channel, originalModel) | ||||||
| 		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | 		requestBody, err := common.GetRequestBody(c) | ||||||
| 		// https://platform.openai.com/docs/guides/error-codes/api-errors | 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||||
| 		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { | 		bizErr = relayHelper(c, relayMode) | ||||||
| 			channelId := c.GetInt("channel_id") | 		if bizErr == nil { | ||||||
| 			channelName := c.GetString("channel_name") | 			return | ||||||
| 			disableChannel(channelId, channelName, err.Message) |  | ||||||
| 		} | 		} | ||||||
|  | 		channelId := c.GetInt(ctxkey.ChannelId) | ||||||
|  | 		lastFailedChannelId = channelId | ||||||
|  | 		channelName := c.GetString(ctxkey.ChannelName) | ||||||
|  | 		go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) | ||||||
|  | 	} | ||||||
|  | 	if bizErr != nil { | ||||||
|  | 		if bizErr.StatusCode == http.StatusTooManyRequests { | ||||||
|  | 			bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" | ||||||
|  | 		} | ||||||
|  | 		bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId) | ||||||
|  | 		c.JSON(bizErr.StatusCode, gin.H{ | ||||||
|  | 			"error": bizErr.Error, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func shouldRetry(c *gin.Context, statusCode int) bool { | ||||||
|  | 	if _, ok := c.Get(ctxkey.SpecificChannelId); ok { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if statusCode == http.StatusTooManyRequests { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if statusCode/100 == 5 { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if statusCode == http.StatusBadRequest { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if statusCode/100 == 2 { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) { | ||||||
|  | 	logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) | ||||||
|  | 	// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||||
|  | 	if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { | ||||||
|  | 		monitor.DisableChannel(channelId, channelName, err.Message) | ||||||
|  | 	} else { | ||||||
|  | 		monitor.Emit(channelId, false) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func RelayNotImplemented(c *gin.Context) { | func RelayNotImplemented(c *gin.Context) { | ||||||
| 	err := OpenAIError{ | 	err := model.Error{ | ||||||
| 		Message: "API not implemented", | 		Message: "API not implemented", | ||||||
| 		Type:    "one_api_error", | 		Type:    "one_api_error", | ||||||
| 		Param:   "", | 		Param:   "", | ||||||
| @@ -390,7 +140,7 @@ func RelayNotImplemented(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func RelayNotFound(c *gin.Context) { | func RelayNotFound(c *gin.Context) { | ||||||
| 	err := OpenAIError{ | 	err := model.Error{ | ||||||
| 		Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), | 		Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), | ||||||
| 		Type:    "invalid_request_error", | 		Type:    "invalid_request_error", | ||||||
| 		Param:   "", | 		Param:   "", | ||||||
|   | |||||||
| @@ -1,20 +1,28 @@ | |||||||
| package controller | package controller | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/network" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/random" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/model" |  | ||||||
| 	"strconv" | 	"strconv" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func GetAllTokens(c *gin.Context) { | func GetAllTokens(c *gin.Context) { | ||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt(ctxkey.Id) | ||||||
| 	p, _ := strconv.Atoi(c.Query("p")) | 	p, _ := strconv.Atoi(c.Query("p")) | ||||||
| 	if p < 0 { | 	if p < 0 { | ||||||
| 		p = 0 | 		p = 0 | ||||||
| 	} | 	} | ||||||
| 	tokens, err := model.GetAllUserTokens(userId, p*common.ItemsPerPage, common.ItemsPerPage) |  | ||||||
|  | 	order := c.Query("order") | ||||||
|  | 	tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage, order) | ||||||
|  |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -31,7 +39,7 @@ func GetAllTokens(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func SearchTokens(c *gin.Context) { | func SearchTokens(c *gin.Context) { | ||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt(ctxkey.Id) | ||||||
| 	keyword := c.Query("keyword") | 	keyword := c.Query("keyword") | ||||||
| 	tokens, err := model.SearchUserTokens(userId, keyword) | 	tokens, err := model.SearchUserTokens(userId, keyword) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -51,7 +59,7 @@ func SearchTokens(c *gin.Context) { | |||||||
|  |  | ||||||
| func GetToken(c *gin.Context) { | func GetToken(c *gin.Context) { | ||||||
| 	id, err := strconv.Atoi(c.Param("id")) | 	id, err := strconv.Atoi(c.Param("id")) | ||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt(ctxkey.Id) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -76,8 +84,8 @@ func GetToken(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GetTokenStatus(c *gin.Context) { | func GetTokenStatus(c *gin.Context) { | ||||||
| 	tokenId := c.GetInt("token_id") | 	tokenId := c.GetInt(ctxkey.TokenId) | ||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt(ctxkey.Id) | ||||||
| 	token, err := model.GetTokenByIds(tokenId, userId) | 	token, err := model.GetTokenByIds(tokenId, userId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| @@ -99,6 +107,19 @@ func GetTokenStatus(c *gin.Context) { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func validateToken(c *gin.Context, token model.Token) error { | ||||||
|  | 	if len(token.Name) > 30 { | ||||||
|  | 		return fmt.Errorf("令牌名称过长") | ||||||
|  | 	} | ||||||
|  | 	if token.Subnet != nil && *token.Subnet != "" { | ||||||
|  | 		err := network.IsValidSubnets(*token.Subnet) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return fmt.Errorf("无效的网段:%s", err.Error()) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func AddToken(c *gin.Context) { | func AddToken(c *gin.Context) { | ||||||
| 	token := model.Token{} | 	token := model.Token{} | ||||||
| 	err := c.ShouldBindJSON(&token) | 	err := c.ShouldBindJSON(&token) | ||||||
| @@ -109,22 +130,26 @@ func AddToken(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if len(token.Name) > 30 { | 	err = validateToken(c, token) | ||||||
|  | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "令牌名称过长", | 			"message": fmt.Sprintf("参数错误:%s", err.Error()), | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	cleanToken := model.Token{ | 	cleanToken := model.Token{ | ||||||
| 		UserId:         c.GetInt("id"), | 		UserId:         c.GetInt(ctxkey.Id), | ||||||
| 		Name:           token.Name, | 		Name:           token.Name, | ||||||
| 		Key:            common.GenerateKey(), | 		Key:            random.GenerateKey(), | ||||||
| 		CreatedTime:    common.GetTimestamp(), | 		CreatedTime:    helper.GetTimestamp(), | ||||||
| 		AccessedTime:   common.GetTimestamp(), | 		AccessedTime:   helper.GetTimestamp(), | ||||||
| 		ExpiredTime:    token.ExpiredTime, | 		ExpiredTime:    token.ExpiredTime, | ||||||
| 		RemainQuota:    token.RemainQuota, | 		RemainQuota:    token.RemainQuota, | ||||||
| 		UnlimitedQuota: token.UnlimitedQuota, | 		UnlimitedQuota: token.UnlimitedQuota, | ||||||
|  | 		Models:         token.Models, | ||||||
|  | 		Subnet:         token.Subnet, | ||||||
| 	} | 	} | ||||||
| 	err = cleanToken.Insert() | 	err = cleanToken.Insert() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -137,13 +162,14 @@ func AddToken(c *gin.Context) { | |||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
|  | 		"data":    cleanToken, | ||||||
| 	}) | 	}) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func DeleteToken(c *gin.Context) { | func DeleteToken(c *gin.Context) { | ||||||
| 	id, _ := strconv.Atoi(c.Param("id")) | 	id, _ := strconv.Atoi(c.Param("id")) | ||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt(ctxkey.Id) | ||||||
| 	err := model.DeleteTokenById(id, userId) | 	err := model.DeleteTokenById(id, userId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| @@ -160,7 +186,7 @@ func DeleteToken(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateToken(c *gin.Context) { | func UpdateToken(c *gin.Context) { | ||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt(ctxkey.Id) | ||||||
| 	statusOnly := c.Query("status_only") | 	statusOnly := c.Query("status_only") | ||||||
| 	token := model.Token{} | 	token := model.Token{} | ||||||
| 	err := c.ShouldBindJSON(&token) | 	err := c.ShouldBindJSON(&token) | ||||||
| @@ -171,10 +197,11 @@ func UpdateToken(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if len(token.Name) > 30 { | 	err = validateToken(c, token) | ||||||
|  | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "令牌名称过长", | 			"message": fmt.Sprintf("参数错误:%s", err.Error()), | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -186,15 +213,15 @@ func UpdateToken(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if token.Status == common.TokenStatusEnabled { | 	if token.Status == model.TokenStatusEnabled { | ||||||
| 		if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 { | 		if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", | 				"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", | ||||||
| 			}) | 			}) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { | 		if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", | 				"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", | ||||||
| @@ -210,6 +237,8 @@ func UpdateToken(c *gin.Context) { | |||||||
| 		cleanToken.ExpiredTime = token.ExpiredTime | 		cleanToken.ExpiredTime = token.ExpiredTime | ||||||
| 		cleanToken.RemainQuota = token.RemainQuota | 		cleanToken.RemainQuota = token.RemainQuota | ||||||
| 		cleanToken.UnlimitedQuota = token.UnlimitedQuota | 		cleanToken.UnlimitedQuota = token.UnlimitedQuota | ||||||
|  | 		cleanToken.Models = token.Models | ||||||
|  | 		cleanToken.Subnet = token.Subnet | ||||||
| 	} | 	} | ||||||
| 	err = cleanToken.Update() | 	err = cleanToken.Update() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|   | |||||||
| @@ -3,9 +3,12 @@ package controller | |||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/random" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/model" |  | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -19,7 +22,7 @@ type LoginRequest struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func Login(c *gin.Context) { | func Login(c *gin.Context) { | ||||||
| 	if !common.PasswordLoginEnabled { | 	if !config.PasswordLoginEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"message": "管理员关闭了密码登录", | 			"message": "管理员关闭了密码登录", | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -56,11 +59,11 @@ func Login(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	setupLogin(&user, c) | 	SetupLogin(&user, c) | ||||||
| } | } | ||||||
|  |  | ||||||
| // setup session & cookies and then return user info | // setup session & cookies and then return user info | ||||||
| func setupLogin(user *model.User, c *gin.Context) { | func SetupLogin(user *model.User, c *gin.Context) { | ||||||
| 	session := sessions.Default(c) | 	session := sessions.Default(c) | ||||||
| 	session.Set("id", user.Id) | 	session.Set("id", user.Id) | ||||||
| 	session.Set("username", user.Username) | 	session.Set("username", user.Username) | ||||||
| @@ -106,14 +109,14 @@ func Logout(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func Register(c *gin.Context) { | func Register(c *gin.Context) { | ||||||
| 	if !common.RegisterEnabled { | 	if !config.RegisterEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"message": "管理员关闭了新用户注册", | 			"message": "管理员关闭了新用户注册", | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if !common.PasswordRegisterEnabled { | 	if !config.PasswordRegisterEnabled { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册", | 			"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册", | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -136,7 +139,7 @@ func Register(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if common.EmailVerificationEnabled { | 	if config.EmailVerificationEnabled { | ||||||
| 		if user.Email == "" || user.VerificationCode == "" { | 		if user.Email == "" || user.VerificationCode == "" { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| @@ -160,7 +163,7 @@ func Register(c *gin.Context) { | |||||||
| 		DisplayName: user.Username, | 		DisplayName: user.Username, | ||||||
| 		InviterId:   inviterId, | 		InviterId:   inviterId, | ||||||
| 	} | 	} | ||||||
| 	if common.EmailVerificationEnabled { | 	if config.EmailVerificationEnabled { | ||||||
| 		cleanUser.Email = user.Email | 		cleanUser.Email = user.Email | ||||||
| 	} | 	} | ||||||
| 	if err := cleanUser.Insert(inviterId); err != nil { | 	if err := cleanUser.Insert(inviterId); err != nil { | ||||||
| @@ -170,6 +173,7 @@ func Register(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| @@ -182,7 +186,10 @@ func GetAllUsers(c *gin.Context) { | |||||||
| 	if p < 0 { | 	if p < 0 { | ||||||
| 		p = 0 | 		p = 0 | ||||||
| 	} | 	} | ||||||
| 	users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage) |  | ||||||
|  | 	order := c.DefaultQuery("order", "") | ||||||
|  | 	users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) | ||||||
|  |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| @@ -190,12 +197,12 @@ func GetAllUsers(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    users, | 		"data":    users, | ||||||
| 	}) | 	}) | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchUsers(c *gin.Context) { | func SearchUsers(c *gin.Context) { | ||||||
| @@ -233,8 +240,8 @@ func GetUser(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	myRole := c.GetInt("role") | 	myRole := c.GetInt(ctxkey.Role) | ||||||
| 	if myRole <= user.Role && myRole != common.RoleRootUser { | 	if myRole <= user.Role && myRole != model.RoleRootUser { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "无权获取同级或更高等级用户的信息", | 			"message": "无权获取同级或更高等级用户的信息", | ||||||
| @@ -250,7 +257,7 @@ func GetUser(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GetUserDashboard(c *gin.Context) { | func GetUserDashboard(c *gin.Context) { | ||||||
| 	id := c.GetInt("id") | 	id := c.GetInt(ctxkey.Id) | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| 	startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix() | 	startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix() | ||||||
| 	endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix() | 	endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix() | ||||||
| @@ -273,7 +280,7 @@ func GetUserDashboard(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GenerateAccessToken(c *gin.Context) { | func GenerateAccessToken(c *gin.Context) { | ||||||
| 	id := c.GetInt("id") | 	id := c.GetInt(ctxkey.Id) | ||||||
| 	user, err := model.GetUserById(id, true) | 	user, err := model.GetUserById(id, true) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| @@ -282,7 +289,7 @@ func GenerateAccessToken(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	user.AccessToken = common.GetUUID() | 	user.AccessToken = random.GetUUID() | ||||||
|  |  | ||||||
| 	if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { | 	if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| @@ -309,7 +316,7 @@ func GenerateAccessToken(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GetAffCode(c *gin.Context) { | func GetAffCode(c *gin.Context) { | ||||||
| 	id := c.GetInt("id") | 	id := c.GetInt(ctxkey.Id) | ||||||
| 	user, err := model.GetUserById(id, true) | 	user, err := model.GetUserById(id, true) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| @@ -319,7 +326,7 @@ func GetAffCode(c *gin.Context) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if user.AffCode == "" { | 	if user.AffCode == "" { | ||||||
| 		user.AffCode = common.GetRandomString(4) | 		user.AffCode = random.GetRandomString(4) | ||||||
| 		if err := user.Update(false); err != nil { | 		if err := user.Update(false); err != nil { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| @@ -337,7 +344,7 @@ func GetAffCode(c *gin.Context) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GetSelf(c *gin.Context) { | func GetSelf(c *gin.Context) { | ||||||
| 	id := c.GetInt("id") | 	id := c.GetInt(ctxkey.Id) | ||||||
| 	user, err := model.GetUserById(id, false) | 	user, err := model.GetUserById(id, false) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| @@ -382,15 +389,15 @@ func UpdateUser(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	myRole := c.GetInt("role") | 	myRole := c.GetInt(ctxkey.Role) | ||||||
| 	if myRole <= originUser.Role && myRole != common.RoleRootUser { | 	if myRole <= originUser.Role && myRole != model.RoleRootUser { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "无权更新同权限等级或更高权限等级的用户信息", | 			"message": "无权更新同权限等级或更高权限等级的用户信息", | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if myRole <= updatedUser.Role && myRole != common.RoleRootUser { | 	if myRole <= updatedUser.Role && myRole != model.RoleRootUser { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "无权将其他用户权限等级提升到大于等于自己的权限等级", | 			"message": "无权将其他用户权限等级提升到大于等于自己的权限等级", | ||||||
| @@ -440,7 +447,7 @@ func UpdateSelf(c *gin.Context) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	cleanUser := model.User{ | 	cleanUser := model.User{ | ||||||
| 		Id:          c.GetInt("id"), | 		Id:          c.GetInt(ctxkey.Id), | ||||||
| 		Username:    user.Username, | 		Username:    user.Username, | ||||||
| 		Password:    user.Password, | 		Password:    user.Password, | ||||||
| 		DisplayName: user.DisplayName, | 		DisplayName: user.DisplayName, | ||||||
| @@ -504,7 +511,7 @@ func DeleteSelf(c *gin.Context) { | |||||||
| 	id := c.GetInt("id") | 	id := c.GetInt("id") | ||||||
| 	user, _ := model.GetUserById(id, false) | 	user, _ := model.GetUserById(id, false) | ||||||
|  |  | ||||||
| 	if user.Role == common.RoleRootUser { | 	if user.Role == model.RoleRootUser { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "不能删除超级管理员账户", | 			"message": "不能删除超级管理员账户", | ||||||
| @@ -606,7 +613,7 @@ func ManageUser(c *gin.Context) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	myRole := c.GetInt("role") | 	myRole := c.GetInt("role") | ||||||
| 	if myRole <= user.Role && myRole != common.RoleRootUser { | 	if myRole <= user.Role && myRole != model.RoleRootUser { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "无权更新同权限等级或更高权限等级的用户信息", | 			"message": "无权更新同权限等级或更高权限等级的用户信息", | ||||||
| @@ -615,8 +622,8 @@ func ManageUser(c *gin.Context) { | |||||||
| 	} | 	} | ||||||
| 	switch req.Action { | 	switch req.Action { | ||||||
| 	case "disable": | 	case "disable": | ||||||
| 		user.Status = common.UserStatusDisabled | 		user.Status = model.UserStatusDisabled | ||||||
| 		if user.Role == common.RoleRootUser { | 		if user.Role == model.RoleRootUser { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "无法禁用超级管理员用户", | 				"message": "无法禁用超级管理员用户", | ||||||
| @@ -624,9 +631,9 @@ func ManageUser(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	case "enable": | 	case "enable": | ||||||
| 		user.Status = common.UserStatusEnabled | 		user.Status = model.UserStatusEnabled | ||||||
| 	case "delete": | 	case "delete": | ||||||
| 		if user.Role == common.RoleRootUser { | 		if user.Role == model.RoleRootUser { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "无法删除超级管理员用户", | 				"message": "无法删除超级管理员用户", | ||||||
| @@ -641,37 +648,37 @@ func ManageUser(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	case "promote": | 	case "promote": | ||||||
| 		if myRole != common.RoleRootUser { | 		if myRole != model.RoleRootUser { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "普通管理员用户无法提升其他用户为管理员", | 				"message": "普通管理员用户无法提升其他用户为管理员", | ||||||
| 			}) | 			}) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		if user.Role >= common.RoleAdminUser { | 		if user.Role >= model.RoleAdminUser { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "该用户已经是管理员", | 				"message": "该用户已经是管理员", | ||||||
| 			}) | 			}) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		user.Role = common.RoleAdminUser | 		user.Role = model.RoleAdminUser | ||||||
| 	case "demote": | 	case "demote": | ||||||
| 		if user.Role == common.RoleRootUser { | 		if user.Role == model.RoleRootUser { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "无法降级超级管理员用户", | 				"message": "无法降级超级管理员用户", | ||||||
| 			}) | 			}) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		if user.Role == common.RoleCommonUser { | 		if user.Role == model.RoleCommonUser { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "该用户已经是普通用户", | 				"message": "该用户已经是普通用户", | ||||||
| 			}) | 			}) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		user.Role = common.RoleCommonUser | 		user.Role = model.RoleCommonUser | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if err := user.Update(false); err != nil { | 	if err := user.Update(false); err != nil { | ||||||
| @@ -725,8 +732,8 @@ func EmailBind(c *gin.Context) { | |||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if user.Role == common.RoleRootUser { | 	if user.Role == model.RoleRootUser { | ||||||
| 		common.RootUserEmail = email | 		config.RootUserEmail = email | ||||||
| 	} | 	} | ||||||
| 	c.JSON(http.StatusOK, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| @@ -765,3 +772,38 @@ func TopUp(c *gin.Context) { | |||||||
| 	}) | 	}) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type adminTopUpRequest struct { | ||||||
|  | 	UserId int    `json:"user_id"` | ||||||
|  | 	Quota  int    `json:"quota"` | ||||||
|  | 	Remark string `json:"remark"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func AdminTopUp(c *gin.Context) { | ||||||
|  | 	req := adminTopUpRequest{} | ||||||
|  | 	err := c.ShouldBindJSON(&req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	err = model.IncreaseUserQuota(req.UserId, int64(req.Quota)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if req.Remark == "" { | ||||||
|  | 		req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) | ||||||
|  | 	} | ||||||
|  | 	model.RecordTopupLog(req.UserId, req.Remark, req.Quota) | ||||||
|  | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 		"success": true, | ||||||
|  | 		"message": "", | ||||||
|  | 	}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|   | |||||||
| @@ -2,7 +2,7 @@ version: '3.4' | |||||||
|  |  | ||||||
| services: | services: | ||||||
|   one-api: |   one-api: | ||||||
|     image: justsong/one-api:latest |     image: "${REGISTRY:-docker.io}/justsong/one-api:latest" | ||||||
|     container_name: one-api |     container_name: one-api | ||||||
|     restart: always |     restart: always | ||||||
|     command: --log-dir /app/logs |     command: --log-dir /app/logs | ||||||
| @@ -29,12 +29,12 @@ services: | |||||||
|       retries: 3 |       retries: 3 | ||||||
|  |  | ||||||
|   redis: |   redis: | ||||||
|     image: redis:latest |     image: "${REGISTRY:-docker.io}/redis:latest" | ||||||
|     container_name: redis |     container_name: redis | ||||||
|     restart: always |     restart: always | ||||||
|  |  | ||||||
|   db: |   db: | ||||||
|     image: mysql:8.2.0 |     image: "${REGISTRY:-docker.io}/mysql:8.2.0" | ||||||
|     restart: always |     restart: always | ||||||
|     container_name: mysql |     container_name: mysql | ||||||
|     volumes: |     volumes: | ||||||
|   | |||||||
							
								
								
									
										53
									
								
								docs/API.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								docs/API.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | |||||||
|  | # 使用 API 操控 & 扩展 One API | ||||||
|  | > 欢迎提交 PR 在此放上你的拓展项目。 | ||||||
|  |  | ||||||
|  | 例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。 | ||||||
|  |  | ||||||
|  | 又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。 | ||||||
|  |  | ||||||
|  | ## 鉴权 | ||||||
|  | One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下图获取: | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | 之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API: | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ## 请求格式与响应格式 | ||||||
|  | One API 使用 JSON 格式进行请求和响应。 | ||||||
|  |  | ||||||
|  | 对于响应体,一般格式如下: | ||||||
|  | ```json | ||||||
|  | { | ||||||
|  |   "message": "请求信息", | ||||||
|  |   "success": true, | ||||||
|  |   "data": {} | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## API 列表 | ||||||
|  | > 当前 API 列表不全,请自行通过浏览器抓取前端请求 | ||||||
|  |  | ||||||
|  | 如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。 | ||||||
|  |  | ||||||
|  | ### 获取当前登录用户信息 | ||||||
|  | **GET** `/api/user/self` | ||||||
|  |  | ||||||
|  | ### 为给定用户充值额度 | ||||||
|  | **POST** `/api/topup` | ||||||
|  | ```json | ||||||
|  | { | ||||||
|  |   "user_id": 1, | ||||||
|  |   "quota": 100000, | ||||||
|  |   "remark": "充值 100000 额度" | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## 其他 | ||||||
|  | ### 充值链接上的附加参数 | ||||||
|  | One API 会在用户点击充值按钮的时候,将用户的信息和充值信息附加在链接上,例如: | ||||||
|  | `https://example.com?username=root&user_id=1&transaction_id=4b3eed80-55d5-443f-bd44-fb18c648c837` | ||||||
|  |  | ||||||
|  | 你可以通过解析链接上的参数来获取用户信息和充值信息,然后调用 API 来为用户充值。 | ||||||
|  |  | ||||||
|  | 注意,不是所有主题都支持该功能,欢迎 PR 补齐。 | ||||||
							
								
								
									
										101
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										101
									
								
								go.mod
									
									
									
									
									
								
							| @@ -1,65 +1,86 @@ | |||||||
| module one-api | module github.com/songquanpeng/one-api | ||||||
|  |  | ||||||
| // +heroku goVersion go1.18 | // +heroku goVersion go1.18 | ||||||
| go 1.18 | go 1.20 | ||||||
|  |  | ||||||
| require ( | require ( | ||||||
| 	github.com/gin-contrib/cors v1.4.0 | 	github.com/aws/aws-sdk-go-v2 v1.27.0 | ||||||
| 	github.com/gin-contrib/gzip v0.0.6 | 	github.com/aws/aws-sdk-go-v2/credentials v1.17.15 | ||||||
| 	github.com/gin-contrib/sessions v0.0.5 | 	github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 | ||||||
| 	github.com/gin-contrib/static v0.0.1 | 	github.com/gin-contrib/cors v1.7.2 | ||||||
| 	github.com/gin-gonic/gin v1.9.1 | 	github.com/gin-contrib/gzip v1.0.1 | ||||||
| 	github.com/go-playground/validator/v10 v10.14.0 | 	github.com/gin-contrib/sessions v1.0.1 | ||||||
|  | 	github.com/gin-contrib/static v1.1.2 | ||||||
|  | 	github.com/gin-gonic/gin v1.10.0 | ||||||
|  | 	github.com/go-playground/validator/v10 v10.20.0 | ||||||
| 	github.com/go-redis/redis/v8 v8.11.5 | 	github.com/go-redis/redis/v8 v8.11.5 | ||||||
| 	github.com/golang-jwt/jwt v3.2.2+incompatible | 	github.com/golang-jwt/jwt v3.2.2+incompatible | ||||||
| 	github.com/google/uuid v1.3.0 | 	github.com/google/uuid v1.6.0 | ||||||
| 	github.com/gorilla/websocket v1.5.0 | 	github.com/gorilla/websocket v1.5.1 | ||||||
| 	github.com/pkoukk/tiktoken-go v0.1.5 | 	github.com/jinzhu/copier v0.4.0 | ||||||
| 	github.com/stretchr/testify v1.8.3 | 	github.com/joho/godotenv v1.5.1 | ||||||
| 	golang.org/x/crypto v0.17.0 | 	github.com/pkg/errors v0.9.1 | ||||||
| 	golang.org/x/image v0.14.0 | 	github.com/pkoukk/tiktoken-go v0.1.7 | ||||||
| 	gorm.io/driver/mysql v1.4.3 | 	github.com/smartystreets/goconvey v1.8.1 | ||||||
| 	gorm.io/driver/postgres v1.5.2 | 	github.com/stretchr/testify v1.9.0 | ||||||
| 	gorm.io/driver/sqlite v1.4.3 | 	golang.org/x/crypto v0.23.0 | ||||||
| 	gorm.io/gorm v1.25.0 | 	golang.org/x/image v0.18.0 | ||||||
|  | 	gorm.io/driver/mysql v1.5.6 | ||||||
|  | 	gorm.io/driver/postgres v1.5.7 | ||||||
|  | 	gorm.io/driver/sqlite v1.5.5 | ||||||
|  | 	gorm.io/gorm v1.25.10 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| require ( | require ( | ||||||
| 	github.com/bytedance/sonic v1.9.1 // indirect | 	filippo.io/edwards25519 v1.1.0 // indirect | ||||||
| 	github.com/cespare/xxhash/v2 v2.1.2 // indirect | 	github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect | ||||||
| 	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect | 	github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect | ||||||
|  | 	github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 // indirect | ||||||
|  | 	github.com/aws/smithy-go v1.20.2 // indirect | ||||||
|  | 	github.com/bytedance/sonic v1.11.6 // indirect | ||||||
|  | 	github.com/bytedance/sonic/loader v0.1.1 // indirect | ||||||
|  | 	github.com/cespare/xxhash/v2 v2.3.0 // indirect | ||||||
|  | 	github.com/cloudwego/base64x v0.1.4 // indirect | ||||||
|  | 	github.com/cloudwego/iasm v0.2.0 // indirect | ||||||
| 	github.com/davecgh/go-spew v1.1.1 // indirect | 	github.com/davecgh/go-spew v1.1.1 // indirect | ||||||
| 	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect | 	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect | ||||||
| 	github.com/dlclark/regexp2 v1.10.0 // indirect | 	github.com/dlclark/regexp2 v1.11.0 // indirect | ||||||
| 	github.com/gabriel-vasile/mimetype v1.4.2 // indirect | 	github.com/fsnotify/fsnotify v1.7.0 // indirect | ||||||
|  | 	github.com/gabriel-vasile/mimetype v1.4.3 // indirect | ||||||
| 	github.com/gin-contrib/sse v0.1.0 // indirect | 	github.com/gin-contrib/sse v0.1.0 // indirect | ||||||
| 	github.com/go-playground/locales v0.14.1 // indirect | 	github.com/go-playground/locales v0.14.1 // indirect | ||||||
| 	github.com/go-playground/universal-translator v0.18.1 // indirect | 	github.com/go-playground/universal-translator v0.18.1 // indirect | ||||||
| 	github.com/go-sql-driver/mysql v1.6.0 // indirect | 	github.com/go-sql-driver/mysql v1.8.1 // indirect | ||||||
| 	github.com/goccy/go-json v0.10.2 // indirect | 	github.com/goccy/go-json v0.10.3 // indirect | ||||||
| 	github.com/gorilla/context v1.1.1 // indirect | 	github.com/gopherjs/gopherjs v1.17.2 // indirect | ||||||
| 	github.com/gorilla/securecookie v1.1.1 // indirect | 	github.com/gorilla/context v1.1.2 // indirect | ||||||
| 	github.com/gorilla/sessions v1.2.1 // indirect | 	github.com/gorilla/securecookie v1.1.2 // indirect | ||||||
|  | 	github.com/gorilla/sessions v1.2.2 // indirect | ||||||
| 	github.com/jackc/pgpassfile v1.0.0 // indirect | 	github.com/jackc/pgpassfile v1.0.0 // indirect | ||||||
| 	github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect | 	github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect | ||||||
| 	github.com/jackc/pgx/v5 v5.3.1 // indirect | 	github.com/jackc/pgx/v5 v5.5.5 // indirect | ||||||
|  | 	github.com/jackc/puddle/v2 v2.2.1 // indirect | ||||||
| 	github.com/jinzhu/inflection v1.0.0 // indirect | 	github.com/jinzhu/inflection v1.0.0 // indirect | ||||||
| 	github.com/jinzhu/now v1.1.5 // indirect | 	github.com/jinzhu/now v1.1.5 // indirect | ||||||
| 	github.com/json-iterator/go v1.1.12 // indirect | 	github.com/json-iterator/go v1.1.12 // indirect | ||||||
| 	github.com/klauspost/cpuid/v2 v2.2.4 // indirect | 	github.com/jtolds/gls v4.20.0+incompatible // indirect | ||||||
| 	github.com/leodido/go-urn v1.2.4 // indirect | 	github.com/klauspost/cpuid/v2 v2.2.7 // indirect | ||||||
| 	github.com/mattn/go-isatty v0.0.19 // indirect | 	github.com/kr/text v0.2.0 // indirect | ||||||
|  | 	github.com/leodido/go-urn v1.4.0 // indirect | ||||||
|  | 	github.com/mattn/go-isatty v0.0.20 // indirect | ||||||
| 	github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect | 	github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect | ||||||
| 	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect | 	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect | ||||||
| 	github.com/modern-go/reflect2 v1.0.2 // indirect | 	github.com/modern-go/reflect2 v1.0.2 // indirect | ||||||
| 	github.com/pelletier/go-toml/v2 v2.0.8 // indirect | 	github.com/pelletier/go-toml/v2 v2.2.2 // indirect | ||||||
| 	github.com/pmezard/go-difflib v1.0.0 // indirect | 	github.com/pmezard/go-difflib v1.0.0 // indirect | ||||||
|  | 	github.com/smarty/assertions v1.15.0 // indirect | ||||||
| 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | ||||||
| 	github.com/ugorji/go/codec v1.2.11 // indirect | 	github.com/ugorji/go/codec v1.2.12 // indirect | ||||||
| 	golang.org/x/arch v0.3.0 // indirect | 	golang.org/x/arch v0.8.0 // indirect | ||||||
| 	golang.org/x/net v0.17.0 // indirect | 	golang.org/x/net v0.25.0 // indirect | ||||||
| 	golang.org/x/sys v0.15.0 // indirect | 	golang.org/x/sync v0.7.0 // indirect | ||||||
| 	golang.org/x/text v0.14.0 // indirect | 	golang.org/x/sys v0.20.0 // indirect | ||||||
| 	google.golang.org/protobuf v1.30.0 // indirect | 	golang.org/x/text v0.16.0 // indirect | ||||||
|  | 	google.golang.org/protobuf v1.34.1 // indirect | ||||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||||
| ) | ) | ||||||
|   | |||||||
							
								
								
									
										269
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										269
									
								
								go.sum
									
									
									
									
									
								
							| @@ -1,206 +1,189 @@ | |||||||
| github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= | filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= | ||||||
| github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= | filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= | ||||||
| github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= | github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo= | ||||||
| github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= | github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= | ||||||
| github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= | github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= | ||||||
| github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= | github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= | ||||||
| github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= | github.com/aws/aws-sdk-go-v2/credentials v1.17.15 h1:YDexlvDRCA8ems2T5IP1xkMtOZ1uLJOCJdTr0igs5zo= | ||||||
| github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= | github.com/aws/aws-sdk-go-v2/credentials v1.17.15/go.mod h1:vxHggqW6hFNaeNC0WyXS3VdyjcV0a4KMUY4dKJ96buU= | ||||||
|  | github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 h1:lf/8VTF2cM+N4SLzaYJERKEWAXq8MOMpZfU6wEPWsPk= | ||||||
|  | github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7/go.mod h1:4SjkU7QiqK2M9oozyMzfZ/23LmUY+h3oFqhdeP5OMiI= | ||||||
|  | github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 h1:4OYVp0705xu8yjdyoWix0r9wPIRXnIzzOoUpQVHIJ/g= | ||||||
|  | github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7/go.mod h1:vd7ESTEvI76T2Na050gODNmNU7+OyKrIKroYTu4ABiI= | ||||||
|  | github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 h1:Fihjyd6DeNjcawBEGLH9dkIEUi6AdhucDKPE9nJ4QiY= | ||||||
|  | github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3/go.mod h1:opvUj3ismqSCxYc+m4WIjPL0ewZGtvp0ess7cKvBPOQ= | ||||||
|  | github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= | ||||||
|  | github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= | ||||||
|  | github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= | ||||||
|  | github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= | ||||||
|  | github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= | ||||||
|  | github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= | ||||||
|  | github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= | ||||||
|  | github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= | ||||||
|  | github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= | ||||||
|  | github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= | ||||||
|  | github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= | ||||||
|  | github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= | ||||||
| github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= | ||||||
| github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | ||||||
| github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= | ||||||
| github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | ||||||
| github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= | ||||||
| github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= | ||||||
| github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= | github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= | ||||||
| github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= | github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= | ||||||
| github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= | github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= | ||||||
| github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= | github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= | ||||||
| github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= | github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= | ||||||
| github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g= | github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= | ||||||
| github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs= | github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw= | ||||||
| github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4= | github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E= | ||||||
| github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk= | github.com/gin-contrib/gzip v1.0.1 h1:HQ8ENHODeLY7a4g1Au/46Z92bdGFl74OhxcZble9WJE= | ||||||
| github.com/gin-contrib/sessions v0.0.5 h1:CATtfHmLMQrMNpJRgzjWXD7worTh7g7ritsQfmF+0jE= | github.com/gin-contrib/gzip v1.0.1/go.mod h1:njt428fdUNRvjuJf16tZMYZ2Yl+WQB53X5wmhDwXvC4= | ||||||
| github.com/gin-contrib/sessions v0.0.5/go.mod h1:vYAuaUPqie3WUSsft6HUlCjlwwoJQs97miaG2+7neKY= | github.com/gin-contrib/sessions v1.0.1 h1:3hsJyNs7v7N8OtelFmYXFrulAf6zSR7nW/putcPEHxI= | ||||||
|  | github.com/gin-contrib/sessions v1.0.1/go.mod h1:ouxSFM24/OgIud5MJYQJLpy6AwxQ5EYO9yLhbtObGkM= | ||||||
| github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= | github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= | ||||||
| github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= | github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= | ||||||
| github.com/gin-contrib/static v0.0.1 h1:JVxuvHPuUfkoul12N7dtQw7KRn/pSMq7Ue1Va9Swm1U= | github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0NglqmlZ4= | ||||||
| github.com/gin-contrib/static v0.0.1/go.mod h1:CSxeF+wep05e0kCOsqWdAWbSszmc31zTIbD8TvWl7Hs= | github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw= | ||||||
| github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= | github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= | ||||||
| github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk= | github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= | ||||||
| github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= |  | ||||||
| github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= |  | ||||||
| github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= |  | ||||||
| github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= | github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= | ||||||
| github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= |  | ||||||
| github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= |  | ||||||
| github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= | github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= | ||||||
| github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= | github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= | ||||||
| github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= |  | ||||||
| github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= |  | ||||||
| github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= | github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= | ||||||
| github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= | github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= | ||||||
| github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= | github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= | ||||||
| github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= | github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= | ||||||
| github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= |  | ||||||
| github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= |  | ||||||
| github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= | github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= | ||||||
| github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= | github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= | ||||||
| github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= | github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= | ||||||
| github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= | github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= | ||||||
| github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= | github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= | ||||||
| github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= | github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= | ||||||
| github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= | github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= | ||||||
| github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= | github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= | ||||||
| github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= | github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= | ||||||
| github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= | github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= | ||||||
| github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= |  | ||||||
| github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= |  | ||||||
| github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= |  | ||||||
| github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= | github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= | ||||||
| github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= | github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= | ||||||
| github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= | ||||||
| github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= | ||||||
| github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= | github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= | ||||||
| github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= | github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= | ||||||
| github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= | github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o= | ||||||
| github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= | github.com/gorilla/context v1.1.2/go.mod h1:KDPwT9i/MeWHiLl90fuTgrt4/wPcv75vFAZLaOOcbxM= | ||||||
| github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= | github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= | ||||||
| github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= | github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= | ||||||
| github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= | github.com/gorilla/sessions v1.2.2 h1:lqzMYz6bOfvn2WriPUjNByzeXIlVzURcPmgMczkmTjY= | ||||||
|  | github.com/gorilla/sessions v1.2.2/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ= | ||||||
|  | github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= | ||||||
|  | github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= | ||||||
| github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= | ||||||
| github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= | ||||||
| github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= | github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= | ||||||
| github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= | github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= | ||||||
| github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= | github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= | ||||||
| github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= | github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= | ||||||
|  | github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= | ||||||
|  | github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= | ||||||
|  | github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= | ||||||
|  | github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= | ||||||
| github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | ||||||
| github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | ||||||
| github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= |  | ||||||
| github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= | github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= | ||||||
| github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||||
| github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= | github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= | ||||||
|  | github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= | ||||||
| github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= | github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= | ||||||
| github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= | github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= | ||||||
|  | github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= | ||||||
|  | github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= | ||||||
| github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= | github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= | ||||||
| github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= | github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= | ||||||
| github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= | github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= | ||||||
| github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= | github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= | ||||||
| github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= |  | ||||||
| github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= | github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= | ||||||
| github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= |  | ||||||
| github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= |  | ||||||
| github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= |  | ||||||
| github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= | ||||||
| github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= | ||||||
| github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= | github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= | ||||||
| github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= | github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= | ||||||
| github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= | ||||||
| github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= | ||||||
| github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= |  | ||||||
| github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= |  | ||||||
| github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= |  | ||||||
| github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= |  | ||||||
| github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= |  | ||||||
| github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= | github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= | ||||||
| github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= | github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= | ||||||
| github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= | github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= | ||||||
| github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= | ||||||
| github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= | ||||||
| github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= |  | ||||||
| github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= | github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= | ||||||
| github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= | github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= | ||||||
| github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= | github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= | ||||||
| github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= | github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= | ||||||
| github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= | github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= | ||||||
| github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= | github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= | ||||||
| github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= | github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= | ||||||
| github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= | ||||||
| github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= | ||||||
| github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4= | github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw= | ||||||
| github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= | github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= | ||||||
| github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||||||
| github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||||||
| github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= |  | ||||||
| github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= | github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= | ||||||
| github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= | github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= | ||||||
|  | github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= | ||||||
|  | github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= | ||||||
|  | github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= | ||||||
| github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= | ||||||
| github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= | ||||||
| github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= | ||||||
|  | github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= | ||||||
| github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= | ||||||
| github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= |  | ||||||
| github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= |  | ||||||
| github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= | ||||||
| github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= | ||||||
| github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= | ||||||
| github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= | ||||||
| github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= | ||||||
| github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= | ||||||
| github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= | ||||||
| github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= | github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= | ||||||
| github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= | github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= | ||||||
| github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= | github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= | ||||||
| github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= | github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= | ||||||
| github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= |  | ||||||
| github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= |  | ||||||
| github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= |  | ||||||
| github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= |  | ||||||
| golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | ||||||
| golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= | golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= | ||||||
| golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= | ||||||
| golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= | golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= | ||||||
| golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= | golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= | ||||||
| golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= | golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= | ||||||
| golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= | golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= | ||||||
| golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= | golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= | ||||||
| golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= | ||||||
| golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= | golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= | ||||||
| golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= | golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= | ||||||
| golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
| golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= |  | ||||||
| golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= |  | ||||||
| golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= |  | ||||||
| golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= |  | ||||||
| golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= |  | ||||||
| golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
| golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= | golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= | ||||||
| golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||||
| golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= | ||||||
| golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= | golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= | ||||||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= | ||||||
| golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= | ||||||
| golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= | google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= | ||||||
| golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= |  | ||||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= |  | ||||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= |  | ||||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= |  | ||||||
| google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= |  | ||||||
| google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= |  | ||||||
| google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= |  | ||||||
| google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= |  | ||||||
| gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||||
| gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= |  | ||||||
| gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= | ||||||
| gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= |  | ||||||
| gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= |  | ||||||
| gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= | ||||||
| gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= |  | ||||||
| gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= |  | ||||||
| gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= | gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= | ||||||
| gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= |  | ||||||
| gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | ||||||
| gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= |  | ||||||
| gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= | ||||||
| gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | ||||||
| gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k= | gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= | ||||||
| gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= | gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= | ||||||
| gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= | gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= | ||||||
| gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= | gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= | ||||||
| gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= | gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E= | ||||||
| gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= | gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE= | ||||||
| gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= | gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= | ||||||
| gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= | gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= | ||||||
| gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= | gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= | ||||||
| gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= | nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= | ||||||
| rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= | rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= | ||||||
|   | |||||||
							
								
								
									
										35
									
								
								i18n/en.json
									
									
									
									
									
								
							
							
						
						
									
										35
									
								
								i18n/en.json
									
									
									
									
									
								
							| @@ -8,12 +8,12 @@ | |||||||
|   "确认删除": "Confirm Delete", |   "确认删除": "Confirm Delete", | ||||||
|   "确认绑定": "Confirm Binding", |   "确认绑定": "Confirm Binding", | ||||||
|   "您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.", |   "您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.", | ||||||
|   "\"通道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"", |   "\"渠道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"", | ||||||
|   "通道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s", |   "渠道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s", | ||||||
|   "测试已在运行中": "Test is already running", |   "测试已在运行中": "Test is already running", | ||||||
|   "响应时间 %.2fs 超过阈值 %.2fs": "Response time %.2fs exceeds threshold %.2fs", |   "响应时间 %.2fs 超过阈值 %.2fs": "Response time %.2fs exceeds threshold %.2fs", | ||||||
|   "通道测试完成": "Channel test completed", |   "渠道测试完成": "Channel test completed", | ||||||
|   "通道测试完成,如果没有收到禁用通知,说明所有通道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal", |   "渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal", | ||||||
|   "无法连接至 GitHub 服务器,请稍后重试!": "Unable to connect to GitHub server, please try again later!", |   "无法连接至 GitHub 服务器,请稍后重试!": "Unable to connect to GitHub server, please try again later!", | ||||||
|   "返回值非法,用户字段为空,请稍后重试!": "The return value is illegal, the user field is empty, please try again later!", |   "返回值非法,用户字段为空,请稍后重试!": "The return value is illegal, the user field is empty, please try again later!", | ||||||
|   "管理员未开启通过 GitHub 登录以及注册": "The administrator did not turn on login and registration via GitHub", |   "管理员未开启通过 GitHub 登录以及注册": "The administrator did not turn on login and registration via GitHub", | ||||||
| @@ -119,11 +119,11 @@ | |||||||
|   " 个月 ": " M ", |   " 个月 ": " M ", | ||||||
|   " 年 ": " y ", |   " 年 ": " y ", | ||||||
|   "未测试": "Not tested", |   "未测试": "Not tested", | ||||||
|   "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", |   "渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", | ||||||
|   "已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.", |   "已成功开始测试所有渠道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.", | ||||||
|   "已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", |   "已成功开始测试所有已启用渠道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", | ||||||
|   "通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", |   "渠道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", | ||||||
|   "已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!", |   "已更新完毕所有已启用渠道余额!": "The balance of all enabled channels has been updated!", | ||||||
|   "搜索渠道的 ID,名称和密钥 ...": "Search for channel ID, name and key ...", |   "搜索渠道的 ID,名称和密钥 ...": "Search for channel ID, name and key ...", | ||||||
|   "名称": "Name", |   "名称": "Name", | ||||||
|   "分组": "Group", |   "分组": "Group", | ||||||
| @@ -141,9 +141,9 @@ | |||||||
|   "启用": "Enable", |   "启用": "Enable", | ||||||
|   "编辑": "Edit", |   "编辑": "Edit", | ||||||
|   "添加新的渠道": "Add a new channel", |   "添加新的渠道": "Add a new channel", | ||||||
|   "测试所有通道": "Test all channels", |   "测试所有渠道": "Test all channels", | ||||||
|   "测试所有已启用通道": "Test all enabled channels", |   "测试所有已启用渠道": "Test all enabled channels", | ||||||
|   "更新所有已启用通道余额": "Update the balance of all enabled channels", |   "更新所有已启用渠道余额": "Update the balance of all enabled channels", | ||||||
|   "刷新": "Refresh", |   "刷新": "Refresh", | ||||||
|   "处理中...": "Processing...", |   "处理中...": "Processing...", | ||||||
|   "绑定成功!": "Binding succeeded!", |   "绑定成功!": "Binding succeeded!", | ||||||
| @@ -207,11 +207,11 @@ | |||||||
|   "监控设置": "Monitoring Settings", |   "监控设置": "Monitoring Settings", | ||||||
|   "最长响应时间": "Longest Response Time", |   "最长响应时间": "Longest Response Time", | ||||||
|   "单位秒": "Unit in seconds", |   "单位秒": "Unit in seconds", | ||||||
|   "当运行通道全部测试时": "When all operating channels are tested", |   "当运行渠道全部测试时": "When all operating channels are tested", | ||||||
|   "超过此时间将自动禁用通道": "Channels will be automatically disabled if this time is exceeded", |   "超过此时间将自动禁用渠道": "Channels will be automatically disabled if this time is exceeded", | ||||||
|   "额度提醒阈值": "Quota reminder threshold", |   "额度提醒阈值": "Quota reminder threshold", | ||||||
|   "低于此额度时将发送邮件提醒用户": "Email will be sent to remind users when the quota is below this", |   "低于此额度时将发送邮件提醒用户": "Email will be sent to remind users when the quota is below this", | ||||||
|   "失败时自动禁用通道": "Automatically disable the channel when it fails", |   "失败时自动禁用渠道": "Automatically disable the channel when it fails", | ||||||
|   "保存监控设置": "Save Monitoring Settings", |   "保存监控设置": "Save Monitoring Settings", | ||||||
|   "额度设置": "Quota Settings", |   "额度设置": "Quota Settings", | ||||||
|   "新用户初始额度": "Initial quota for new users", |   "新用户初始额度": "Initial quota for new users", | ||||||
| @@ -405,7 +405,7 @@ | |||||||
|   "镜像": "Mirror", |   "镜像": "Mirror", | ||||||
|   "请输入镜像站地址,格式为:https://domain.com,可不填,不填则使用渠道默认值": "Please enter the mirror site address, the format is: https://domain.com, it can be left blank, if left blank, the default value of the channel will be used", |   "请输入镜像站地址,格式为:https://domain.com,可不填,不填则使用渠道默认值": "Please enter the mirror site address, the format is: https://domain.com, it can be left blank, if left blank, the default value of the channel will be used", | ||||||
|   "模型": "Model", |   "模型": "Model", | ||||||
|   "请选择该通道所支持的模型": "Please select the model supported by the channel", |   "请选择该渠道所支持的模型": "Please select the model supported by the channel", | ||||||
|   "填入基础模型": "Fill in the basic model", |   "填入基础模型": "Fill in the basic model", | ||||||
|   "填入所有模型": "Fill in all models", |   "填入所有模型": "Fill in all models", | ||||||
|   "清除所有模型": "Clear all models", |   "清除所有模型": "Clear all models", | ||||||
| @@ -456,6 +456,7 @@ | |||||||
|   "已绑定的邮箱账户": "Email Account Bound", |   "已绑定的邮箱账户": "Email Account Bound", | ||||||
|   "用户信息更新成功!": "User information updated successfully!", |   "用户信息更新成功!": "User information updated successfully!", | ||||||
|   "模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f", |   "模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f", | ||||||
|  |   "模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f": "model rate %.2f, group rate %.2f, completion rate %.2f", | ||||||
|   "使用明细(总消耗额度:{renderQuota(stat.quota)})": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})", |   "使用明细(总消耗额度:{renderQuota(stat.quota)})": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})", | ||||||
|   "用户名称": "User Name", |   "用户名称": "User Name", | ||||||
|   "令牌名称": "Token Name", |   "令牌名称": "Token Name", | ||||||
| @@ -514,7 +515,7 @@ | |||||||
|   "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", |   "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", | ||||||
|   "Homepage URL 填": "Fill in the Homepage URL", |   "Homepage URL 填": "Fill in the Homepage URL", | ||||||
|   "Authorization callback URL 填": "Fill in the Authorization callback URL", |   "Authorization callback URL 填": "Fill in the Authorization callback URL", | ||||||
|   "请为通道命名": "Please name the channel", |   "请为渠道命名": "Please name the channel", | ||||||
|   "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", |   "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", | ||||||
|   "模型重定向": "Model redirection", |   "模型重定向": "Model redirection", | ||||||
|   "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel", |   "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel", | ||||||
|   | |||||||
							
								
								
									
										80
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										80
									
								
								main.go
									
									
									
									
									
								
							| @@ -6,11 +6,16 @@ import ( | |||||||
| 	"github.com/gin-contrib/sessions" | 	"github.com/gin-contrib/sessions" | ||||||
| 	"github.com/gin-contrib/sessions/cookie" | 	"github.com/gin-contrib/sessions/cookie" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"one-api/common" | 	_ "github.com/joho/godotenv/autoload" | ||||||
| 	"one-api/controller" | 	"github.com/songquanpeng/one-api/common" | ||||||
| 	"one-api/middleware" | 	"github.com/songquanpeng/one-api/common/client" | ||||||
| 	"one-api/model" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| 	"one-api/router" | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/controller" | ||||||
|  | 	"github.com/songquanpeng/one-api/middleware" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||||
|  | 	"github.com/songquanpeng/one-api/router" | ||||||
| 	"os" | 	"os" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| ) | ) | ||||||
| @@ -19,68 +24,72 @@ import ( | |||||||
| var buildFS embed.FS | var buildFS embed.FS | ||||||
|  |  | ||||||
| func main() { | func main() { | ||||||
| 	common.SetupLogger() | 	common.Init() | ||||||
| 	common.SysLog(fmt.Sprintf("One API %s started", common.Version)) | 	logger.SetupLogger() | ||||||
| 	if os.Getenv("GIN_MODE") != "debug" { | 	logger.SysLogf("One API %s started", common.Version) | ||||||
|  |  | ||||||
|  | 	if os.Getenv("GIN_MODE") != gin.DebugMode { | ||||||
| 		gin.SetMode(gin.ReleaseMode) | 		gin.SetMode(gin.ReleaseMode) | ||||||
| 	} | 	} | ||||||
| 	if common.DebugEnabled { | 	if config.DebugEnabled { | ||||||
| 		common.SysLog("running in debug mode") | 		logger.SysLog("running in debug mode") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Initialize SQL Database | 	// Initialize SQL Database | ||||||
| 	err := model.InitDB() | 	model.InitDB() | ||||||
|  | 	model.InitLogDB() | ||||||
|  |  | ||||||
|  | 	var err error | ||||||
|  | 	err = model.CreateRootAccountIfNeed() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.FatalLog("failed to initialize database: " + err.Error()) | 		logger.FatalLog("database init error: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		err := model.CloseDB() | 		err := model.CloseDB() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.FatalLog("failed to close database: " + err.Error()) | 			logger.FatalLog("failed to close database: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	// Initialize Redis | 	// Initialize Redis | ||||||
| 	err = common.InitRedisClient() | 	err = common.InitRedisClient() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.FatalLog("failed to initialize Redis: " + err.Error()) | 		logger.FatalLog("failed to initialize Redis: " + err.Error()) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Initialize options | 	// Initialize options | ||||||
| 	model.InitOptionMap() | 	model.InitOptionMap() | ||||||
| 	common.SysLog(fmt.Sprintf("using theme %s", common.Theme)) | 	logger.SysLog(fmt.Sprintf("using theme %s", config.Theme)) | ||||||
| 	if common.RedisEnabled { | 	if common.RedisEnabled { | ||||||
| 		// for compatibility with old versions | 		// for compatibility with old versions | ||||||
| 		common.MemoryCacheEnabled = true | 		config.MemoryCacheEnabled = true | ||||||
| 	} | 	} | ||||||
| 	if common.MemoryCacheEnabled { | 	if config.MemoryCacheEnabled { | ||||||
| 		common.SysLog("memory cache enabled") | 		logger.SysLog("memory cache enabled") | ||||||
| 		common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) | 		logger.SysLog(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency)) | ||||||
| 		model.InitChannelCache() | 		model.InitChannelCache() | ||||||
| 	} | 	} | ||||||
| 	if common.MemoryCacheEnabled { | 	if config.MemoryCacheEnabled { | ||||||
| 		go model.SyncOptions(common.SyncFrequency) | 		go model.SyncOptions(config.SyncFrequency) | ||||||
| 		go model.SyncChannelCache(common.SyncFrequency) | 		go model.SyncChannelCache(config.SyncFrequency) | ||||||
| 	} |  | ||||||
| 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { |  | ||||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) |  | ||||||
| 		if err != nil { |  | ||||||
| 			common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 		go controller.AutomaticallyUpdateChannels(frequency) |  | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { | 	if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { | ||||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) | 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) | 			logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 		go controller.AutomaticallyTestChannels(frequency) | 		go controller.AutomaticallyTestChannels(frequency) | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { | 	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { | ||||||
| 		common.BatchUpdateEnabled = true | 		config.BatchUpdateEnabled = true | ||||||
| 		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") | 		logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") | ||||||
| 		model.InitBatchUpdater() | 		model.InitBatchUpdater() | ||||||
| 	} | 	} | ||||||
| 	controller.InitTokenEncoders() | 	if config.EnableMetric { | ||||||
|  | 		logger.SysLog("metric enabled, will disable channel if too much request failed") | ||||||
|  | 	} | ||||||
|  | 	openai.InitTokenEncoders() | ||||||
|  | 	client.Init() | ||||||
|  |  | ||||||
| 	// Initialize HTTP server | 	// Initialize HTTP server | ||||||
| 	server := gin.New() | 	server := gin.New() | ||||||
| @@ -90,7 +99,7 @@ func main() { | |||||||
| 	server.Use(middleware.RequestId()) | 	server.Use(middleware.RequestId()) | ||||||
| 	middleware.SetUpLogger(server) | 	middleware.SetUpLogger(server) | ||||||
| 	// Initialize session store | 	// Initialize session store | ||||||
| 	store := cookie.NewStore([]byte(common.SessionSecret)) | 	store := cookie.NewStore([]byte(config.SessionSecret)) | ||||||
| 	server.Use(sessions.Sessions("session", store)) | 	server.Use(sessions.Sessions("session", store)) | ||||||
|  |  | ||||||
| 	router.SetRouter(server, buildFS) | 	router.SetRouter(server, buildFS) | ||||||
| @@ -98,8 +107,9 @@ func main() { | |||||||
| 	if port == "" { | 	if port == "" { | ||||||
| 		port = strconv.Itoa(*common.Port) | 		port = strconv.Itoa(*common.Port) | ||||||
| 	} | 	} | ||||||
|  | 	logger.SysLogf("server started on http://localhost:%s", port) | ||||||
| 	err = server.Run(":" + port) | 	err = server.Run(":" + port) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.FatalLog("failed to start HTTP server: " + err.Error()) | 		logger.FatalLog("failed to start HTTP server: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,11 +1,14 @@ | |||||||
| package middleware | package middleware | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
| 	"github.com/gin-contrib/sessions" | 	"github.com/gin-contrib/sessions" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/blacklist" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/network" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/model" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -42,11 +45,14 @@ func authHelper(c *gin.Context, minRole int) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if status.(int) == common.UserStatusDisabled { | 	if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": "用户已被封禁", | 			"message": "用户已被封禁", | ||||||
| 		}) | 		}) | ||||||
|  | 		session := sessions.Default(c) | ||||||
|  | 		session.Clear() | ||||||
|  | 		_ = session.Save() | ||||||
| 		c.Abort() | 		c.Abort() | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -66,24 +72,25 @@ func authHelper(c *gin.Context, minRole int) { | |||||||
|  |  | ||||||
| func UserAuth() func(c *gin.Context) { | func UserAuth() func(c *gin.Context) { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		authHelper(c, common.RoleCommonUser) | 		authHelper(c, model.RoleCommonUser) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func AdminAuth() func(c *gin.Context) { | func AdminAuth() func(c *gin.Context) { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		authHelper(c, common.RoleAdminUser) | 		authHelper(c, model.RoleAdminUser) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func RootAuth() func(c *gin.Context) { | func RootAuth() func(c *gin.Context) { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		authHelper(c, common.RoleRootUser) | 		authHelper(c, model.RoleRootUser) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TokenAuth() func(c *gin.Context) { | func TokenAuth() func(c *gin.Context) { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
|  | 		ctx := c.Request.Context() | ||||||
| 		key := c.Request.Header.Get("Authorization") | 		key := c.Request.Header.Get("Authorization") | ||||||
| 		key = strings.TrimPrefix(key, "Bearer ") | 		key = strings.TrimPrefix(key, "Bearer ") | ||||||
| 		key = strings.TrimPrefix(key, "sk-") | 		key = strings.TrimPrefix(key, "sk-") | ||||||
| @@ -94,21 +101,40 @@ func TokenAuth() func(c *gin.Context) { | |||||||
| 			abortWithMessage(c, http.StatusUnauthorized, err.Error()) | 			abortWithMessage(c, http.StatusUnauthorized, err.Error()) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  | 		if token.Subnet != nil && *token.Subnet != "" { | ||||||
|  | 			if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) { | ||||||
|  | 				abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP())) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
| 		userEnabled, err := model.CacheIsUserEnabled(token.UserId) | 		userEnabled, err := model.CacheIsUserEnabled(token.UserId) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			abortWithMessage(c, http.StatusInternalServerError, err.Error()) | 			abortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		if !userEnabled { | 		if !userEnabled || blacklist.IsUserBanned(token.UserId) { | ||||||
| 			abortWithMessage(c, http.StatusForbidden, "用户已被封禁") | 			abortWithMessage(c, http.StatusForbidden, "用户已被封禁") | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		c.Set("id", token.UserId) | 		requestModel, err := getRequestModel(c) | ||||||
| 		c.Set("token_id", token.Id) | 		if err != nil && shouldCheckModel(c) { | ||||||
| 		c.Set("token_name", token.Name) | 			abortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		c.Set(ctxkey.RequestModel, requestModel) | ||||||
|  | 		if token.Models != nil && *token.Models != "" { | ||||||
|  | 			c.Set(ctxkey.AvailableModels, *token.Models) | ||||||
|  | 			if requestModel != "" && !isModelInList(requestModel, *token.Models) { | ||||||
|  | 				abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		c.Set(ctxkey.Id, token.UserId) | ||||||
|  | 		c.Set(ctxkey.TokenId, token.Id) | ||||||
|  | 		c.Set(ctxkey.TokenName, token.Name) | ||||||
| 		if len(parts) > 1 { | 		if len(parts) > 1 { | ||||||
| 			if model.IsAdmin(token.UserId) { | 			if model.IsAdmin(token.UserId) { | ||||||
| 				c.Set("channelId", parts[1]) | 				c.Set(ctxkey.SpecificChannelId, parts[1]) | ||||||
| 			} else { | 			} else { | ||||||
| 				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") | 				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") | ||||||
| 				return | 				return | ||||||
| @@ -117,3 +143,19 @@ func TokenAuth() func(c *gin.Context) { | |||||||
| 		c.Next() | 		c.Next() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func shouldCheckModel(c *gin.Context) bool { | ||||||
|  | 	if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(c.Request.URL.Path, "/v1/images") { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|   | |||||||
| @@ -2,13 +2,13 @@ package middleware | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" |  | ||||||
| 	"one-api/common" |  | ||||||
| 	"one-api/model" |  | ||||||
| 	"strconv" |  | ||||||
| 	"strings" |  | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strconv" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type ModelRequest struct { | type ModelRequest struct { | ||||||
| @@ -17,11 +17,12 @@ type ModelRequest struct { | |||||||
|  |  | ||||||
| func Distribute() func(c *gin.Context) { | func Distribute() func(c *gin.Context) { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		userId := c.GetInt("id") | 		userId := c.GetInt(ctxkey.Id) | ||||||
| 		userGroup, _ := model.CacheGetUserGroup(userId) | 		userGroup, _ := model.CacheGetUserGroup(userId) | ||||||
| 		c.Set("group", userGroup) | 		c.Set(ctxkey.Group, userGroup) | ||||||
|  | 		var requestModel string | ||||||
| 		var channel *model.Channel | 		var channel *model.Channel | ||||||
| 		channelId, ok := c.Get("channelId") | 		channelId, ok := c.Get(ctxkey.SpecificChannelId) | ||||||
| 		if ok { | 		if ok { | ||||||
| 			id, err := strconv.Atoi(channelId.(string)) | 			id, err := strconv.Atoi(channelId.(string)) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -33,67 +34,62 @@ func Distribute() func(c *gin.Context) { | |||||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") | 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			if channel.Status != common.ChannelStatusEnabled { | 			if channel.Status != model.ChannelStatusEnabled { | ||||||
| 				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") | 				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} else { | 		} else { | ||||||
| 			// Select a channel for the user | 			requestModel = c.GetString(ctxkey.RequestModel) | ||||||
| 			var modelRequest ModelRequest | 			var err error | ||||||
| 			err := common.UnmarshalBodyReusable(c, &modelRequest) | 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的请求") | 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel) | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { |  | ||||||
| 				if modelRequest.Model == "" { |  | ||||||
| 					modelRequest.Model = "text-moderation-stable" |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			if strings.HasSuffix(c.Request.URL.Path, "embeddings") { |  | ||||||
| 				if modelRequest.Model == "" { |  | ||||||
| 					modelRequest.Model = c.Param("model") |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { |  | ||||||
| 				if modelRequest.Model == "" { |  | ||||||
| 					modelRequest.Model = "dall-e-2" |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { |  | ||||||
| 				if modelRequest.Model == "" { |  | ||||||
| 					modelRequest.Model = "whisper-1" |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) |  | ||||||
| 			if err != nil { |  | ||||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) |  | ||||||
| 				if channel != nil { | 				if channel != nil { | ||||||
| 					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | 					logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | ||||||
| 					message = "数据库一致性已被破坏,请联系管理员" | 					message = "数据库一致性已被破坏,请联系管理员" | ||||||
| 				} | 				} | ||||||
| 				abortWithMessage(c, http.StatusServiceUnavailable, message) | 				abortWithMessage(c, http.StatusServiceUnavailable, message) | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		c.Set("channel", channel.Type) | 		SetupContextForSelectedChannel(c, channel, requestModel) | ||||||
| 		c.Set("channel_id", channel.Id) |  | ||||||
| 		c.Set("channel_name", channel.Name) |  | ||||||
| 		c.Set("model_mapping", channel.GetModelMapping()) |  | ||||||
| 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) |  | ||||||
| 		c.Set("base_url", channel.GetBaseURL()) |  | ||||||
| 		switch channel.Type { |  | ||||||
| 		case common.ChannelTypeAzure: |  | ||||||
| 			c.Set("api_version", channel.Other) |  | ||||||
| 		case common.ChannelTypeXunfei: |  | ||||||
| 			c.Set("api_version", channel.Other) |  | ||||||
| 		case common.ChannelTypeGemini: |  | ||||||
| 			c.Set("api_version", channel.Other) |  | ||||||
| 		case common.ChannelTypeAIProxyLibrary: |  | ||||||
| 			c.Set("library_id", channel.Other) |  | ||||||
| 		case common.ChannelTypeAli: |  | ||||||
| 			c.Set("plugin", channel.Other) |  | ||||||
| 		} |  | ||||||
| 		c.Next() | 		c.Next() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { | ||||||
|  | 	c.Set(ctxkey.Channel, channel.Type) | ||||||
|  | 	c.Set(ctxkey.ChannelId, channel.Id) | ||||||
|  | 	c.Set(ctxkey.ChannelName, channel.Name) | ||||||
|  | 	c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) | ||||||
|  | 	c.Set(ctxkey.OriginalModel, modelName) // for retry | ||||||
|  | 	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||||
|  | 	c.Set(ctxkey.BaseURL, channel.GetBaseURL()) | ||||||
|  | 	cfg, _ := channel.LoadConfig() | ||||||
|  | 	// this is for backward compatibility | ||||||
|  | 	if channel.Other != nil { | ||||||
|  | 		switch channel.Type { | ||||||
|  | 		case channeltype.Azure: | ||||||
|  | 			if cfg.APIVersion == "" { | ||||||
|  | 				cfg.APIVersion = *channel.Other | ||||||
|  | 			} | ||||||
|  | 		case channeltype.Xunfei: | ||||||
|  | 			if cfg.APIVersion == "" { | ||||||
|  | 				cfg.APIVersion = *channel.Other | ||||||
|  | 			} | ||||||
|  | 		case channeltype.Gemini: | ||||||
|  | 			if cfg.APIVersion == "" { | ||||||
|  | 				cfg.APIVersion = *channel.Other | ||||||
|  | 			} | ||||||
|  | 		case channeltype.AIProxyLibrary: | ||||||
|  | 			if cfg.LibraryID == "" { | ||||||
|  | 				cfg.LibraryID = *channel.Other | ||||||
|  | 			} | ||||||
|  | 		case channeltype.Ali: | ||||||
|  | 			if cfg.Plugin == "" { | ||||||
|  | 				cfg.Plugin = *channel.Other | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	c.Set(ctxkey.Config, cfg) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -3,14 +3,14 @@ package middleware | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"one-api/common" | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func SetUpLogger(server *gin.Engine) { | func SetUpLogger(server *gin.Engine) { | ||||||
| 	server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { | 	server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { | ||||||
| 		var requestID string | 		var requestID string | ||||||
| 		if param.Keys != nil { | 		if param.Keys != nil { | ||||||
| 			requestID = param.Keys[common.RequestIdKey].(string) | 			requestID = param.Keys[helper.RequestIdKey].(string) | ||||||
| 		} | 		} | ||||||
| 		return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", | 		return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", | ||||||
| 			param.TimeStamp.Format("2006/01/02 - 15:04:05"), | 			param.TimeStamp.Format("2006/01/02 - 15:04:05"), | ||||||
|   | |||||||
| @@ -4,8 +4,9 @@ import ( | |||||||
| 	"context" | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" |  | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -26,7 +27,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st | |||||||
| 	} | 	} | ||||||
| 	if listLength < int64(maxRequestNum) { | 	if listLength < int64(maxRequestNum) { | ||||||
| 		rdb.LPush(ctx, key, time.Now().Format(timeFormat)) | 		rdb.LPush(ctx, key, time.Now().Format(timeFormat)) | ||||||
| 		rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) | 		rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) | ||||||
| 	} else { | 	} else { | ||||||
| 		oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() | 		oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() | ||||||
| 		oldTime, err := time.Parse(timeFormat, oldTimeStr) | 		oldTime, err := time.Parse(timeFormat, oldTimeStr) | ||||||
| @@ -47,14 +48,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st | |||||||
| 		// time.Since will return negative number! | 		// time.Since will return negative number! | ||||||
| 		// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows | 		// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows | ||||||
| 		if int64(nowTime.Sub(oldTime).Seconds()) < duration { | 		if int64(nowTime.Sub(oldTime).Seconds()) < duration { | ||||||
| 			rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) | 			rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) | ||||||
| 			c.Status(http.StatusTooManyRequests) | 			c.Status(http.StatusTooManyRequests) | ||||||
| 			c.Abort() | 			c.Abort() | ||||||
| 			return | 			return | ||||||
| 		} else { | 		} else { | ||||||
| 			rdb.LPush(ctx, key, time.Now().Format(timeFormat)) | 			rdb.LPush(ctx, key, time.Now().Format(timeFormat)) | ||||||
| 			rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) | 			rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) | ||||||
| 			rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) | 			rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -75,7 +76,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi | |||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		// It's safe to call multi times. | 		// It's safe to call multi times. | ||||||
| 		inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) | 		inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration) | ||||||
| 		return func(c *gin.Context) { | 		return func(c *gin.Context) { | ||||||
| 			memoryRateLimiter(c, maxRequestNum, duration, mark) | 			memoryRateLimiter(c, maxRequestNum, duration, mark) | ||||||
| 		} | 		} | ||||||
| @@ -83,21 +84,21 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi | |||||||
| } | } | ||||||
|  |  | ||||||
| func GlobalWebRateLimit() func(c *gin.Context) { | func GlobalWebRateLimit() func(c *gin.Context) { | ||||||
| 	return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW") | 	return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW") | ||||||
| } | } | ||||||
|  |  | ||||||
| func GlobalAPIRateLimit() func(c *gin.Context) { | func GlobalAPIRateLimit() func(c *gin.Context) { | ||||||
| 	return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA") | 	return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA") | ||||||
| } | } | ||||||
|  |  | ||||||
| func CriticalRateLimit() func(c *gin.Context) { | func CriticalRateLimit() func(c *gin.Context) { | ||||||
| 	return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT") | 	return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT") | ||||||
| } | } | ||||||
|  |  | ||||||
| func DownloadRateLimit() func(c *gin.Context) { | func DownloadRateLimit() func(c *gin.Context) { | ||||||
| 	return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW") | 	return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW") | ||||||
| } | } | ||||||
|  |  | ||||||
| func UploadRateLimit() func(c *gin.Context) { | func UploadRateLimit() func(c *gin.Context) { | ||||||
| 	return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") | 	return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP") | ||||||
| } | } | ||||||
|   | |||||||
| @@ -3,8 +3,9 @@ package middleware | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" |  | ||||||
| 	"runtime/debug" | 	"runtime/debug" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -12,11 +13,15 @@ func RelayPanicRecover() gin.HandlerFunc { | |||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		defer func() { | 		defer func() { | ||||||
| 			if err := recover(); err != nil { | 			if err := recover(); err != nil { | ||||||
| 				common.SysError(fmt.Sprintf("panic detected: %v", err)) | 				ctx := c.Request.Context() | ||||||
| 				common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) | 				logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err)) | ||||||
|  | 				logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) | ||||||
|  | 				logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path)) | ||||||
|  | 				body, _ := common.GetRequestBody(c) | ||||||
|  | 				logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body))) | ||||||
| 				c.JSON(http.StatusInternalServerError, gin.H{ | 				c.JSON(http.StatusInternalServerError, gin.H{ | ||||||
| 					"error": gin.H{ | 					"error": gin.H{ | ||||||
| 						"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), | 						"message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err), | ||||||
| 						"type":    "one_api_panic", | 						"type":    "one_api_panic", | ||||||
| 					}, | 					}, | ||||||
| 				}) | 				}) | ||||||
|   | |||||||
| @@ -3,16 +3,16 @@ package middleware | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"one-api/common" | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func RequestId() func(c *gin.Context) { | func RequestId() func(c *gin.Context) { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		id := common.GetTimeString() + common.GetRandomString(8) | 		id := helper.GenRequestID() | ||||||
| 		c.Set(common.RequestIdKey, id) | 		c.Set(helper.RequestIdKey, id) | ||||||
| 		ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) | 		ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id) | ||||||
| 		c.Request = c.Request.WithContext(ctx) | 		c.Request = c.Request.WithContext(ctx) | ||||||
| 		c.Header(common.RequestIdKey, id) | 		c.Header(helper.RequestIdKey, id) | ||||||
| 		c.Next() | 		c.Next() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -4,9 +4,10 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"github.com/gin-contrib/sessions" | 	"github.com/gin-contrib/sessions" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"one-api/common" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type turnstileCheckResponse struct { | type turnstileCheckResponse struct { | ||||||
| @@ -15,7 +16,7 @@ type turnstileCheckResponse struct { | |||||||
|  |  | ||||||
| func TurnstileCheck() gin.HandlerFunc { | func TurnstileCheck() gin.HandlerFunc { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		if common.TurnstileCheckEnabled { | 		if config.TurnstileCheckEnabled { | ||||||
| 			session := sessions.Default(c) | 			session := sessions.Default(c) | ||||||
| 			turnstileChecked := session.Get("turnstile") | 			turnstileChecked := session.Get("turnstile") | ||||||
| 			if turnstileChecked != nil { | 			if turnstileChecked != nil { | ||||||
| @@ -32,12 +33,12 @@ func TurnstileCheck() gin.HandlerFunc { | |||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ | 			rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ | ||||||
| 				"secret":   {common.TurnstileSecretKey}, | 				"secret":   {config.TurnstileSecretKey}, | ||||||
| 				"response": {response}, | 				"response": {response}, | ||||||
| 				"remoteip": {c.ClientIP()}, | 				"remoteip": {c.ClientIP()}, | ||||||
| 			}) | 			}) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				common.SysError(err.Error()) | 				logger.SysError(err.Error()) | ||||||
| 				c.JSON(http.StatusOK, gin.H{ | 				c.JSON(http.StatusOK, gin.H{ | ||||||
| 					"success": false, | 					"success": false, | ||||||
| 					"message": err.Error(), | 					"message": err.Error(), | ||||||
| @@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc { | |||||||
| 			var res turnstileCheckResponse | 			var res turnstileCheckResponse | ||||||
| 			err = json.NewDecoder(rawRes.Body).Decode(&res) | 			err = json.NewDecoder(rawRes.Body).Decode(&res) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				common.SysError(err.Error()) | 				logger.SysError(err.Error()) | ||||||
| 				c.JSON(http.StatusOK, gin.H{ | 				c.JSON(http.StatusOK, gin.H{ | ||||||
| 					"success": false, | 					"success": false, | ||||||
| 					"message": err.Error(), | 					"message": err.Error(), | ||||||
|   | |||||||
| @@ -1,17 +1,60 @@ | |||||||
| package middleware | package middleware | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func abortWithMessage(c *gin.Context, statusCode int, message string) { | func abortWithMessage(c *gin.Context, statusCode int, message string) { | ||||||
| 	c.JSON(statusCode, gin.H{ | 	c.JSON(statusCode, gin.H{ | ||||||
| 		"error": gin.H{ | 		"error": gin.H{ | ||||||
| 			"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), | 			"message": helper.MessageWithRequestId(message, c.GetString(helper.RequestIdKey)), | ||||||
| 			"type":    "one_api_error", | 			"type":    "one_api_error", | ||||||
| 		}, | 		}, | ||||||
| 	}) | 	}) | ||||||
| 	c.Abort() | 	c.Abort() | ||||||
| 	common.LogError(c.Request.Context(), message) | 	logger.Error(c.Request.Context(), message) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getRequestModel(c *gin.Context) (string, error) { | ||||||
|  | 	var modelRequest ModelRequest | ||||||
|  | 	err := common.UnmarshalBodyReusable(c, &modelRequest) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||||
|  | 		if modelRequest.Model == "" { | ||||||
|  | 			modelRequest.Model = "text-moderation-stable" | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||||
|  | 		if modelRequest.Model == "" { | ||||||
|  | 			modelRequest.Model = c.Param("model") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||||
|  | 		if modelRequest.Model == "" { | ||||||
|  | 			modelRequest.Model = "dall-e-2" | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { | ||||||
|  | 		if modelRequest.Model == "" { | ||||||
|  | 			modelRequest.Model = "whisper-1" | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return modelRequest.Model, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func isModelInList(modelName string, models string) bool { | ||||||
|  | 	modelList := strings.Split(models, ",") | ||||||
|  | 	for _, model := range modelList { | ||||||
|  | 		if modelName == model { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,7 +1,10 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"one-api/common" | 	"context" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"gorm.io/gorm" | ||||||
|  | 	"sort" | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -13,7 +16,7 @@ type Ability struct { | |||||||
| 	Priority  *int64 `json:"priority" gorm:"bigint;default:0;index"` | 	Priority  *int64 `json:"priority" gorm:"bigint;default:0;index"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | func GetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { | ||||||
| 	ability := Ability{} | 	ability := Ability{} | ||||||
| 	groupCol := "`group`" | 	groupCol := "`group`" | ||||||
| 	trueVal := "1" | 	trueVal := "1" | ||||||
| @@ -23,8 +26,13 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var err error = nil | 	var err error = nil | ||||||
| 	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) | 	var channelQuery *gorm.DB | ||||||
| 	channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) | 	if ignoreFirstPriority { | ||||||
|  | 		channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) | ||||||
|  | 	} else { | ||||||
|  | 		maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) | ||||||
|  | 		channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) | ||||||
|  | 	} | ||||||
| 	if common.UsingSQLite || common.UsingPostgreSQL { | 	if common.UsingSQLite || common.UsingPostgreSQL { | ||||||
| 		err = channelQuery.Order("RANDOM()").First(&ability).Error | 		err = channelQuery.Order("RANDOM()").First(&ability).Error | ||||||
| 	} else { | 	} else { | ||||||
| @@ -49,7 +57,7 @@ func (channel *Channel) AddAbilities() error { | |||||||
| 				Group:     group, | 				Group:     group, | ||||||
| 				Model:     model, | 				Model:     model, | ||||||
| 				ChannelId: channel.Id, | 				ChannelId: channel.Id, | ||||||
| 				Enabled:   channel.Status == common.ChannelStatusEnabled, | 				Enabled:   channel.Status == ChannelStatusEnabled, | ||||||
| 				Priority:  channel.Priority, | 				Priority:  channel.Priority, | ||||||
| 			} | 			} | ||||||
| 			abilities = append(abilities, ability) | 			abilities = append(abilities, ability) | ||||||
| @@ -82,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error { | |||||||
| func UpdateAbilityStatus(channelId int, status bool) error { | func UpdateAbilityStatus(channelId int, status bool) error { | ||||||
| 	return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error | 	return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GetGroupModels(ctx context.Context, group string) ([]string, error) { | ||||||
|  | 	groupCol := "`group`" | ||||||
|  | 	trueVal := "1" | ||||||
|  | 	if common.UsingPostgreSQL { | ||||||
|  | 		groupCol = `"group"` | ||||||
|  | 		trueVal = "true" | ||||||
|  | 	} | ||||||
|  | 	var models []string | ||||||
|  | 	err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	sort.Strings(models) | ||||||
|  | 	return models, err | ||||||
|  | } | ||||||
|   | |||||||
| @@ -1,11 +1,15 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/random" | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
| 	"one-api/common" |  | ||||||
| 	"sort" | 	"sort" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| @@ -14,10 +18,11 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| var ( | var ( | ||||||
| 	TokenCacheSeconds         = common.SyncFrequency | 	TokenCacheSeconds         = config.SyncFrequency | ||||||
| 	UserId2GroupCacheSeconds  = common.SyncFrequency | 	UserId2GroupCacheSeconds  = config.SyncFrequency | ||||||
| 	UserId2QuotaCacheSeconds  = common.SyncFrequency | 	UserId2QuotaCacheSeconds  = config.SyncFrequency | ||||||
| 	UserId2StatusCacheSeconds = common.SyncFrequency | 	UserId2StatusCacheSeconds = config.SyncFrequency | ||||||
|  | 	GroupModelsCacheSeconds   = config.SyncFrequency | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func CacheGetTokenByKey(key string) (*Token, error) { | func CacheGetTokenByKey(key string) (*Token, error) { | ||||||
| @@ -42,7 +47,7 @@ func CacheGetTokenByKey(key string) (*Token, error) { | |||||||
| 		} | 		} | ||||||
| 		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) | 		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.SysError("Redis set token error: " + err.Error()) | 			logger.SysError("Redis set token error: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 		return &token, nil | 		return &token, nil | ||||||
| 	} | 	} | ||||||
| @@ -62,37 +67,48 @@ func CacheGetUserGroup(id int) (group string, err error) { | |||||||
| 		} | 		} | ||||||
| 		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) | 		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.SysError("Redis set user group error: " + err.Error()) | 			logger.SysError("Redis set user group error: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	return group, err | 	return group, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func CacheGetUserQuota(id int) (quota int, err error) { | func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) { | ||||||
|  | 	quota, err = GetUserQuota(id) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  | 	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Error(ctx, "Redis set user quota error: "+err.Error()) | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) { | ||||||
| 	if !common.RedisEnabled { | 	if !common.RedisEnabled { | ||||||
| 		return GetUserQuota(id) | 		return GetUserQuota(id) | ||||||
| 	} | 	} | ||||||
| 	quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) | 	quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		quota, err = GetUserQuota(id) | 		return fetchAndUpdateUserQuota(ctx, id) | ||||||
| 		if err != nil { |  | ||||||
| 			return 0, err |  | ||||||
| 		} |  | ||||||
| 		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) |  | ||||||
| 		if err != nil { |  | ||||||
| 			common.SysError("Redis set user quota error: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 		return quota, err |  | ||||||
| 	} | 	} | ||||||
| 	quota, err = strconv.Atoi(quotaString) | 	quota, err = strconv.ParseInt(quotaString, 10, 64) | ||||||
| 	return quota, err | 	if err != nil { | ||||||
|  | 		return 0, nil | ||||||
|  | 	} | ||||||
|  | 	if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db | ||||||
|  | 		logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id) | ||||||
|  | 		return fetchAndUpdateUserQuota(ctx, id) | ||||||
|  | 	} | ||||||
|  | 	return quota, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func CacheUpdateUserQuota(id int) error { | func CacheUpdateUserQuota(ctx context.Context, id int) error { | ||||||
| 	if !common.RedisEnabled { | 	if !common.RedisEnabled { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	quota, err := GetUserQuota(id) | 	quota, err := CacheGetUserQuota(ctx, id) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @@ -100,7 +116,7 @@ func CacheUpdateUserQuota(id int) error { | |||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func CacheDecreaseUserQuota(id int, quota int) error { | func CacheDecreaseUserQuota(id int, quota int64) error { | ||||||
| 	if !common.RedisEnabled { | 	if !common.RedisEnabled { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| @@ -127,18 +143,37 @@ func CacheIsUserEnabled(userId int) (bool, error) { | |||||||
| 	} | 	} | ||||||
| 	err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) | 	err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("Redis set user enabled error: " + err.Error()) | 		logger.SysError("Redis set user enabled error: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	return userEnabled, err | 	return userEnabled, err | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { | ||||||
|  | 	if !common.RedisEnabled { | ||||||
|  | 		return GetGroupModels(ctx, group) | ||||||
|  | 	} | ||||||
|  | 	modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group)) | ||||||
|  | 	if err == nil { | ||||||
|  | 		return strings.Split(modelsStr, ","), nil | ||||||
|  | 	} | ||||||
|  | 	models, err := GetGroupModels(ctx, group) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.SysError("Redis set group models error: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | 	return models, nil | ||||||
|  | } | ||||||
|  |  | ||||||
| var group2model2channels map[string]map[string][]*Channel | var group2model2channels map[string]map[string][]*Channel | ||||||
| var channelSyncLock sync.RWMutex | var channelSyncLock sync.RWMutex | ||||||
|  |  | ||||||
| func InitChannelCache() { | func InitChannelCache() { | ||||||
| 	newChannelId2channel := make(map[int]*Channel) | 	newChannelId2channel := make(map[int]*Channel) | ||||||
| 	var channels []*Channel | 	var channels []*Channel | ||||||
| 	DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) | 	DB.Where("status = ?", ChannelStatusEnabled).Find(&channels) | ||||||
| 	for _, channel := range channels { | 	for _, channel := range channels { | ||||||
| 		newChannelId2channel[channel.Id] = channel | 		newChannelId2channel[channel.Id] = channel | ||||||
| 	} | 	} | ||||||
| @@ -178,20 +213,20 @@ func InitChannelCache() { | |||||||
| 	channelSyncLock.Lock() | 	channelSyncLock.Lock() | ||||||
| 	group2model2channels = newGroup2model2channels | 	group2model2channels = newGroup2model2channels | ||||||
| 	channelSyncLock.Unlock() | 	channelSyncLock.Unlock() | ||||||
| 	common.SysLog("channels synced from database") | 	logger.SysLog("channels synced from database") | ||||||
| } | } | ||||||
|  |  | ||||||
| func SyncChannelCache(frequency int) { | func SyncChannelCache(frequency int) { | ||||||
| 	for { | 	for { | ||||||
| 		time.Sleep(time.Duration(frequency) * time.Second) | 		time.Sleep(time.Duration(frequency) * time.Second) | ||||||
| 		common.SysLog("syncing channels from database") | 		logger.SysLog("syncing channels from database") | ||||||
| 		InitChannelCache() | 		InitChannelCache() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { | ||||||
| 	if !common.MemoryCacheEnabled { | 	if !config.MemoryCacheEnabled { | ||||||
| 		return GetRandomSatisfiedChannel(group, model) | 		return GetRandomSatisfiedChannel(group, model, ignoreFirstPriority) | ||||||
| 	} | 	} | ||||||
| 	channelSyncLock.RLock() | 	channelSyncLock.RLock() | ||||||
| 	defer channelSyncLock.RUnlock() | 	defer channelSyncLock.RUnlock() | ||||||
| @@ -211,5 +246,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	idx := rand.Intn(endIdx) | 	idx := rand.Intn(endIdx) | ||||||
|  | 	if ignoreFirstPriority { | ||||||
|  | 		if endIdx < len(channels) { // which means there are more than one priority | ||||||
|  | 			idx = random.RandRange(endIdx, len(channels)) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 	return channels[idx], nil | 	return channels[idx], nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,14 +1,25 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"one-api/common" | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	ChannelStatusUnknown          = 0 | ||||||
|  | 	ChannelStatusEnabled          = 1 // don't use 0, 0 is the default value! | ||||||
|  | 	ChannelStatusManuallyDisabled = 2 // also don't use 0 | ||||||
|  | 	ChannelStatusAutoDisabled     = 3 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Channel struct { | type Channel struct { | ||||||
| 	Id                 int     `json:"id"` | 	Id                 int     `json:"id"` | ||||||
| 	Type               int     `json:"type" gorm:"default:0"` | 	Type               int     `json:"type" gorm:"default:0"` | ||||||
| 	Key                string  `json:"key" gorm:"not null;index"` | 	Key                string  `json:"key" gorm:"type:text"` | ||||||
| 	Status             int     `json:"status" gorm:"default:1"` | 	Status             int     `json:"status" gorm:"default:1"` | ||||||
| 	Name               string  `json:"name" gorm:"index"` | 	Name               string  `json:"name" gorm:"index"` | ||||||
| 	Weight             *uint   `json:"weight" gorm:"default:0"` | 	Weight             *uint   `json:"weight" gorm:"default:0"` | ||||||
| @@ -16,7 +27,7 @@ type Channel struct { | |||||||
| 	TestTime           int64   `json:"test_time" gorm:"bigint"` | 	TestTime           int64   `json:"test_time" gorm:"bigint"` | ||||||
| 	ResponseTime       int     `json:"response_time"` // in milliseconds | 	ResponseTime       int     `json:"response_time"` // in milliseconds | ||||||
| 	BaseURL            *string `json:"base_url" gorm:"column:base_url;default:''"` | 	BaseURL            *string `json:"base_url" gorm:"column:base_url;default:''"` | ||||||
| 	Other              string  `json:"other"` | 	Other              *string `json:"other"`   // DEPRECATED: please save config to field Config | ||||||
| 	Balance            float64 `json:"balance"` // in USD | 	Balance            float64 `json:"balance"` // in USD | ||||||
| 	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"` | 	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"` | ||||||
| 	Models             string  `json:"models"` | 	Models             string  `json:"models"` | ||||||
| @@ -24,25 +35,35 @@ type Channel struct { | |||||||
| 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"` | 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"` | ||||||
| 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | ||||||
| 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"` | 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"` | ||||||
|  | 	Config             string  `json:"config"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | type ChannelConfig struct { | ||||||
|  | 	Region     string `json:"region,omitempty"` | ||||||
|  | 	SK         string `json:"sk,omitempty"` | ||||||
|  | 	AK         string `json:"ak,omitempty"` | ||||||
|  | 	UserID     string `json:"user_id,omitempty"` | ||||||
|  | 	APIVersion string `json:"api_version,omitempty"` | ||||||
|  | 	LibraryID  string `json:"library_id,omitempty"` | ||||||
|  | 	Plugin     string `json:"plugin,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { | ||||||
| 	var channels []*Channel | 	var channels []*Channel | ||||||
| 	var err error | 	var err error | ||||||
| 	if selectAll { | 	switch scope { | ||||||
|  | 	case "all": | ||||||
| 		err = DB.Order("id desc").Find(&channels).Error | 		err = DB.Order("id desc").Find(&channels).Error | ||||||
| 	} else { | 	case "disabled": | ||||||
|  | 		err = DB.Order("id desc").Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Find(&channels).Error | ||||||
|  | 	default: | ||||||
| 		err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error | 		err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error | ||||||
| 	} | 	} | ||||||
| 	return channels, err | 	return channels, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchChannels(keyword string) (channels []*Channel, err error) { | func SearchChannels(keyword string) (channels []*Channel, err error) { | ||||||
| 	keyCol := "`key`" | 	err = DB.Omit("key").Where("id = ? or name LIKE ?", helper.String2Int(keyword), keyword+"%").Find(&channels).Error | ||||||
| 	if common.UsingPostgreSQL { |  | ||||||
| 		keyCol = `"key"` |  | ||||||
| 	} |  | ||||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error |  | ||||||
| 	return channels, err | 	return channels, err | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -86,11 +107,17 @@ func (channel *Channel) GetBaseURL() string { | |||||||
| 	return *channel.BaseURL | 	return *channel.BaseURL | ||||||
| } | } | ||||||
|  |  | ||||||
| func (channel *Channel) GetModelMapping() string { | func (channel *Channel) GetModelMapping() map[string]string { | ||||||
| 	if channel.ModelMapping == nil { | 	if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" { | ||||||
| 		return "" | 		return nil | ||||||
| 	} | 	} | ||||||
| 	return *channel.ModelMapping | 	modelMapping := make(map[string]string) | ||||||
|  | 	err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error())) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return modelMapping | ||||||
| } | } | ||||||
|  |  | ||||||
| func (channel *Channel) Insert() error { | func (channel *Channel) Insert() error { | ||||||
| @@ -116,21 +143,21 @@ func (channel *Channel) Update() error { | |||||||
|  |  | ||||||
| func (channel *Channel) UpdateResponseTime(responseTime int64) { | func (channel *Channel) UpdateResponseTime(responseTime int64) { | ||||||
| 	err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ | 	err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ | ||||||
| 		TestTime:     common.GetTimestamp(), | 		TestTime:     helper.GetTimestamp(), | ||||||
| 		ResponseTime: int(responseTime), | 		ResponseTime: int(responseTime), | ||||||
| 	}).Error | 	}).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("failed to update response time: " + err.Error()) | 		logger.SysError("failed to update response time: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (channel *Channel) UpdateBalance(balance float64) { | func (channel *Channel) UpdateBalance(balance float64) { | ||||||
| 	err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ | 	err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ | ||||||
| 		BalanceUpdatedTime: common.GetTimestamp(), | 		BalanceUpdatedTime: helper.GetTimestamp(), | ||||||
| 		Balance:            balance, | 		Balance:            balance, | ||||||
| 	}).Error | 	}).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("failed to update balance: " + err.Error()) | 		logger.SysError("failed to update balance: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -144,29 +171,41 @@ func (channel *Channel) Delete() error { | |||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateChannelStatusById(id int, status int) { | func (channel *Channel) LoadConfig() (ChannelConfig, error) { | ||||||
| 	err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) | 	var cfg ChannelConfig | ||||||
|  | 	if channel.Config == "" { | ||||||
|  | 		return cfg, nil | ||||||
|  | 	} | ||||||
|  | 	err := json.Unmarshal([]byte(channel.Config), &cfg) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("failed to update ability status: " + err.Error()) | 		return cfg, err | ||||||
|  | 	} | ||||||
|  | 	return cfg, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func UpdateChannelStatusById(id int, status int) { | ||||||
|  | 	err := UpdateAbilityStatus(id, status == ChannelStatusEnabled) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.SysError("failed to update ability status: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error | 	err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("failed to update channel status: " + err.Error()) | 		logger.SysError("failed to update channel status: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateChannelUsedQuota(id int, quota int) { | func UpdateChannelUsedQuota(id int, quota int64) { | ||||||
| 	if common.BatchUpdateEnabled { | 	if config.BatchUpdateEnabled { | ||||||
| 		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) | 		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	updateChannelUsedQuota(id, quota) | 	updateChannelUsedQuota(id, quota) | ||||||
| } | } | ||||||
|  |  | ||||||
| func updateChannelUsedQuota(id int, quota int) { | func updateChannelUsedQuota(id int, quota int64) { | ||||||
| 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("failed to update channel used quota: " + err.Error()) | 		logger.SysError("failed to update channel used quota: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -176,6 +215,6 @@ func DeleteChannelByStatus(status int64) (int64, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func DeleteDisabledChannel() (int64, error) { | func DeleteDisabledChannel() (int64, error) { | ||||||
| 	result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) | 	result := DB.Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Delete(&Channel{}) | ||||||
| 	return result.RowsAffected, result.Error | 	return result.RowsAffected, result.Error | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										69
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										69
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -3,15 +3,17 @@ package model | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"one-api/common" | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Log struct { | type Log struct { | ||||||
| 	Id               int    `json:"id;index:idx_created_at_id,priority:1"` | 	Id               int    `json:"id"` | ||||||
| 	UserId           int    `json:"user_id" gorm:"index"` | 	UserId           int    `json:"user_id" gorm:"index"` | ||||||
| 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"` | 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index:idx_created_at_type"` | ||||||
| 	Type             int    `json:"type" gorm:"index:idx_created_at_type"` | 	Type             int    `json:"type" gorm:"index:idx_created_at_type"` | ||||||
| 	Content          string `json:"content"` | 	Content          string `json:"content"` | ||||||
| 	Username         string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` | 	Username         string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` | ||||||
| @@ -32,52 +34,67 @@ const ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| func RecordLog(userId int, logType int, content string) { | func RecordLog(userId int, logType int, content string) { | ||||||
| 	if logType == LogTypeConsume && !common.LogConsumeEnabled { | 	if logType == LogTypeConsume && !config.LogConsumeEnabled { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	log := &Log{ | 	log := &Log{ | ||||||
| 		UserId:    userId, | 		UserId:    userId, | ||||||
| 		Username:  GetUsernameById(userId), | 		Username:  GetUsernameById(userId), | ||||||
| 		CreatedAt: common.GetTimestamp(), | 		CreatedAt: helper.GetTimestamp(), | ||||||
| 		Type:      logType, | 		Type:      logType, | ||||||
| 		Content:   content, | 		Content:   content, | ||||||
| 	} | 	} | ||||||
| 	err := DB.Create(log).Error | 	err := LOG_DB.Create(log).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("failed to record log: " + err.Error()) | 		logger.SysError("failed to record log: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | func RecordTopupLog(userId int, content string, quota int) { | ||||||
| 	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | 	log := &Log{ | ||||||
| 	if !common.LogConsumeEnabled { | 		UserId:    userId, | ||||||
|  | 		Username:  GetUsernameById(userId), | ||||||
|  | 		CreatedAt: helper.GetTimestamp(), | ||||||
|  | 		Type:      LogTypeTopup, | ||||||
|  | 		Content:   content, | ||||||
|  | 		Quota:     quota, | ||||||
|  | 	} | ||||||
|  | 	err := LOG_DB.Create(log).Error | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.SysError("failed to record log: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { | ||||||
|  | 	logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||||
|  | 	if !config.LogConsumeEnabled { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	log := &Log{ | 	log := &Log{ | ||||||
| 		UserId:           userId, | 		UserId:           userId, | ||||||
| 		Username:         GetUsernameById(userId), | 		Username:         GetUsernameById(userId), | ||||||
| 		CreatedAt:        common.GetTimestamp(), | 		CreatedAt:        helper.GetTimestamp(), | ||||||
| 		Type:             LogTypeConsume, | 		Type:             LogTypeConsume, | ||||||
| 		Content:          content, | 		Content:          content, | ||||||
| 		PromptTokens:     promptTokens, | 		PromptTokens:     promptTokens, | ||||||
| 		CompletionTokens: completionTokens, | 		CompletionTokens: completionTokens, | ||||||
| 		TokenName:        tokenName, | 		TokenName:        tokenName, | ||||||
| 		ModelName:        modelName, | 		ModelName:        modelName, | ||||||
| 		Quota:            quota, | 		Quota:            int(quota), | ||||||
| 		ChannelId:        channelId, | 		ChannelId:        channelId, | ||||||
| 	} | 	} | ||||||
| 	err := DB.Create(log).Error | 	err := LOG_DB.Create(log).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.LogError(ctx, "failed to record log: "+err.Error()) | 		logger.Error(ctx, "failed to record log: "+err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { | func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { | ||||||
| 	var tx *gorm.DB | 	var tx *gorm.DB | ||||||
| 	if logType == LogTypeUnknown { | 	if logType == LogTypeUnknown { | ||||||
| 		tx = DB | 		tx = LOG_DB | ||||||
| 	} else { | 	} else { | ||||||
| 		tx = DB.Where("type = ?", logType) | 		tx = LOG_DB.Where("type = ?", logType) | ||||||
| 	} | 	} | ||||||
| 	if modelName != "" { | 	if modelName != "" { | ||||||
| 		tx = tx.Where("model_name = ?", modelName) | 		tx = tx.Where("model_name = ?", modelName) | ||||||
| @@ -104,9 +121,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName | |||||||
| func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { | func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { | ||||||
| 	var tx *gorm.DB | 	var tx *gorm.DB | ||||||
| 	if logType == LogTypeUnknown { | 	if logType == LogTypeUnknown { | ||||||
| 		tx = DB.Where("user_id = ?", userId) | 		tx = LOG_DB.Where("user_id = ?", userId) | ||||||
| 	} else { | 	} else { | ||||||
| 		tx = DB.Where("user_id = ? and type = ?", userId, logType) | 		tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType) | ||||||
| 	} | 	} | ||||||
| 	if modelName != "" { | 	if modelName != "" { | ||||||
| 		tx = tx.Where("model_name = ?", modelName) | 		tx = tx.Where("model_name = ?", modelName) | ||||||
| @@ -125,17 +142,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int | |||||||
| } | } | ||||||
|  |  | ||||||
| func SearchAllLogs(keyword string) (logs []*Log, err error) { | func SearchAllLogs(keyword string) (logs []*Log, err error) { | ||||||
| 	err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error | 	err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error | ||||||
| 	return logs, err | 	return logs, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | ||||||
| 	err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error | 	err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error | ||||||
| 	return logs, err | 	return logs, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { | func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { | ||||||
| 	tx := DB.Table("logs").Select("ifnull(sum(quota),0)") | 	tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)") | ||||||
| 	if username != "" { | 	if username != "" { | ||||||
| 		tx = tx.Where("username = ?", username) | 		tx = tx.Where("username = ?", username) | ||||||
| 	} | 	} | ||||||
| @@ -159,7 +176,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa | |||||||
| } | } | ||||||
|  |  | ||||||
| func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | ||||||
| 	tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") | 	tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") | ||||||
| 	if username != "" { | 	if username != "" { | ||||||
| 		tx = tx.Where("username = ?", username) | 		tx = tx.Where("username = ?", username) | ||||||
| 	} | 	} | ||||||
| @@ -180,7 +197,7 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa | |||||||
| } | } | ||||||
|  |  | ||||||
| func DeleteOldLog(targetTimestamp int64) (int64, error) { | func DeleteOldLog(targetTimestamp int64) (int64, error) { | ||||||
| 	result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) | 	result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) | ||||||
| 	return result.RowsAffected, result.Error | 	return result.RowsAffected, result.Error | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -204,7 +221,7 @@ func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatis | |||||||
| 		groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day" | 		groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day" | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = DB.Raw(` | 	err = LOG_DB.Raw(` | ||||||
| 		SELECT `+groupSelect+`, | 		SELECT `+groupSelect+`, | ||||||
| 		model_name, count(1) as request_count, | 		model_name, count(1) as request_count, | ||||||
| 		sum(quota) as quota, | 		sum(quota) as quota, | ||||||
|   | |||||||
							
								
								
									
										275
									
								
								model/main.go
									
									
									
									
									
								
							
							
						
						
									
										275
									
								
								model/main.go
									
									
									
									
									
								
							| @@ -1,24 +1,31 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"database/sql" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/env" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/random" | ||||||
| 	"gorm.io/driver/mysql" | 	"gorm.io/driver/mysql" | ||||||
| 	"gorm.io/driver/postgres" | 	"gorm.io/driver/postgres" | ||||||
| 	"gorm.io/driver/sqlite" | 	"gorm.io/driver/sqlite" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"one-api/common" |  | ||||||
| 	"os" | 	"os" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var DB *gorm.DB | var DB *gorm.DB | ||||||
|  | var LOG_DB *gorm.DB | ||||||
|  |  | ||||||
| func createRootAccountIfNeed() error { | func CreateRootAccountIfNeed() error { | ||||||
| 	var user User | 	var user User | ||||||
| 	//if user.Status != common.UserStatusEnabled { | 	//if user.Status != util.UserStatusEnabled { | ||||||
| 	if err := DB.First(&user).Error; err != nil { | 	if err := DB.First(&user).Error; err != nil { | ||||||
| 		common.SysLog("no user exists, create a root user for you: username is root, password is 123456") | 		logger.SysLog("no user exists, creating a root user for you: username is root, password is 123456") | ||||||
| 		hashedPassword, err := common.Password2Hash("123456") | 		hashedPassword, err := common.Password2Hash("123456") | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| @@ -26,107 +33,201 @@ func createRootAccountIfNeed() error { | |||||||
| 		rootUser := User{ | 		rootUser := User{ | ||||||
| 			Username:    "root", | 			Username:    "root", | ||||||
| 			Password:    hashedPassword, | 			Password:    hashedPassword, | ||||||
| 			Role:        common.RoleRootUser, | 			Role:        RoleRootUser, | ||||||
| 			Status:      common.UserStatusEnabled, | 			Status:      UserStatusEnabled, | ||||||
| 			DisplayName: "Root User", | 			DisplayName: "Root User", | ||||||
| 			AccessToken: common.GetUUID(), | 			AccessToken: random.GetUUID(), | ||||||
| 			Quota:       100000000, | 			Quota:       500000000000000, | ||||||
| 		} | 		} | ||||||
| 		DB.Create(&rootUser) | 		DB.Create(&rootUser) | ||||||
|  | 		if config.InitialRootToken != "" { | ||||||
|  | 			logger.SysLog("creating initial root token as requested") | ||||||
|  | 			token := Token{ | ||||||
|  | 				Id:             1, | ||||||
|  | 				UserId:         rootUser.Id, | ||||||
|  | 				Key:            config.InitialRootToken, | ||||||
|  | 				Status:         TokenStatusEnabled, | ||||||
|  | 				Name:           "Initial Root Token", | ||||||
|  | 				CreatedTime:    helper.GetTimestamp(), | ||||||
|  | 				AccessedTime:   helper.GetTimestamp(), | ||||||
|  | 				ExpiredTime:    -1, | ||||||
|  | 				RemainQuota:    500000000000000, | ||||||
|  | 				UnlimitedQuota: true, | ||||||
|  | 			} | ||||||
|  | 			DB.Create(&token) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func chooseDB() (*gorm.DB, error) { | func chooseDB(envName string) (*gorm.DB, error) { | ||||||
| 	if os.Getenv("SQL_DSN") != "" { | 	dsn := os.Getenv(envName) | ||||||
| 		dsn := os.Getenv("SQL_DSN") |  | ||||||
| 		if strings.HasPrefix(dsn, "postgres://") { | 	switch { | ||||||
| 			// Use PostgreSQL | 	case strings.HasPrefix(dsn, "postgres://"): | ||||||
| 			common.SysLog("using PostgreSQL as database") | 		// Use PostgreSQL | ||||||
| 			common.UsingPostgreSQL = true | 		return openPostgreSQL(dsn) | ||||||
| 			return gorm.Open(postgres.New(postgres.Config{ | 	case dsn != "": | ||||||
| 				DSN:                  dsn, |  | ||||||
| 				PreferSimpleProtocol: true, // disables implicit prepared statement usage |  | ||||||
| 			}), &gorm.Config{ |  | ||||||
| 				PrepareStmt: true, // precompile SQL |  | ||||||
| 			}) |  | ||||||
| 		} |  | ||||||
| 		// Use MySQL | 		// Use MySQL | ||||||
| 		common.SysLog("using MySQL as database") | 		return openMySQL(dsn) | ||||||
| 		return gorm.Open(mysql.Open(dsn), &gorm.Config{ | 	default: | ||||||
| 			PrepareStmt: true, // precompile SQL | 		// Use SQLite | ||||||
| 		}) | 		return openSQLite() | ||||||
| 	} | 	} | ||||||
| 	// Use SQLite | } | ||||||
| 	common.SysLog("SQL_DSN not set, using SQLite as database") |  | ||||||
| 	common.UsingSQLite = true | func openPostgreSQL(dsn string) (*gorm.DB, error) { | ||||||
| 	config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) | 	logger.SysLog("using PostgreSQL as database") | ||||||
| 	return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ | 	common.UsingPostgreSQL = true | ||||||
|  | 	return gorm.Open(postgres.New(postgres.Config{ | ||||||
|  | 		DSN:                  dsn, | ||||||
|  | 		PreferSimpleProtocol: true, // disables implicit prepared statement usage | ||||||
|  | 	}), &gorm.Config{ | ||||||
| 		PrepareStmt: true, // precompile SQL | 		PrepareStmt: true, // precompile SQL | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func InitDB() (err error) { | func openMySQL(dsn string) (*gorm.DB, error) { | ||||||
| 	db, err := chooseDB() | 	logger.SysLog("using MySQL as database") | ||||||
| 	if err == nil { | 	common.UsingMySQL = true | ||||||
| 		if common.DebugEnabled { | 	return gorm.Open(mysql.Open(dsn), &gorm.Config{ | ||||||
| 			db = db.Debug() | 		PrepareStmt: true, // precompile SQL | ||||||
| 		} | 	}) | ||||||
| 		DB = db |  | ||||||
| 		sqlDB, err := DB.DB() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) |  | ||||||
| 		sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) |  | ||||||
| 		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) |  | ||||||
|  |  | ||||||
| 		if !common.IsMasterNode { |  | ||||||
| 			return nil |  | ||||||
| 		} |  | ||||||
| 		common.SysLog("database migration started") |  | ||||||
| 		err = db.AutoMigrate(&Channel{}) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		err = db.AutoMigrate(&Token{}) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		err = db.AutoMigrate(&User{}) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		err = db.AutoMigrate(&Option{}) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		err = db.AutoMigrate(&Redemption{}) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		err = db.AutoMigrate(&Ability{}) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		err = db.AutoMigrate(&Log{}) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		common.SysLog("database migrated") |  | ||||||
| 		err = createRootAccountIfNeed() |  | ||||||
| 		return err |  | ||||||
| 	} else { |  | ||||||
| 		common.FatalLog(err) |  | ||||||
| 	} |  | ||||||
| 	return err |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func CloseDB() error { | func openSQLite() (*gorm.DB, error) { | ||||||
| 	sqlDB, err := DB.DB() | 	logger.SysLog("SQL_DSN not set, using SQLite as database") | ||||||
|  | 	common.UsingSQLite = true | ||||||
|  | 	dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout) | ||||||
|  | 	return gorm.Open(sqlite.Open(dsn), &gorm.Config{ | ||||||
|  | 		PrepareStmt: true, // precompile SQL | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func InitDB() { | ||||||
|  | 	var err error | ||||||
|  | 	DB, err = chooseDB("SQL_DSN") | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.FatalLog("failed to initialize database: " + err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	sqlDB := setDBConns(DB) | ||||||
|  |  | ||||||
|  | 	if !config.IsMasterNode { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if common.UsingMySQL { | ||||||
|  | 		_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	logger.SysLog("database migration started") | ||||||
|  | 	if err = migrateDB(); err != nil { | ||||||
|  | 		logger.FatalLog("failed to migrate database: " + err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	logger.SysLog("database migrated") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func migrateDB() error { | ||||||
|  | 	var err error | ||||||
|  | 	if err = DB.AutoMigrate(&Channel{}); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if err = DB.AutoMigrate(&Token{}); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if err = DB.AutoMigrate(&User{}); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if err = DB.AutoMigrate(&Option{}); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if err = DB.AutoMigrate(&Redemption{}); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if err = DB.AutoMigrate(&Ability{}); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if err = DB.AutoMigrate(&Log{}); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if err = DB.AutoMigrate(&Channel{}); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func InitLogDB() { | ||||||
|  | 	if os.Getenv("LOG_SQL_DSN") == "" { | ||||||
|  | 		LOG_DB = DB | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	logger.SysLog("using secondary database for table logs") | ||||||
|  | 	var err error | ||||||
|  | 	LOG_DB, err = chooseDB("LOG_SQL_DSN") | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.FatalLog("failed to initialize secondary database: " + err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	setDBConns(LOG_DB) | ||||||
|  |  | ||||||
|  | 	if !config.IsMasterNode { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	logger.SysLog("secondary database migration started") | ||||||
|  | 	err = migrateLOGDB() | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.FatalLog("failed to migrate secondary database: " + err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	logger.SysLog("secondary database migrated") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func migrateLOGDB() error { | ||||||
|  | 	var err error | ||||||
|  | 	if err = LOG_DB.AutoMigrate(&Log{}); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func setDBConns(db *gorm.DB) *sql.DB { | ||||||
|  | 	if config.DebugSQLEnabled { | ||||||
|  | 		db = db.Debug() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	sqlDB, err := db.DB() | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.FatalLog("failed to connect database: " + err.Error()) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) | ||||||
|  | 	sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) | ||||||
|  | 	sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) | ||||||
|  | 	return sqlDB | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func closeDB(db *gorm.DB) error { | ||||||
|  | 	sqlDB, err := db.DB() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	err = sqlDB.Close() | 	err = sqlDB.Close() | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func CloseDB() error { | ||||||
|  | 	if LOG_DB != DB { | ||||||
|  | 		err := closeDB(LOG_DB) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return closeDB(DB) | ||||||
|  | } | ||||||
|   | |||||||
							
								
								
									
										235
									
								
								model/option.go
									
									
									
									
									
								
							
							
						
						
									
										235
									
								
								model/option.go
									
									
									
									
									
								
							| @@ -1,7 +1,9 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"one-api/common" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -20,69 +22,71 @@ func AllOption() ([]*Option, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func InitOptionMap() { | func InitOptionMap() { | ||||||
| 	common.OptionMapRWMutex.Lock() | 	config.OptionMapRWMutex.Lock() | ||||||
| 	common.OptionMap = make(map[string]string) | 	config.OptionMap = make(map[string]string) | ||||||
| 	common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) | 	config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled) | ||||||
| 	common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) | 	config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) | ||||||
| 	common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) | 	config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) | ||||||
| 	common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) | 	config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) | ||||||
| 	common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) | 	config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) | ||||||
| 	common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) | 	config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) | ||||||
| 	common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) | 	config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) | ||||||
| 	common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) | 	config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled) | ||||||
| 	common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) | 	config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled) | ||||||
| 	common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) | 	config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled) | ||||||
| 	common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) | 	config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled) | ||||||
| 	common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) | 	config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled) | ||||||
| 	common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) | 	config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled) | ||||||
| 	common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) | 	config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64) | ||||||
| 	common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) | 	config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled) | ||||||
| 	common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) | 	config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",") | ||||||
| 	common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) | 	config.OptionMap["SMTPServer"] = "" | ||||||
| 	common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) | 	config.OptionMap["SMTPFrom"] = "" | ||||||
| 	common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) | 	config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort) | ||||||
| 	common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") | 	config.OptionMap["SMTPAccount"] = "" | ||||||
| 	common.OptionMap["SMTPServer"] = "" | 	config.OptionMap["SMTPToken"] = "" | ||||||
| 	common.OptionMap["SMTPFrom"] = "" | 	config.OptionMap["Notice"] = "" | ||||||
| 	common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) | 	config.OptionMap["About"] = "" | ||||||
| 	common.OptionMap["SMTPAccount"] = "" | 	config.OptionMap["HomePageContent"] = "" | ||||||
| 	common.OptionMap["SMTPToken"] = "" | 	config.OptionMap["Footer"] = config.Footer | ||||||
| 	common.OptionMap["Notice"] = "" | 	config.OptionMap["SystemName"] = config.SystemName | ||||||
| 	common.OptionMap["About"] = "" | 	config.OptionMap["Logo"] = config.Logo | ||||||
| 	common.OptionMap["HomePageContent"] = "" | 	config.OptionMap["ServerAddress"] = "" | ||||||
| 	common.OptionMap["Footer"] = common.Footer | 	config.OptionMap["GitHubClientId"] = "" | ||||||
| 	common.OptionMap["SystemName"] = common.SystemName | 	config.OptionMap["GitHubClientSecret"] = "" | ||||||
| 	common.OptionMap["Logo"] = common.Logo | 	config.OptionMap["WeChatServerAddress"] = "" | ||||||
| 	common.OptionMap["ServerAddress"] = "" | 	config.OptionMap["WeChatServerToken"] = "" | ||||||
| 	common.OptionMap["GitHubClientId"] = "" | 	config.OptionMap["WeChatAccountQRCodeImageURL"] = "" | ||||||
| 	common.OptionMap["GitHubClientSecret"] = "" | 	config.OptionMap["MessagePusherAddress"] = "" | ||||||
| 	common.OptionMap["WeChatServerAddress"] = "" | 	config.OptionMap["MessagePusherToken"] = "" | ||||||
| 	common.OptionMap["WeChatServerToken"] = "" | 	config.OptionMap["TurnstileSiteKey"] = "" | ||||||
| 	common.OptionMap["WeChatAccountQRCodeImageURL"] = "" | 	config.OptionMap["TurnstileSecretKey"] = "" | ||||||
| 	common.OptionMap["TurnstileSiteKey"] = "" | 	config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10) | ||||||
| 	common.OptionMap["TurnstileSecretKey"] = "" | 	config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10) | ||||||
| 	common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) | 	config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10) | ||||||
| 	common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) | 	config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10) | ||||||
| 	common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) | 	config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10) | ||||||
| 	common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) | 	config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString() | ||||||
| 	common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) | 	config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString() | ||||||
| 	common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() | 	config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString() | ||||||
| 	common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() | 	config.OptionMap["TopUpLink"] = config.TopUpLink | ||||||
| 	common.OptionMap["TopUpLink"] = common.TopUpLink | 	config.OptionMap["ChatLink"] = config.ChatLink | ||||||
| 	common.OptionMap["ChatLink"] = common.ChatLink | 	config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) | ||||||
| 	common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) | 	config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes) | ||||||
| 	common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) | 	config.OptionMap["Theme"] = config.Theme | ||||||
| 	common.OptionMap["Theme"] = common.Theme | 	config.OptionMapRWMutex.Unlock() | ||||||
| 	common.OptionMapRWMutex.Unlock() |  | ||||||
| 	loadOptionsFromDatabase() | 	loadOptionsFromDatabase() | ||||||
| } | } | ||||||
|  |  | ||||||
| func loadOptionsFromDatabase() { | func loadOptionsFromDatabase() { | ||||||
| 	options, _ := AllOption() | 	options, _ := AllOption() | ||||||
| 	for _, option := range options { | 	for _, option := range options { | ||||||
|  | 		if option.Key == "ModelRatio" { | ||||||
|  | 			option.Value = billingratio.AddNewMissingRatio(option.Value) | ||||||
|  | 		} | ||||||
| 		err := updateOptionMap(option.Key, option.Value) | 		err := updateOptionMap(option.Key, option.Value) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.SysError("failed to update option map: " + err.Error()) | 			logger.SysError("failed to update option map: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -90,7 +94,7 @@ func loadOptionsFromDatabase() { | |||||||
| func SyncOptions(frequency int) { | func SyncOptions(frequency int) { | ||||||
| 	for { | 	for { | ||||||
| 		time.Sleep(time.Duration(frequency) * time.Second) | 		time.Sleep(time.Duration(frequency) * time.Second) | ||||||
| 		common.SysLog("syncing options from database") | 		logger.SysLog("syncing options from database") | ||||||
| 		loadOptionsFromDatabase() | 		loadOptionsFromDatabase() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -112,117 +116,114 @@ func UpdateOption(key string, value string) error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func updateOptionMap(key string, value string) (err error) { | func updateOptionMap(key string, value string) (err error) { | ||||||
| 	common.OptionMapRWMutex.Lock() | 	config.OptionMapRWMutex.Lock() | ||||||
| 	defer common.OptionMapRWMutex.Unlock() | 	defer config.OptionMapRWMutex.Unlock() | ||||||
| 	common.OptionMap[key] = value | 	config.OptionMap[key] = value | ||||||
| 	if strings.HasSuffix(key, "Permission") { |  | ||||||
| 		intValue, _ := strconv.Atoi(value) |  | ||||||
| 		switch key { |  | ||||||
| 		case "FileUploadPermission": |  | ||||||
| 			common.FileUploadPermission = intValue |  | ||||||
| 		case "FileDownloadPermission": |  | ||||||
| 			common.FileDownloadPermission = intValue |  | ||||||
| 		case "ImageUploadPermission": |  | ||||||
| 			common.ImageUploadPermission = intValue |  | ||||||
| 		case "ImageDownloadPermission": |  | ||||||
| 			common.ImageDownloadPermission = intValue |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	if strings.HasSuffix(key, "Enabled") { | 	if strings.HasSuffix(key, "Enabled") { | ||||||
| 		boolValue := value == "true" | 		boolValue := value == "true" | ||||||
| 		switch key { | 		switch key { | ||||||
| 		case "PasswordRegisterEnabled": | 		case "PasswordRegisterEnabled": | ||||||
| 			common.PasswordRegisterEnabled = boolValue | 			config.PasswordRegisterEnabled = boolValue | ||||||
| 		case "PasswordLoginEnabled": | 		case "PasswordLoginEnabled": | ||||||
| 			common.PasswordLoginEnabled = boolValue | 			config.PasswordLoginEnabled = boolValue | ||||||
| 		case "EmailVerificationEnabled": | 		case "EmailVerificationEnabled": | ||||||
| 			common.EmailVerificationEnabled = boolValue | 			config.EmailVerificationEnabled = boolValue | ||||||
| 		case "GitHubOAuthEnabled": | 		case "GitHubOAuthEnabled": | ||||||
| 			common.GitHubOAuthEnabled = boolValue | 			config.GitHubOAuthEnabled = boolValue | ||||||
| 		case "WeChatAuthEnabled": | 		case "WeChatAuthEnabled": | ||||||
| 			common.WeChatAuthEnabled = boolValue | 			config.WeChatAuthEnabled = boolValue | ||||||
| 		case "TurnstileCheckEnabled": | 		case "TurnstileCheckEnabled": | ||||||
| 			common.TurnstileCheckEnabled = boolValue | 			config.TurnstileCheckEnabled = boolValue | ||||||
| 		case "RegisterEnabled": | 		case "RegisterEnabled": | ||||||
| 			common.RegisterEnabled = boolValue | 			config.RegisterEnabled = boolValue | ||||||
| 		case "EmailDomainRestrictionEnabled": | 		case "EmailDomainRestrictionEnabled": | ||||||
| 			common.EmailDomainRestrictionEnabled = boolValue | 			config.EmailDomainRestrictionEnabled = boolValue | ||||||
| 		case "AutomaticDisableChannelEnabled": | 		case "AutomaticDisableChannelEnabled": | ||||||
| 			common.AutomaticDisableChannelEnabled = boolValue | 			config.AutomaticDisableChannelEnabled = boolValue | ||||||
| 		case "AutomaticEnableChannelEnabled": | 		case "AutomaticEnableChannelEnabled": | ||||||
| 			common.AutomaticEnableChannelEnabled = boolValue | 			config.AutomaticEnableChannelEnabled = boolValue | ||||||
| 		case "ApproximateTokenEnabled": | 		case "ApproximateTokenEnabled": | ||||||
| 			common.ApproximateTokenEnabled = boolValue | 			config.ApproximateTokenEnabled = boolValue | ||||||
| 		case "LogConsumeEnabled": | 		case "LogConsumeEnabled": | ||||||
| 			common.LogConsumeEnabled = boolValue | 			config.LogConsumeEnabled = boolValue | ||||||
| 		case "DisplayInCurrencyEnabled": | 		case "DisplayInCurrencyEnabled": | ||||||
| 			common.DisplayInCurrencyEnabled = boolValue | 			config.DisplayInCurrencyEnabled = boolValue | ||||||
| 		case "DisplayTokenStatEnabled": | 		case "DisplayTokenStatEnabled": | ||||||
| 			common.DisplayTokenStatEnabled = boolValue | 			config.DisplayTokenStatEnabled = boolValue | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	switch key { | 	switch key { | ||||||
| 	case "EmailDomainWhitelist": | 	case "EmailDomainWhitelist": | ||||||
| 		common.EmailDomainWhitelist = strings.Split(value, ",") | 		config.EmailDomainWhitelist = strings.Split(value, ",") | ||||||
| 	case "SMTPServer": | 	case "SMTPServer": | ||||||
| 		common.SMTPServer = value | 		config.SMTPServer = value | ||||||
| 	case "SMTPPort": | 	case "SMTPPort": | ||||||
| 		intValue, _ := strconv.Atoi(value) | 		intValue, _ := strconv.Atoi(value) | ||||||
| 		common.SMTPPort = intValue | 		config.SMTPPort = intValue | ||||||
| 	case "SMTPAccount": | 	case "SMTPAccount": | ||||||
| 		common.SMTPAccount = value | 		config.SMTPAccount = value | ||||||
| 	case "SMTPFrom": | 	case "SMTPFrom": | ||||||
| 		common.SMTPFrom = value | 		config.SMTPFrom = value | ||||||
| 	case "SMTPToken": | 	case "SMTPToken": | ||||||
| 		common.SMTPToken = value | 		config.SMTPToken = value | ||||||
| 	case "ServerAddress": | 	case "ServerAddress": | ||||||
| 		common.ServerAddress = value | 		config.ServerAddress = value | ||||||
| 	case "GitHubClientId": | 	case "GitHubClientId": | ||||||
| 		common.GitHubClientId = value | 		config.GitHubClientId = value | ||||||
| 	case "GitHubClientSecret": | 	case "GitHubClientSecret": | ||||||
| 		common.GitHubClientSecret = value | 		config.GitHubClientSecret = value | ||||||
|  | 	case "LarkClientId": | ||||||
|  | 		config.LarkClientId = value | ||||||
|  | 	case "LarkClientSecret": | ||||||
|  | 		config.LarkClientSecret = value | ||||||
| 	case "Footer": | 	case "Footer": | ||||||
| 		common.Footer = value | 		config.Footer = value | ||||||
| 	case "SystemName": | 	case "SystemName": | ||||||
| 		common.SystemName = value | 		config.SystemName = value | ||||||
| 	case "Logo": | 	case "Logo": | ||||||
| 		common.Logo = value | 		config.Logo = value | ||||||
| 	case "WeChatServerAddress": | 	case "WeChatServerAddress": | ||||||
| 		common.WeChatServerAddress = value | 		config.WeChatServerAddress = value | ||||||
| 	case "WeChatServerToken": | 	case "WeChatServerToken": | ||||||
| 		common.WeChatServerToken = value | 		config.WeChatServerToken = value | ||||||
| 	case "WeChatAccountQRCodeImageURL": | 	case "WeChatAccountQRCodeImageURL": | ||||||
| 		common.WeChatAccountQRCodeImageURL = value | 		config.WeChatAccountQRCodeImageURL = value | ||||||
|  | 	case "MessagePusherAddress": | ||||||
|  | 		config.MessagePusherAddress = value | ||||||
|  | 	case "MessagePusherToken": | ||||||
|  | 		config.MessagePusherToken = value | ||||||
| 	case "TurnstileSiteKey": | 	case "TurnstileSiteKey": | ||||||
| 		common.TurnstileSiteKey = value | 		config.TurnstileSiteKey = value | ||||||
| 	case "TurnstileSecretKey": | 	case "TurnstileSecretKey": | ||||||
| 		common.TurnstileSecretKey = value | 		config.TurnstileSecretKey = value | ||||||
| 	case "QuotaForNewUser": | 	case "QuotaForNewUser": | ||||||
| 		common.QuotaForNewUser, _ = strconv.Atoi(value) | 		config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64) | ||||||
| 	case "QuotaForInviter": | 	case "QuotaForInviter": | ||||||
| 		common.QuotaForInviter, _ = strconv.Atoi(value) | 		config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64) | ||||||
| 	case "QuotaForInvitee": | 	case "QuotaForInvitee": | ||||||
| 		common.QuotaForInvitee, _ = strconv.Atoi(value) | 		config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64) | ||||||
| 	case "QuotaRemindThreshold": | 	case "QuotaRemindThreshold": | ||||||
| 		common.QuotaRemindThreshold, _ = strconv.Atoi(value) | 		config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64) | ||||||
| 	case "PreConsumedQuota": | 	case "PreConsumedQuota": | ||||||
| 		common.PreConsumedQuota, _ = strconv.Atoi(value) | 		config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64) | ||||||
| 	case "RetryTimes": | 	case "RetryTimes": | ||||||
| 		common.RetryTimes, _ = strconv.Atoi(value) | 		config.RetryTimes, _ = strconv.Atoi(value) | ||||||
| 	case "ModelRatio": | 	case "ModelRatio": | ||||||
| 		err = common.UpdateModelRatioByJSONString(value) | 		err = billingratio.UpdateModelRatioByJSONString(value) | ||||||
| 	case "GroupRatio": | 	case "GroupRatio": | ||||||
| 		err = common.UpdateGroupRatioByJSONString(value) | 		err = billingratio.UpdateGroupRatioByJSONString(value) | ||||||
|  | 	case "CompletionRatio": | ||||||
|  | 		err = billingratio.UpdateCompletionRatioByJSONString(value) | ||||||
| 	case "TopUpLink": | 	case "TopUpLink": | ||||||
| 		common.TopUpLink = value | 		config.TopUpLink = value | ||||||
| 	case "ChatLink": | 	case "ChatLink": | ||||||
| 		common.ChatLink = value | 		config.ChatLink = value | ||||||
| 	case "ChannelDisableThreshold": | 	case "ChannelDisableThreshold": | ||||||
| 		common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) | 		config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) | ||||||
| 	case "QuotaPerUnit": | 	case "QuotaPerUnit": | ||||||
| 		common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) | 		config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) | ||||||
| 	case "Theme": | 	case "Theme": | ||||||
| 		common.Theme = value | 		config.Theme = value | ||||||
| 	} | 	} | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|   | |||||||
| @@ -3,8 +3,15 @@ package model | |||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"one-api/common" | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	RedemptionCodeStatusEnabled  = 1 // don't use 0, 0 is the default value! | ||||||
|  | 	RedemptionCodeStatusDisabled = 2 // also don't use 0 | ||||||
|  | 	RedemptionCodeStatusUsed     = 3 // also don't use 0 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Redemption struct { | type Redemption struct { | ||||||
| @@ -13,7 +20,7 @@ type Redemption struct { | |||||||
| 	Key          string `json:"key" gorm:"type:char(32);uniqueIndex"` | 	Key          string `json:"key" gorm:"type:char(32);uniqueIndex"` | ||||||
| 	Status       int    `json:"status" gorm:"default:1"` | 	Status       int    `json:"status" gorm:"default:1"` | ||||||
| 	Name         string `json:"name" gorm:"index"` | 	Name         string `json:"name" gorm:"index"` | ||||||
| 	Quota        int    `json:"quota" gorm:"default:100"` | 	Quota        int64  `json:"quota" gorm:"bigint;default:100"` | ||||||
| 	CreatedTime  int64  `json:"created_time" gorm:"bigint"` | 	CreatedTime  int64  `json:"created_time" gorm:"bigint"` | ||||||
| 	RedeemedTime int64  `json:"redeemed_time" gorm:"bigint"` | 	RedeemedTime int64  `json:"redeemed_time" gorm:"bigint"` | ||||||
| 	Count        int    `json:"count" gorm:"-:all"` // only for api request | 	Count        int    `json:"count" gorm:"-:all"` // only for api request | ||||||
| @@ -41,7 +48,7 @@ func GetRedemptionById(id int) (*Redemption, error) { | |||||||
| 	return &redemption, err | 	return &redemption, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func Redeem(key string, userId int) (quota int, err error) { | func Redeem(key string, userId int) (quota int64, err error) { | ||||||
| 	if key == "" { | 	if key == "" { | ||||||
| 		return 0, errors.New("未提供兑换码") | 		return 0, errors.New("未提供兑换码") | ||||||
| 	} | 	} | ||||||
| @@ -60,15 +67,15 @@ func Redeem(key string, userId int) (quota int, err error) { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errors.New("无效的兑换码") | 			return errors.New("无效的兑换码") | ||||||
| 		} | 		} | ||||||
| 		if redemption.Status != common.RedemptionCodeStatusEnabled { | 		if redemption.Status != RedemptionCodeStatusEnabled { | ||||||
| 			return errors.New("该兑换码已被使用") | 			return errors.New("该兑换码已被使用") | ||||||
| 		} | 		} | ||||||
| 		err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error | 		err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 		redemption.RedeemedTime = common.GetTimestamp() | 		redemption.RedeemedTime = helper.GetTimestamp() | ||||||
| 		redemption.Status = common.RedemptionCodeStatusUsed | 		redemption.Status = RedemptionCodeStatusUsed | ||||||
| 		err = tx.Save(redemption).Error | 		err = tx.Save(redemption).Error | ||||||
| 		return err | 		return err | ||||||
| 	}) | 	}) | ||||||
|   | |||||||
							
								
								
									
										104
									
								
								model/token.go
									
									
									
									
									
								
							
							
						
						
									
										104
									
								
								model/token.go
									
									
									
									
									
								
							| @@ -3,28 +3,52 @@ package model | |||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/message" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"one-api/common" | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	TokenStatusEnabled   = 1 // don't use 0, 0 is the default value! | ||||||
|  | 	TokenStatusDisabled  = 2 // also don't use 0 | ||||||
|  | 	TokenStatusExpired   = 3 | ||||||
|  | 	TokenStatusExhausted = 4 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Token struct { | type Token struct { | ||||||
| 	Id             int    `json:"id"` | 	Id             int     `json:"id"` | ||||||
| 	UserId         int    `json:"user_id"` | 	UserId         int     `json:"user_id"` | ||||||
| 	Key            string `json:"key" gorm:"type:char(48);uniqueIndex"` | 	Key            string  `json:"key" gorm:"type:char(48);uniqueIndex"` | ||||||
| 	Status         int    `json:"status" gorm:"default:1"` | 	Status         int     `json:"status" gorm:"default:1"` | ||||||
| 	Name           string `json:"name" gorm:"index" ` | 	Name           string  `json:"name" gorm:"index" ` | ||||||
| 	CreatedTime    int64  `json:"created_time" gorm:"bigint"` | 	CreatedTime    int64   `json:"created_time" gorm:"bigint"` | ||||||
| 	AccessedTime   int64  `json:"accessed_time" gorm:"bigint"` | 	AccessedTime   int64   `json:"accessed_time" gorm:"bigint"` | ||||||
| 	ExpiredTime    int64  `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired | 	ExpiredTime    int64   `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired | ||||||
| 	RemainQuota    int    `json:"remain_quota" gorm:"default:0"` | 	RemainQuota    int64   `json:"remain_quota" gorm:"bigint;default:0"` | ||||||
| 	UnlimitedQuota bool   `json:"unlimited_quota" gorm:"default:false"` | 	UnlimitedQuota bool    `json:"unlimited_quota" gorm:"default:false"` | ||||||
| 	UsedQuota      int    `json:"used_quota" gorm:"default:0"` // used quota | 	UsedQuota      int64   `json:"used_quota" gorm:"bigint;default:0"` // used quota | ||||||
|  | 	Models         *string `json:"models" gorm:"default:''"`           // allowed models | ||||||
|  | 	Subnet         *string `json:"subnet" gorm:"default:''"`           // allowed subnet | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { | func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) { | ||||||
| 	var tokens []*Token | 	var tokens []*Token | ||||||
| 	var err error | 	var err error | ||||||
| 	err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error | 	query := DB.Where("user_id = ?", userId) | ||||||
|  |  | ||||||
|  | 	switch order { | ||||||
|  | 	case "remain_quota": | ||||||
|  | 		query = query.Order("unlimited_quota desc, remain_quota desc") | ||||||
|  | 	case "used_quota": | ||||||
|  | 		query = query.Order("used_quota desc") | ||||||
|  | 	default: | ||||||
|  | 		query = query.Order("id desc") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = query.Limit(num).Offset(startIdx).Find(&tokens).Error | ||||||
| 	return tokens, err | 	return tokens, err | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -39,26 +63,26 @@ func ValidateUserToken(key string) (token *Token, err error) { | |||||||
| 	} | 	} | ||||||
| 	token, err = CacheGetTokenByKey(key) | 	token, err = CacheGetTokenByKey(key) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("CacheGetTokenByKey failed: " + err.Error()) | 		logger.SysError("CacheGetTokenByKey failed: " + err.Error()) | ||||||
| 		if errors.Is(err, gorm.ErrRecordNotFound) { | 		if errors.Is(err, gorm.ErrRecordNotFound) { | ||||||
| 			return nil, errors.New("无效的令牌") | 			return nil, errors.New("无效的令牌") | ||||||
| 		} | 		} | ||||||
| 		return nil, errors.New("令牌验证失败") | 		return nil, errors.New("令牌验证失败") | ||||||
| 	} | 	} | ||||||
| 	if token.Status == common.TokenStatusExhausted { | 	if token.Status == TokenStatusExhausted { | ||||||
| 		return nil, errors.New("该令牌额度已用尽") | 		return nil, fmt.Errorf("令牌 %s(#%d)额度已用尽", token.Name, token.Id) | ||||||
| 	} else if token.Status == common.TokenStatusExpired { | 	} else if token.Status == TokenStatusExpired { | ||||||
| 		return nil, errors.New("该令牌已过期") | 		return nil, errors.New("该令牌已过期") | ||||||
| 	} | 	} | ||||||
| 	if token.Status != common.TokenStatusEnabled { | 	if token.Status != TokenStatusEnabled { | ||||||
| 		return nil, errors.New("该令牌状态不可用") | 		return nil, errors.New("该令牌状态不可用") | ||||||
| 	} | 	} | ||||||
| 	if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { | 	if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() { | ||||||
| 		if !common.RedisEnabled { | 		if !common.RedisEnabled { | ||||||
| 			token.Status = common.TokenStatusExpired | 			token.Status = TokenStatusExpired | ||||||
| 			err := token.SelectUpdate() | 			err := token.SelectUpdate() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				common.SysError("failed to update token status" + err.Error()) | 				logger.SysError("failed to update token status" + err.Error()) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		return nil, errors.New("该令牌已过期") | 		return nil, errors.New("该令牌已过期") | ||||||
| @@ -66,10 +90,10 @@ func ValidateUserToken(key string) (token *Token, err error) { | |||||||
| 	if !token.UnlimitedQuota && token.RemainQuota <= 0 { | 	if !token.UnlimitedQuota && token.RemainQuota <= 0 { | ||||||
| 		if !common.RedisEnabled { | 		if !common.RedisEnabled { | ||||||
| 			// in this case, we can make sure the token is exhausted | 			// in this case, we can make sure the token is exhausted | ||||||
| 			token.Status = common.TokenStatusExhausted | 			token.Status = TokenStatusExhausted | ||||||
| 			err := token.SelectUpdate() | 			err := token.SelectUpdate() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				common.SysError("failed to update token status" + err.Error()) | 				logger.SysError("failed to update token status" + err.Error()) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		return nil, errors.New("该令牌额度已用尽") | 		return nil, errors.New("该令牌额度已用尽") | ||||||
| @@ -106,7 +130,7 @@ func (token *Token) Insert() error { | |||||||
| // Update Make sure your token's fields is completed, because this will update non-zero values | // Update Make sure your token's fields is completed, because this will update non-zero values | ||||||
| func (token *Token) Update() error { | func (token *Token) Update() error { | ||||||
| 	var err error | 	var err error | ||||||
| 	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error | 	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -134,51 +158,51 @@ func DeleteTokenById(id int, userId int) (err error) { | |||||||
| 	return token.Delete() | 	return token.Delete() | ||||||
| } | } | ||||||
|  |  | ||||||
| func IncreaseTokenQuota(id int, quota int) (err error) { | func IncreaseTokenQuota(id int, quota int64) (err error) { | ||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
| 	if common.BatchUpdateEnabled { | 	if config.BatchUpdateEnabled { | ||||||
| 		addNewRecord(BatchUpdateTypeTokenQuota, id, quota) | 		addNewRecord(BatchUpdateTypeTokenQuota, id, quota) | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	return increaseTokenQuota(id, quota) | 	return increaseTokenQuota(id, quota) | ||||||
| } | } | ||||||
|  |  | ||||||
| func increaseTokenQuota(id int, quota int) (err error) { | func increaseTokenQuota(id int, quota int64) (err error) { | ||||||
| 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"remain_quota":  gorm.Expr("remain_quota + ?", quota), | 			"remain_quota":  gorm.Expr("remain_quota + ?", quota), | ||||||
| 			"used_quota":    gorm.Expr("used_quota - ?", quota), | 			"used_quota":    gorm.Expr("used_quota - ?", quota), | ||||||
| 			"accessed_time": common.GetTimestamp(), | 			"accessed_time": helper.GetTimestamp(), | ||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func DecreaseTokenQuota(id int, quota int) (err error) { | func DecreaseTokenQuota(id int, quota int64) (err error) { | ||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
| 	if common.BatchUpdateEnabled { | 	if config.BatchUpdateEnabled { | ||||||
| 		addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) | 		addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	return decreaseTokenQuota(id, quota) | 	return decreaseTokenQuota(id, quota) | ||||||
| } | } | ||||||
|  |  | ||||||
| func decreaseTokenQuota(id int, quota int) (err error) { | func decreaseTokenQuota(id int, quota int64) (err error) { | ||||||
| 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"remain_quota":  gorm.Expr("remain_quota - ?", quota), | 			"remain_quota":  gorm.Expr("remain_quota - ?", quota), | ||||||
| 			"used_quota":    gorm.Expr("used_quota + ?", quota), | 			"used_quota":    gorm.Expr("used_quota + ?", quota), | ||||||
| 			"accessed_time": common.GetTimestamp(), | 			"accessed_time": helper.GetTimestamp(), | ||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func PreConsumeTokenQuota(tokenId int, quota int) (err error) { | func PreConsumeTokenQuota(tokenId int, quota int64) (err error) { | ||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
| @@ -196,24 +220,24 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { | |||||||
| 	if userQuota < quota { | 	if userQuota < quota { | ||||||
| 		return errors.New("用户额度不足") | 		return errors.New("用户额度不足") | ||||||
| 	} | 	} | ||||||
| 	quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold | 	quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold | ||||||
| 	noMoreQuota := userQuota-quota <= 0 | 	noMoreQuota := userQuota-quota <= 0 | ||||||
| 	if quotaTooLow || noMoreQuota { | 	if quotaTooLow || noMoreQuota { | ||||||
| 		go func() { | 		go func() { | ||||||
| 			email, err := GetUserEmail(token.UserId) | 			email, err := GetUserEmail(token.UserId) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				common.SysError("failed to fetch user email: " + err.Error()) | 				logger.SysError("failed to fetch user email: " + err.Error()) | ||||||
| 			} | 			} | ||||||
| 			prompt := "您的额度即将用尽" | 			prompt := "您的额度即将用尽" | ||||||
| 			if noMoreQuota { | 			if noMoreQuota { | ||||||
| 				prompt = "您的额度已用尽" | 				prompt = "您的额度已用尽" | ||||||
| 			} | 			} | ||||||
| 			if email != "" { | 			if email != "" { | ||||||
| 				topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) | 				topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) | ||||||
| 				err = common.SendEmail(prompt, email, | 				err = message.SendEmail(prompt, email, | ||||||
| 					fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) | 					fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					common.SysError("failed to send email" + err.Error()) | 					logger.SysError("failed to send email" + err.Error()) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		}() | 		}() | ||||||
| @@ -228,7 +252,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { | |||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func PostConsumeTokenQuota(tokenId int, quota int) (err error) { | func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { | ||||||
| 	token, err := GetTokenById(tokenId) | 	token, err := GetTokenById(tokenId) | ||||||
| 	if quota > 0 { | 	if quota > 0 { | ||||||
| 		err = DecreaseUserQuota(token.UserId, quota) | 		err = DecreaseUserQuota(token.UserId, quota) | ||||||
|   | |||||||
							
								
								
									
										156
									
								
								model/user.go
									
									
									
									
									
								
							
							
						
						
									
										156
									
								
								model/user.go
									
									
									
									
									
								
							| @@ -3,11 +3,29 @@ package model | |||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/blacklist" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/helper" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/random" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"one-api/common" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	RoleGuestUser  = 0 | ||||||
|  | 	RoleCommonUser = 1 | ||||||
|  | 	RoleAdminUser  = 10 | ||||||
|  | 	RoleRootUser   = 100 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	UserStatusEnabled  = 1 // don't use 0, 0 is the default value! | ||||||
|  | 	UserStatusDisabled = 2 // also don't use 0 | ||||||
|  | 	UserStatusDeleted  = 3 | ||||||
|  | ) | ||||||
|  |  | ||||||
| // User if you add sensitive fields, don't forget to clean them in setupLogin function. | // User if you add sensitive fields, don't forget to clean them in setupLogin function. | ||||||
| // Otherwise, the sensitive information will be saved on local storage in plain text! | // Otherwise, the sensitive information will be saved on local storage in plain text! | ||||||
| type User struct { | type User struct { | ||||||
| @@ -15,16 +33,17 @@ type User struct { | |||||||
| 	Username         string `json:"username" gorm:"unique;index" validate:"max=12"` | 	Username         string `json:"username" gorm:"unique;index" validate:"max=12"` | ||||||
| 	Password         string `json:"password" gorm:"not null;" validate:"min=8,max=20"` | 	Password         string `json:"password" gorm:"not null;" validate:"min=8,max=20"` | ||||||
| 	DisplayName      string `json:"display_name" gorm:"index" validate:"max=20"` | 	DisplayName      string `json:"display_name" gorm:"index" validate:"max=20"` | ||||||
| 	Role             int    `json:"role" gorm:"type:int;default:1"`   // admin, common | 	Role             int    `json:"role" gorm:"type:int;default:1"`   // admin, util | ||||||
| 	Status           int    `json:"status" gorm:"type:int;default:1"` // enabled, disabled | 	Status           int    `json:"status" gorm:"type:int;default:1"` // enabled, disabled | ||||||
| 	Email            string `json:"email" gorm:"index" validate:"max=50"` | 	Email            string `json:"email" gorm:"index" validate:"max=50"` | ||||||
| 	GitHubId         string `json:"github_id" gorm:"column:github_id;index"` | 	GitHubId         string `json:"github_id" gorm:"column:github_id;index"` | ||||||
| 	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"` | 	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"` | ||||||
|  | 	LarkId           string `json:"lark_id" gorm:"column:lark_id;index"` | ||||||
| 	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database! | 	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database! | ||||||
| 	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management | 	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management | ||||||
| 	Quota            int    `json:"quota" gorm:"type:int;default:0"` | 	Quota            int64  `json:"quota" gorm:"bigint;default:0"` | ||||||
| 	UsedQuota        int    `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota | 	UsedQuota        int64  `json:"used_quota" gorm:"bigint;default:0;column:used_quota"` // used quota | ||||||
| 	RequestCount     int    `json:"request_count" gorm:"type:int;default:0;"`               // request number | 	RequestCount     int    `json:"request_count" gorm:"type:int;default:0;"`             // request number | ||||||
| 	Group            string `json:"group" gorm:"type:varchar(32);default:'default'"` | 	Group            string `json:"group" gorm:"type:varchar(32);default:'default'"` | ||||||
| 	AffCode          string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` | 	AffCode          string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` | ||||||
| 	InviterId        int    `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` | 	InviterId        int    `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` | ||||||
| @@ -36,8 +55,21 @@ func GetMaxUserId() int { | |||||||
| 	return user.Id | 	return user.Id | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllUsers(startIdx int, num int) (users []*User, err error) { | func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) { | ||||||
| 	err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error | 	query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted) | ||||||
|  |  | ||||||
|  | 	switch order { | ||||||
|  | 	case "quota": | ||||||
|  | 		query = query.Order("quota desc") | ||||||
|  | 	case "used_quota": | ||||||
|  | 		query = query.Order("used_quota desc") | ||||||
|  | 	case "request_count": | ||||||
|  | 		query = query.Order("request_count desc") | ||||||
|  | 	default: | ||||||
|  | 		query = query.Order("id desc") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = query.Find(&users).Error | ||||||
| 	return users, err | 	return users, err | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -89,26 +121,42 @@ func (user *User) Insert(inviterId int) error { | |||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	user.Quota = common.QuotaForNewUser | 	user.Quota = config.QuotaForNewUser | ||||||
| 	user.AccessToken = common.GetUUID() | 	user.AccessToken = random.GetUUID() | ||||||
| 	user.AffCode = common.GetRandomString(4) | 	user.AffCode = random.GetRandomString(4) | ||||||
| 	result := DB.Create(user) | 	result := DB.Create(user) | ||||||
| 	if result.Error != nil { | 	if result.Error != nil { | ||||||
| 		return result.Error | 		return result.Error | ||||||
| 	} | 	} | ||||||
| 	if common.QuotaForNewUser > 0 { | 	if config.QuotaForNewUser > 0 { | ||||||
| 		RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) | 		RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) | ||||||
| 	} | 	} | ||||||
| 	if inviterId != 0 { | 	if inviterId != 0 { | ||||||
| 		if common.QuotaForInvitee > 0 { | 		if config.QuotaForInvitee > 0 { | ||||||
| 			_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee) | 			_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee) | ||||||
| 			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) | 			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) | ||||||
| 		} | 		} | ||||||
| 		if common.QuotaForInviter > 0 { | 		if config.QuotaForInviter > 0 { | ||||||
| 			_ = IncreaseUserQuota(inviterId, common.QuotaForInviter) | 			_ = IncreaseUserQuota(inviterId, config.QuotaForInviter) | ||||||
| 			RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) | 			RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 	// create default token | ||||||
|  | 	cleanToken := Token{ | ||||||
|  | 		UserId:         user.Id, | ||||||
|  | 		Name:           "default", | ||||||
|  | 		Key:            random.GenerateKey(), | ||||||
|  | 		CreatedTime:    helper.GetTimestamp(), | ||||||
|  | 		AccessedTime:   helper.GetTimestamp(), | ||||||
|  | 		ExpiredTime:    -1, | ||||||
|  | 		RemainQuota:    -1, | ||||||
|  | 		UnlimitedQuota: true, | ||||||
|  | 	} | ||||||
|  | 	result.Error = cleanToken.Insert() | ||||||
|  | 	if result.Error != nil { | ||||||
|  | 		// do not block | ||||||
|  | 		logger.SysError(fmt.Sprintf("create default token for user %d failed: %s", user.Id, result.Error.Error())) | ||||||
|  | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -120,6 +168,11 @@ func (user *User) Update(updatePassword bool) error { | |||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 	if user.Status == UserStatusDisabled { | ||||||
|  | 		blacklist.BanUser(user.Id) | ||||||
|  | 	} else if user.Status == UserStatusEnabled { | ||||||
|  | 		blacklist.UnbanUser(user.Id) | ||||||
|  | 	} | ||||||
| 	err = DB.Model(user).Updates(user).Error | 	err = DB.Model(user).Updates(user).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| @@ -128,7 +181,10 @@ func (user *User) Delete() error { | |||||||
| 	if user.Id == 0 { | 	if user.Id == 0 { | ||||||
| 		return errors.New("id 为空!") | 		return errors.New("id 为空!") | ||||||
| 	} | 	} | ||||||
| 	err := DB.Delete(user).Error | 	blacklist.BanUser(user.Id) | ||||||
|  | 	user.Username = fmt.Sprintf("deleted_%s", random.GetUUID()) | ||||||
|  | 	user.Status = UserStatusDeleted | ||||||
|  | 	err := DB.Model(user).Updates(user).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -141,9 +197,17 @@ func (user *User) ValidateAndFill() (err error) { | |||||||
| 	if user.Username == "" || password == "" { | 	if user.Username == "" || password == "" { | ||||||
| 		return errors.New("用户名或密码为空") | 		return errors.New("用户名或密码为空") | ||||||
| 	} | 	} | ||||||
| 	DB.Where(User{Username: user.Username}).First(user) | 	err = DB.Where("username = ?", user.Username).First(user).Error | ||||||
|  | 	if err != nil { | ||||||
|  | 		// we must make sure check username firstly | ||||||
|  | 		// consider this case: a malicious user set his username as other's email | ||||||
|  | 		err := DB.Where("email = ?", user.Username).First(user).Error | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errors.New("用户名或密码错误,或用户已被封禁") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 	okay := common.ValidatePasswordAndHash(password, user.Password) | 	okay := common.ValidatePasswordAndHash(password, user.Password) | ||||||
| 	if !okay || user.Status != common.UserStatusEnabled { | 	if !okay || user.Status != UserStatusEnabled { | ||||||
| 		return errors.New("用户名或密码错误,或用户已被封禁") | 		return errors.New("用户名或密码错误,或用户已被封禁") | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| @@ -173,6 +237,14 @@ func (user *User) FillUserByGitHubId() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (user *User) FillUserByLarkId() error { | ||||||
|  | 	if user.LarkId == "" { | ||||||
|  | 		return errors.New("lark id 为空!") | ||||||
|  | 	} | ||||||
|  | 	DB.Where(User{LarkId: user.LarkId}).First(user) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func (user *User) FillUserByWeChatId() error { | func (user *User) FillUserByWeChatId() error { | ||||||
| 	if user.WeChatId == "" { | 	if user.WeChatId == "" { | ||||||
| 		return errors.New("WeChat id 为空!") | 		return errors.New("WeChat id 为空!") | ||||||
| @@ -201,6 +273,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool { | |||||||
| 	return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 | 	return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func IsLarkIdAlreadyTaken(githubId string) bool { | ||||||
|  | 	return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 | ||||||
|  | } | ||||||
|  |  | ||||||
| func IsUsernameAlreadyTaken(username string) bool { | func IsUsernameAlreadyTaken(username string) bool { | ||||||
| 	return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 | 	return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 | ||||||
| } | } | ||||||
| @@ -224,10 +300,10 @@ func IsAdmin(userId int) bool { | |||||||
| 	var user User | 	var user User | ||||||
| 	err := DB.Where("id = ?", userId).Select("role").Find(&user).Error | 	err := DB.Where("id = ?", userId).Select("role").Find(&user).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("no such user " + err.Error()) | 		logger.SysError("no such user " + err.Error()) | ||||||
| 		return false | 		return false | ||||||
| 	} | 	} | ||||||
| 	return user.Role >= common.RoleAdminUser | 	return user.Role >= RoleAdminUser | ||||||
| } | } | ||||||
|  |  | ||||||
| func IsUserEnabled(userId int) (bool, error) { | func IsUserEnabled(userId int) (bool, error) { | ||||||
| @@ -239,7 +315,7 @@ func IsUserEnabled(userId int) (bool, error) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return false, err | 		return false, err | ||||||
| 	} | 	} | ||||||
| 	return user.Status == common.UserStatusEnabled, nil | 	return user.Status == UserStatusEnabled, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func ValidateAccessToken(token string) (user *User) { | func ValidateAccessToken(token string) (user *User) { | ||||||
| @@ -254,12 +330,12 @@ func ValidateAccessToken(token string) (user *User) { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetUserQuota(id int) (quota int, err error) { | func GetUserQuota(id int) (quota int64, err error) { | ||||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error | 	err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error | ||||||
| 	return quota, err | 	return quota, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetUserUsedQuota(id int) (quota int, err error) { | func GetUserUsedQuota(id int) (quota int64, err error) { | ||||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error | 	err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error | ||||||
| 	return quota, err | 	return quota, err | ||||||
| } | } | ||||||
| @@ -279,45 +355,45 @@ func GetUserGroup(id int) (group string, err error) { | |||||||
| 	return group, err | 	return group, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func IncreaseUserQuota(id int, quota int) (err error) { | func IncreaseUserQuota(id int, quota int64) (err error) { | ||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
| 	if common.BatchUpdateEnabled { | 	if config.BatchUpdateEnabled { | ||||||
| 		addNewRecord(BatchUpdateTypeUserQuota, id, quota) | 		addNewRecord(BatchUpdateTypeUserQuota, id, quota) | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	return increaseUserQuota(id, quota) | 	return increaseUserQuota(id, quota) | ||||||
| } | } | ||||||
|  |  | ||||||
| func increaseUserQuota(id int, quota int) (err error) { | func increaseUserQuota(id int, quota int64) (err error) { | ||||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error | 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func DecreaseUserQuota(id int, quota int) (err error) { | func DecreaseUserQuota(id int, quota int64) (err error) { | ||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
| 	if common.BatchUpdateEnabled { | 	if config.BatchUpdateEnabled { | ||||||
| 		addNewRecord(BatchUpdateTypeUserQuota, id, -quota) | 		addNewRecord(BatchUpdateTypeUserQuota, id, -quota) | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	return decreaseUserQuota(id, quota) | 	return decreaseUserQuota(id, quota) | ||||||
| } | } | ||||||
|  |  | ||||||
| func decreaseUserQuota(id int, quota int) (err error) { | func decreaseUserQuota(id int, quota int64) (err error) { | ||||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error | 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetRootUserEmail() (email string) { | func GetRootUserEmail() (email string) { | ||||||
| 	DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) | 	DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email) | ||||||
| 	return email | 	return email | ||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) { | ||||||
| 	if common.BatchUpdateEnabled { | 	if config.BatchUpdateEnabled { | ||||||
| 		addNewRecord(BatchUpdateTypeUsedQuota, id, quota) | 		addNewRecord(BatchUpdateTypeUsedQuota, id, quota) | ||||||
| 		addNewRecord(BatchUpdateTypeRequestCount, id, 1) | 		addNewRecord(BatchUpdateTypeRequestCount, id, 1) | ||||||
| 		return | 		return | ||||||
| @@ -325,7 +401,7 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | |||||||
| 	updateUserUsedQuotaAndRequestCount(id, quota, 1) | 	updateUserUsedQuotaAndRequestCount(id, quota, 1) | ||||||
| } | } | ||||||
|  |  | ||||||
| func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) { | ||||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"used_quota":    gorm.Expr("used_quota + ?", quota), | 			"used_quota":    gorm.Expr("used_quota + ?", quota), | ||||||
| @@ -333,25 +409,25 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | |||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("failed to update user used quota and request count: " + err.Error()) | 		logger.SysError("failed to update user used quota and request count: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func updateUserUsedQuota(id int, quota int) { | func updateUserUsedQuota(id int, quota int64) { | ||||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"used_quota": gorm.Expr("used_quota + ?", quota), | 			"used_quota": gorm.Expr("used_quota + ?", quota), | ||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("failed to update user used quota: " + err.Error()) | 		logger.SysError("failed to update user used quota: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func updateUserRequestCount(id int, count int) { | func updateUserRequestCount(id int, count int) { | ||||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error | 	err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("failed to update user request count: " + err.Error()) | 		logger.SysError("failed to update user request count: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,7 +1,8 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"one-api/common" | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| @@ -15,12 +16,12 @@ const ( | |||||||
| 	BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock | 	BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var batchUpdateStores []map[int]int | var batchUpdateStores []map[int]int64 | ||||||
| var batchUpdateLocks []sync.Mutex | var batchUpdateLocks []sync.Mutex | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
| 	for i := 0; i < BatchUpdateTypeCount; i++ { | 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||||
| 		batchUpdateStores = append(batchUpdateStores, make(map[int]int)) | 		batchUpdateStores = append(batchUpdateStores, make(map[int]int64)) | ||||||
| 		batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) | 		batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -28,13 +29,13 @@ func init() { | |||||||
| func InitBatchUpdater() { | func InitBatchUpdater() { | ||||||
| 	go func() { | 	go func() { | ||||||
| 		for { | 		for { | ||||||
| 			time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) | 			time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second) | ||||||
| 			batchUpdate() | 			batchUpdate() | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| } | } | ||||||
|  |  | ||||||
| func addNewRecord(type_ int, id int, value int) { | func addNewRecord(type_ int, id int, value int64) { | ||||||
| 	batchUpdateLocks[type_].Lock() | 	batchUpdateLocks[type_].Lock() | ||||||
| 	defer batchUpdateLocks[type_].Unlock() | 	defer batchUpdateLocks[type_].Unlock() | ||||||
| 	if _, ok := batchUpdateStores[type_][id]; !ok { | 	if _, ok := batchUpdateStores[type_][id]; !ok { | ||||||
| @@ -45,11 +46,11 @@ func addNewRecord(type_ int, id int, value int) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func batchUpdate() { | func batchUpdate() { | ||||||
| 	common.SysLog("batch update started") | 	logger.SysLog("batch update started") | ||||||
| 	for i := 0; i < BatchUpdateTypeCount; i++ { | 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||||
| 		batchUpdateLocks[i].Lock() | 		batchUpdateLocks[i].Lock() | ||||||
| 		store := batchUpdateStores[i] | 		store := batchUpdateStores[i] | ||||||
| 		batchUpdateStores[i] = make(map[int]int) | 		batchUpdateStores[i] = make(map[int]int64) | ||||||
| 		batchUpdateLocks[i].Unlock() | 		batchUpdateLocks[i].Unlock() | ||||||
| 		// TODO: maybe we can combine updates with same key? | 		// TODO: maybe we can combine updates with same key? | ||||||
| 		for key, value := range store { | 		for key, value := range store { | ||||||
| @@ -57,21 +58,21 @@ func batchUpdate() { | |||||||
| 			case BatchUpdateTypeUserQuota: | 			case BatchUpdateTypeUserQuota: | ||||||
| 				err := increaseUserQuota(key, value) | 				err := increaseUserQuota(key, value) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					common.SysError("failed to batch update user quota: " + err.Error()) | 					logger.SysError("failed to batch update user quota: " + err.Error()) | ||||||
| 				} | 				} | ||||||
| 			case BatchUpdateTypeTokenQuota: | 			case BatchUpdateTypeTokenQuota: | ||||||
| 				err := increaseTokenQuota(key, value) | 				err := increaseTokenQuota(key, value) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					common.SysError("failed to batch update token quota: " + err.Error()) | 					logger.SysError("failed to batch update token quota: " + err.Error()) | ||||||
| 				} | 				} | ||||||
| 			case BatchUpdateTypeUsedQuota: | 			case BatchUpdateTypeUsedQuota: | ||||||
| 				updateUserUsedQuota(key, value) | 				updateUserUsedQuota(key, value) | ||||||
| 			case BatchUpdateTypeRequestCount: | 			case BatchUpdateTypeRequestCount: | ||||||
| 				updateUserRequestCount(key, value) | 				updateUserRequestCount(key, int(value)) | ||||||
| 			case BatchUpdateTypeChannelUsedQuota: | 			case BatchUpdateTypeChannelUsedQuota: | ||||||
| 				updateChannelUsedQuota(key, value) | 				updateChannelUsedQuota(key, value) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	common.SysLog("batch update finished") | 	logger.SysLog("batch update finished") | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										54
									
								
								monitor/channel.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								monitor/channel.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,54 @@ | |||||||
|  | package monitor | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/logger" | ||||||
|  | 	"github.com/songquanpeng/one-api/common/message" | ||||||
|  | 	"github.com/songquanpeng/one-api/model" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func notifyRootUser(subject string, content string) { | ||||||
|  | 	if config.MessagePusherAddress != "" { | ||||||
|  | 		err := message.SendMessage(subject, content, content) | ||||||
|  | 		if err != nil { | ||||||
|  | 			logger.SysError(fmt.Sprintf("failed to send message: %s", err.Error())) | ||||||
|  | 		} else { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if config.RootUserEmail == "" { | ||||||
|  | 		config.RootUserEmail = model.GetRootUserEmail() | ||||||
|  | 	} | ||||||
|  | 	err := message.SendEmail(subject, config.RootUserEmail, content) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // DisableChannel disable & notify | ||||||
|  | func DisableChannel(channelId int, channelName string, reason string) { | ||||||
|  | 	model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled) | ||||||
|  | 	logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) | ||||||
|  | 	subject := fmt.Sprintf("渠道「%s」(#%d)已被禁用", channelName, channelId) | ||||||
|  | 	content := fmt.Sprintf("渠道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||||
|  | 	notifyRootUser(subject, content) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func MetricDisableChannel(channelId int, successRate float64) { | ||||||
|  | 	model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled) | ||||||
|  | 	logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) | ||||||
|  | 	subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId) | ||||||
|  | 	content := fmt.Sprintf("该渠道(#%d)在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", | ||||||
|  | 		channelId, config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100) | ||||||
|  | 	notifyRootUser(subject, content) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // EnableChannel enable & notify | ||||||
|  | func EnableChannel(channelId int, channelName string) { | ||||||
|  | 	model.UpdateChannelStatusById(channelId, model.ChannelStatusEnabled) | ||||||
|  | 	logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) | ||||||
|  | 	subject := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId) | ||||||
|  | 	content := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId) | ||||||
|  | 	notifyRootUser(subject, content) | ||||||
|  | } | ||||||
							
								
								
									
										62
									
								
								monitor/manage.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								monitor/manage.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,62 @@ | |||||||
|  | package monitor | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | 	"github.com/songquanpeng/one-api/relay/model" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func ShouldDisableChannel(err *model.Error, statusCode int) bool { | ||||||
|  | 	if !config.AutomaticDisableChannelEnabled { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if err == nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if statusCode == http.StatusUnauthorized { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	switch err.Type { | ||||||
|  | 	case "insufficient_quota": | ||||||
|  | 		return true | ||||||
|  | 	// https://docs.anthropic.com/claude/reference/errors | ||||||
|  | 	case "authentication_error": | ||||||
|  | 		return true | ||||||
|  | 	case "permission_error": | ||||||
|  | 		return true | ||||||
|  | 	case "forbidden": | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic | ||||||
|  | 		return true | ||||||
|  | 	} else if strings.HasPrefix(err.Message, "This organization has been disabled.") { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	//if strings.Contains(err.Message, "quota") { | ||||||
|  | 	//	return true | ||||||
|  | 	//} | ||||||
|  | 	if strings.Contains(err.Message, "credit") { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if strings.Contains(err.Message, "balance") { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ShouldEnableChannel(err error, openAIErr *model.Error) bool { | ||||||
|  | 	if !config.AutomaticEnableChannelEnabled { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if openAIErr != nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return true | ||||||
|  | } | ||||||
							
								
								
									
										79
									
								
								monitor/metric.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								monitor/metric.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | |||||||
|  | package monitor | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/songquanpeng/one-api/common/config" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var store = make(map[int][]bool) | ||||||
|  | var metricSuccessChan = make(chan int, config.MetricSuccessChanSize) | ||||||
|  | var metricFailChan = make(chan int, config.MetricFailChanSize) | ||||||
|  |  | ||||||
|  | func consumeSuccess(channelId int) { | ||||||
|  | 	if len(store[channelId]) > config.MetricQueueSize { | ||||||
|  | 		store[channelId] = store[channelId][1:] | ||||||
|  | 	} | ||||||
|  | 	store[channelId] = append(store[channelId], true) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func consumeFail(channelId int) (bool, float64) { | ||||||
|  | 	if len(store[channelId]) > config.MetricQueueSize { | ||||||
|  | 		store[channelId] = store[channelId][1:] | ||||||
|  | 	} | ||||||
|  | 	store[channelId] = append(store[channelId], false) | ||||||
|  | 	successCount := 0 | ||||||
|  | 	for _, success := range store[channelId] { | ||||||
|  | 		if success { | ||||||
|  | 			successCount++ | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	successRate := float64(successCount) / float64(len(store[channelId])) | ||||||
|  | 	if len(store[channelId]) < config.MetricQueueSize { | ||||||
|  | 		return false, successRate | ||||||
|  | 	} | ||||||
|  | 	if successRate < config.MetricSuccessRateThreshold { | ||||||
|  | 		store[channelId] = make([]bool, 0) | ||||||
|  | 		return true, successRate | ||||||
|  | 	} | ||||||
|  | 	return false, successRate | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func metricSuccessConsumer() { | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case channelId := <-metricSuccessChan: | ||||||
|  | 			consumeSuccess(channelId) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func metricFailConsumer() { | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case channelId := <-metricFailChan: | ||||||
|  | 			disable, successRate := consumeFail(channelId) | ||||||
|  | 			if disable { | ||||||
|  | 				go MetricDisableChannel(channelId, successRate) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	if config.EnableMetric { | ||||||
|  | 		go metricSuccessConsumer() | ||||||
|  | 		go metricFailConsumer() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Emit(channelId int, success bool) { | ||||||
|  | 	if !config.EnableMetric { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	go func() { | ||||||
|  | 		if success { | ||||||
|  | 			metricSuccessChan <- channelId | ||||||
|  | 		} else { | ||||||
|  | 			metricFailChan <- channelId | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | } | ||||||
| @@ -1,9 +1,10 @@ | |||||||
| [//]: # (请按照以下格式关联 issue) | [//]: # (请按照以下格式关联 issue) | ||||||
| [//]: # (请在提交 PR 前确认所提交的功能可用,附上截图即可,这将有助于项目维护者 review & merge 该 PR,谢谢) | [//]: # (请在提交 PR 前确认所提交的功能可用,需要附上截图,谢谢) | ||||||
| [//]: # (项目维护者一般仅在周末处理 PR,因此如若未能及时回复希望能理解) | [//]: # (项目维护者一般仅在周末处理 PR,因此如若未能及时回复希望能理解) | ||||||
| [//]: # (开发者交流群:910657413) | [//]: # (开发者交流群:910657413) | ||||||
| [//]: # (请在提交 PR 之前删除上面的注释) | [//]: # (请在提交 PR 之前删除上面的注释) | ||||||
|  |  | ||||||
| close #issue_number | close #issue_number | ||||||
|  |  | ||||||
| 我已确认该 PR 已自测通过,相关截图如下: | 我已确认该 PR 已自测通过,相关截图如下: | ||||||
|  | (此处放上测试通过的截图,如果不涉及前端改动或从 UI 上无法看出,请放终端启动成功的截图) | ||||||
|   | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user