mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-25 19:03:43 +08:00 
			
		
		
		
	Compare commits
	
		
			151 Commits
		
	
	
		
			v0.6.8-alp
			...
			v0.6.11-al
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 93ce6c4cd7 | ||
|  | 4fe5ab8d09 | ||
|  | afbbfbbf83 | ||
|  | 75d9d9d560 | ||
|  | d1af30ee5a | ||
|  | a3924a2353 | ||
|  | 9af5a1d11d | ||
|  | 57f9f7dfbb | ||
|  | d9f2df2baf | ||
|  | bdf312e5dc | ||
|  | 1521df6551 | ||
|  | c67b167f4f | ||
|  | c351e196e6 | ||
|  | a316ed7abc | ||
|  | 0895d8660e | ||
|  | be1ed114f4 | ||
|  | eb6da573a3 | ||
|  | 0a6273fc08 | ||
|  | 5997fce454 | ||
|  | 0df6d7a131 | ||
|  | 93fdb60de5 | ||
|  | 4db834da95 | ||
|  | 6818ed5ca8 | ||
|  | 7be3b5547d | ||
|  | 2d7ea61d67 | ||
|  | 83b34be067 | ||
|  | d5d879afdc | ||
|  | 0f205a3aa3 | ||
|  | 76c3f87351 | ||
|  | 6d9a92f8f7 | ||
|  | 835f0e0d67 | ||
|  | a6981f0d51 | ||
|  | 678d613179 | ||
|  | be089a072b | ||
|  | 45d10aa3df | ||
|  | 9cdd48ac22 | ||
|  | 310e7120e5 | ||
|  | 3d29713268 | ||
|  | f2c7c424e9 | ||
|  | 38a42bb265 | ||
|  | fa2e8f44b1 | ||
|  | 9f74101543 | ||
|  | 28a271a896 | ||
|  | e8ea87fff3 | ||
|  | abe2d2dba8 | ||
|  | 4bcaa064d6 | ||
|  | 52d81e0e24 | ||
|  | dc8c3bc69e | ||
|  | b4e69df802 | ||
|  | d9f74bdff3 | ||
|  | fa2a772731 | ||
|  | 4f68f3e1b3 | ||
|  | 0bab887b2d | ||
|  | 0230d36643 | ||
|  | bad57d049a | ||
|  | dc470ce82e | ||
|  | ea0721d525 | ||
|  | d0402f9086 | ||
|  | 1fead8e7f7 | ||
|  | 09911a301d | ||
|  | f95e6b78b8 | ||
|  | 605bb06667 | ||
|  | d88e07fd9a | ||
|  | 3915ce9814 | ||
|  | 999defc88b | ||
|  | b51c47bc77 | ||
|  | 4f25cde132 | ||
|  | d89e9d7e44 | ||
|  | a858292b54 | ||
|  | ff589b5e4a | ||
|  | 95e8c16338 | ||
|  | 381172cb36 | ||
|  | 59eae186a3 | ||
|  | ce52f355bb | ||
|  | cb9d0a74c9 | ||
|  | 49ffb1c60d | ||
|  | 2f16649896 | ||
|  | af3aa57bd6 | ||
|  | e9f117ff72 | ||
|  | 6bb5247bd6 | ||
|  | 305ce14fe3 | ||
|  | 36c8f4f15c | ||
|  | 45b51ea0ee | ||
|  | 7c8628bd95 | ||
|  | 6ab87f8a08 | ||
|  | 833fa7ad6f | ||
|  | 6eb0770a89 | ||
|  | 92cd46d64f | ||
|  | 2b2dc2c733 | ||
|  | a3d7df7f89 | ||
|  | c368232f50 | ||
|  | cbfc983dc3 | ||
|  | 8ec092ba44 | ||
|  | b0b88a79ff | ||
|  | 7e51b04221 | ||
|  | f75a17f8eb | ||
|  | 6f13a3bb3c | ||
|  | f092eed1db | ||
|  | 629378691b | ||
|  | 3716e1b0e6 | ||
|  | a4d6e7a886 | ||
|  | cb772e5d06 | ||
|  | e32cb0b844 | ||
|  | fdd7bf41c0 | ||
|  | 29389ed44f | ||
|  | 88acc5a614 | ||
|  | a21681096a | ||
|  | 32f90a79a8 | ||
|  | 99c8c77504 | ||
|  | 649ecbf29c | ||
|  | 3a27c90910 | ||
|  | cba82404ae | ||
|  | c9ac670ba1 | ||
|  | 15f815c23c | ||
|  | 89b63ca96f | ||
|  | 8cc54489b9 | ||
|  | 58bf60805e | ||
|  | 6714cf96d6 | ||
|  | f9774698e9 | ||
|  | 2af6f6a166 | ||
|  | 04bb3ef392 | ||
|  | b4bfa418a8 | ||
|  | e7e99e558a | ||
|  | 402fcf7f79 | ||
|  | 36039e329e | ||
|  | c936198ac8 | ||
|  | 296ab013b8 | ||
|  | 5f03c856b4 | ||
|  | 39383e5532 | ||
|  | 2a892c1937 | ||
|  | adba54acd3 | ||
|  | 6209ff9ea9 | ||
|  | 1c44d7e1cd | ||
|  | a3eefb7af0 | ||
|  | b65bee46fb | ||
|  | 422a4e8ee5 | ||
|  | cf9b5f0b92 | ||
|  | 65acb94f45 | ||
|  | 6ad169975f | ||
|  | f636c50c84 | ||
|  | 720fe2dfeb | ||
|  | e090e76c86 | ||
|  | 6a941748f8 | ||
|  | 46a0773580 | ||
|  | ffdb0b0c81 | ||
|  | efd30a40b3 | ||
|  | d7a78f3397 | ||
|  | 273be55797 | ||
|  | ec6ad24810 | ||
|  | c4fe57c165 | ||
|  | 274fcf3d76 | 
							
								
								
									
										10
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							| @@ -1,19 +1,17 @@ | ||||
| name: CI | ||||
|  | ||||
| # This setup assumes that you run the unit tests with code coverage in the same | ||||
| # workflow that will also print the coverage report as comment to the pull request.  | ||||
| # workflow that will also print the coverage report as comment to the pull request. | ||||
| # Therefore, you need to trigger this workflow when a pull request is (re)opened or | ||||
| # when new code is pushed to the branch of the pull request. In addition, you also | ||||
| # need to trigger this workflow when new code is pushed to the main branch because  | ||||
| # need to trigger this workflow when new code is pushed to the main branch because | ||||
| # we need to upload the code coverage results as artifact for the main branch as | ||||
| # well since it will be the baseline code coverage. | ||||
| #  | ||||
| # | ||||
| # We do not want to trigger the workflow for pushes to *any* branch because this | ||||
| # would trigger our jobs twice on pull requests (once from "push" event and once | ||||
| # from "pull_request->synchronize") | ||||
| on: | ||||
|   pull_request: | ||||
|     types: [opened, reopened, synchronize] | ||||
|   push: | ||||
|     branches: | ||||
|       - 'main' | ||||
| @@ -31,7 +29,7 @@ jobs: | ||||
|         with: | ||||
|           go-version: ^1.22 | ||||
|  | ||||
|       # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a  | ||||
|       # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a | ||||
|       # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt") | ||||
|       # in the next step as well as the next job. | ||||
|       - name: Test | ||||
|   | ||||
							
								
								
									
										56
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										56
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							| @@ -1,56 +0,0 @@ | ||||
| name: Publish Docker image (amd64, English) | ||||
|  | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - 'v*.*.*' | ||||
|   workflow_dispatch: | ||||
|     inputs: | ||||
|       name: | ||||
|         description: 'reason' | ||||
|         required: false | ||||
| jobs: | ||||
|   push_to_registries: | ||||
|     name: Push Docker image to multiple registries | ||||
|     runs-on: ubuntu-latest | ||||
|     permissions: | ||||
|       packages: write | ||||
|       contents: read | ||||
|     steps: | ||||
|       - name: Check out the repo | ||||
|         uses: actions/checkout@v3 | ||||
|  | ||||
|       - name: Check repository URL | ||||
|         run: | | ||||
|           REPO_URL=$(git config --get remote.origin.url) | ||||
|           if [[ $REPO_URL == *"pro" ]]; then | ||||
|             exit 1 | ||||
|           fi       | ||||
|  | ||||
|       - name: Save version info | ||||
|         run: | | ||||
|           git describe --tags > VERSION  | ||||
|  | ||||
|       - name: Translate | ||||
|         run: | | ||||
|           python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json | ||||
|       - name: Log in to Docker Hub | ||||
|         uses: docker/login-action@v2 | ||||
|         with: | ||||
|           username: ${{ secrets.DOCKERHUB_USERNAME }} | ||||
|           password: ${{ secrets.DOCKERHUB_TOKEN }} | ||||
|  | ||||
|       - name: Extract metadata (tags, labels) for Docker | ||||
|         id: meta | ||||
|         uses: docker/metadata-action@v4 | ||||
|         with: | ||||
|           images: | | ||||
|             justsong/one-api-en | ||||
|  | ||||
|       - name: Build and push Docker images | ||||
|         uses: docker/build-push-action@v3 | ||||
|         with: | ||||
|           context: . | ||||
|           push: true | ||||
|           tags: ${{ steps.meta.outputs.tags }} | ||||
|           labels: ${{ steps.meta.outputs.labels }} | ||||
							
								
								
									
										61
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										61
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -1,61 +0,0 @@ | ||||
| name: Publish Docker image (amd64) | ||||
|  | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - 'v*.*.*' | ||||
|   workflow_dispatch: | ||||
|     inputs: | ||||
|       name: | ||||
|         description: 'reason' | ||||
|         required: false | ||||
| jobs: | ||||
|   push_to_registries: | ||||
|     name: Push Docker image to multiple registries | ||||
|     runs-on: ubuntu-latest | ||||
|     permissions: | ||||
|       packages: write | ||||
|       contents: read | ||||
|     steps: | ||||
|       - name: Check out the repo | ||||
|         uses: actions/checkout@v3 | ||||
|  | ||||
|       - name: Check repository URL | ||||
|         run: | | ||||
|           REPO_URL=$(git config --get remote.origin.url) | ||||
|           if [[ $REPO_URL == *"pro" ]]; then | ||||
|             exit 1 | ||||
|           fi         | ||||
|  | ||||
|       - name: Save version info | ||||
|         run: | | ||||
|           git describe --tags > VERSION  | ||||
|  | ||||
|       - name: Log in to Docker Hub | ||||
|         uses: docker/login-action@v2 | ||||
|         with: | ||||
|           username: ${{ secrets.DOCKERHUB_USERNAME }} | ||||
|           password: ${{ secrets.DOCKERHUB_TOKEN }} | ||||
|  | ||||
|       - name: Log in to the Container registry | ||||
|         uses: docker/login-action@v2 | ||||
|         with: | ||||
|           registry: ghcr.io | ||||
|           username: ${{ github.actor }} | ||||
|           password: ${{ secrets.GITHUB_TOKEN }} | ||||
|  | ||||
|       - name: Extract metadata (tags, labels) for Docker | ||||
|         id: meta | ||||
|         uses: docker/metadata-action@v4 | ||||
|         with: | ||||
|           images: | | ||||
|             justsong/one-api | ||||
|             ghcr.io/${{ github.repository }} | ||||
|  | ||||
|       - name: Build and push Docker images | ||||
|         uses: docker/build-push-action@v3 | ||||
|         with: | ||||
|           context: . | ||||
|           push: true | ||||
|           tags: ${{ steps.meta.outputs.tags }} | ||||
|           labels: ${{ steps.meta.outputs.labels }} | ||||
| @@ -1,10 +1,9 @@ | ||||
| name: Publish Docker image (arm64) | ||||
| name: Publish Docker image | ||||
| 
 | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - 'v*.*.*' | ||||
|       - '!*-alpha*' | ||||
|   workflow_dispatch: | ||||
|     inputs: | ||||
|       name: | ||||
| @@ -63,7 +62,8 @@ jobs: | ||||
|         uses: docker/build-push-action@v3 | ||||
|         with: | ||||
|           context: . | ||||
|           platforms: linux/amd64,linux/arm64 | ||||
|           #          platforms: linux/amd64,linux/arm64 | ||||
|           platforms: linux/amd64 # TODO disable arm64 for now, because it cause error | ||||
|           push: true | ||||
|           tags: ${{ steps.meta.outputs.tags }} | ||||
|           labels: ${{ steps.meta.outputs.labels }} | ||||
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -9,4 +9,5 @@ logs | ||||
| data | ||||
| /web/node_modules | ||||
| cmd.md | ||||
| .env | ||||
| .env | ||||
| /one-api | ||||
|   | ||||
							
								
								
									
										45
									
								
								Dockerfile
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								Dockerfile
									
									
									
									
									
								
							| @@ -1,42 +1,51 @@ | ||||
| FROM node:16 as builder | ||||
| FROM --platform=$BUILDPLATFORM node:16 AS builder | ||||
|  | ||||
| WORKDIR /web | ||||
| COPY ./VERSION . | ||||
| COPY ./web . | ||||
|  | ||||
| WORKDIR /web/default | ||||
| RUN npm install | ||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||
| RUN npm install --prefix /web/default & \ | ||||
|     npm install --prefix /web/berry & \ | ||||
|     npm install --prefix /web/air & \ | ||||
|     wait | ||||
|  | ||||
| WORKDIR /web/berry | ||||
| RUN npm install | ||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||
|  | ||||
| WORKDIR /web/air | ||||
| RUN npm install | ||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build | ||||
| RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat /web/default/VERSION) npm run build --prefix /web/default & \ | ||||
|     DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat /web/berry/VERSION) npm run build --prefix /web/berry & \ | ||||
|     DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat /web/air/VERSION) npm run build --prefix /web/air & \ | ||||
|     wait | ||||
|  | ||||
| FROM golang AS builder2 | ||||
|  | ||||
| RUN apt-get update && apt-get install -y --no-install-recommends \ | ||||
|     build-essential \ | ||||
|     sqlite3 libsqlite3-dev \ | ||||
|     && rm -rf /var/lib/apt/lists/* | ||||
|  | ||||
| ENV GO111MODULE=on \ | ||||
|     CGO_ENABLED=1 \ | ||||
|     GOOS=linux | ||||
|     GOOS=linux \ | ||||
|     CGO_CFLAGS="-I/usr/include" \ | ||||
|     CGO_LDFLAGS="-L/usr/lib" | ||||
|  | ||||
| WORKDIR /build | ||||
|  | ||||
| ADD go.mod go.sum ./ | ||||
| RUN go mod download | ||||
|  | ||||
| COPY . . | ||||
| COPY --from=builder /web/build ./web/build | ||||
| RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api | ||||
|  | ||||
| FROM alpine | ||||
| RUN go build -trimpath -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)'" -o one-api | ||||
|  | ||||
| RUN apk update \ | ||||
|     && apk upgrade \ | ||||
|     && apk add --no-cache ca-certificates tzdata \ | ||||
|     && update-ca-certificates 2>/dev/null || true | ||||
| # Final runtime image | ||||
| FROM ubuntu:22.04 | ||||
|  | ||||
| RUN apt-get update && apt-get install -y --no-install-recommends \ | ||||
|     ca-certificates tzdata bash \ | ||||
|     && rm -rf /var/lib/apt/lists/* | ||||
|  | ||||
| COPY --from=builder2 /build/one-api / | ||||
|  | ||||
| EXPOSE 3000 | ||||
| WORKDIR /data | ||||
| ENTRYPOINT ["/one-api"] | ||||
							
								
								
									
										35
									
								
								README.en.md
									
									
									
									
									
								
							
							
						
						
									
										35
									
								
								README.en.md
									
									
									
									
									
								
							| @@ -245,16 +245,41 @@ If the channel ID is not provided, load balancing will be used to distribute the | ||||
|     + Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs` | ||||
| 5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. | ||||
|     + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||
| 6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. | ||||
| 6. 'MEMORY_CACHE_ENABLED': Enabling memory caching can cause a certain delay in updating user quotas, with optional values of 'true' and 'false'. If not set, it defaults to 'false'. | ||||
| 7. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. | ||||
|     + Example: `SYNC_FREQUENCY=60` | ||||
| 7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. | ||||
| 8. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. | ||||
|     + Example: `NODE_TYPE=slave` | ||||
| 8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. | ||||
| 9. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. | ||||
|     + Example: `CHANNEL_UPDATE_FREQUENCY=1440` | ||||
| 9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. | ||||
| 10. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. | ||||
|     + Example: `CHANNEL_TEST_FREQUENCY=1440` | ||||
| 10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. | ||||
| 11. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. | ||||
|     + Example: `POLLING_INTERVAL=5` | ||||
| 12. `BATCH_UPDATE_ENABLED`: Enabling batch database update aggregation can cause a certain delay in updating user quotas. The optional values are 'true' and 'false', but if not set, it defaults to 'false'. | ||||
|     +Example: ` BATCH_UPDATE_ENABLED=true` | ||||
|     +If you encounter an issue with too many database connections, you can try enabling this option. | ||||
| 13. `BATCH_UPDATE_INTERVAL=5`: The time interval for batch updating aggregates, measured in seconds, defaults to '5'. | ||||
|     +Example: ` BATCH_UPDATE_INTERVAL=5` | ||||
| 14. Request frequency limit: | ||||
|     + `GLOBAL_API_RATE_LIMIT`: Global API rate limit (excluding relay requests), the maximum number of requests within three minutes per IP, default to 180. | ||||
|     + `GLOBAL_WEL_RATE_LIMIT`: Global web speed limit, the maximum number of requests within three minutes per IP, default to 60. | ||||
| 15. Encoder cache settings: | ||||
|     +`TIKTOKEN_CACHE_DIR`: By default, when the program starts, it will download the encoding of some common word elements online, such as' gpt-3.5 turbo '. In some unstable network environments or offline situations, it may cause startup problems. This directory can be configured to cache data and can be migrated to an offline environment. | ||||
|     +`DATA_GYM_CACHE_DIR`: Currently, this configuration has the same function as' TIKTOKEN-CACHE-DIR ', but its priority is not as high as it. | ||||
| 16. `RELAY_TIMEOUT`: Relay timeout setting, measured in seconds, with no default timeout time set. | ||||
| 17. `RELAY_PROXY`: After setting up, use this proxy to request APIs. | ||||
| 18. `USER_CONTENT_REQUEST_TIMEOUT`: The timeout period for users to upload and download content, measured in seconds. | ||||
| 19. `USER_CONTENT_REQUEST_PROXY`: After setting up, use this agent to request content uploaded by users, such as images. | ||||
| 20. `SQLITE_BUSY_TIMEOUT`: SQLite lock wait timeout setting, measured in milliseconds, default to '3000'. | ||||
| 21. `GEMINI_SAFETY_SETTING`: Gemini's security settings are set to 'BLOCK-NONE' by default. | ||||
| 22. `GEMINI_VERSION`: The Gemini version used by the One API, which defaults to 'v1'. | ||||
| 23. `THE`: The system's theme setting, default to 'default', specific optional values refer to [here] (./web/README. md). | ||||
| 24. `ENABLE_METRIC`: Whether to disable channels based on request success rate, default not enabled, optional values are 'true' and 'false'. | ||||
| 25. `METRIC_QUEUE_SIZE`: Request success rate statistics queue size, default to '10'. | ||||
| 26. `METRIC_SUCCESS_RATE_THRESHOLD`: Request success rate threshold, default to '0.8'. | ||||
| 27. `INITIAL_ROOT_TOKEN`: If this value is set, a root user token with the value of the environment variable will be automatically created when the system starts for the first time. | ||||
| 28. `INITIAL_ROOT_ACCESS_TOKEN`: If this value is set, a system management token will be automatically created for the root user with a value of the environment variable when the system starts for the first time. | ||||
|  | ||||
| ### Command Line Parameters | ||||
| 1. `--port <port_number>`: Specifies the port number on which the server listens. Defaults to `3000`. | ||||
|   | ||||
							
								
								
									
										63
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										63
									
								
								README.md
									
									
									
									
									
								
							| @@ -88,6 +88,9 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|    + [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) | ||||
|    + [x] [DeepL](https://www.deepl.com/) | ||||
|    + [x] [together.ai](https://www.together.ai/) | ||||
|    + [x] [novita.ai](https://www.novita.ai/) | ||||
|    + [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud) | ||||
|    + [x] [xAI](https://x.ai/) | ||||
| 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | ||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||
| @@ -112,8 +115,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
| 21. 支持 Cloudflare Turnstile 用户校验。 | ||||
| 22. 支持用户管理,支持**多种用户登录注册方式**: | ||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||
|     + 支持使用飞书进行授权登录。 | ||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||
|     + 支持[飞书授权登录](https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/authen-v1/authorize/get)([这里有 One API 的实现细节阐述供参考](https://iamazing.cn/page/feishu-oauth-login))。 | ||||
|     + 支持 [GitHub 授权登录](https://github.com/settings/applications/new)。 | ||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||
| 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 | ||||
| 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 | ||||
| @@ -172,6 +175,10 @@ sudo service nginx restart | ||||
|  | ||||
| 初始账号用户名为 `root`,密码为 `123456`。 | ||||
|  | ||||
| ### 通过宝塔面板进行一键部署 | ||||
| 1. 安装宝塔面板9.2.0及以上版本,前往 [宝塔面板](https://www.bt.cn/new/download.html?r=dk_oneapi) 官网,选择正式版的脚本下载安装; | ||||
| 2. 安装后登录宝塔面板,在左侧菜单栏中点击 `Docker`,首次进入会提示安装 `Docker` 服务,点击立即安装,按提示完成安装; | ||||
| 3. 安装完成后在应用商店中搜索 `One-API`,点击安装,配置域名等基本信息即可完成安装; | ||||
|  | ||||
| ### 基于 Docker Compose 进行部署 | ||||
|  | ||||
| @@ -215,7 +222,7 @@ docker-compose ps | ||||
| 3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。 | ||||
| 4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis,无论主从。 | ||||
| 5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。 | ||||
| 6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟。 | ||||
| 6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟(Redis 集群或者哨兵模式的支持请参考环境变量说明)。 | ||||
| 7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis,并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。 | ||||
|  | ||||
| 环境变量的具体使用方法详见[此处](#环境变量)。 | ||||
| @@ -250,9 +257,9 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | ||||
| #### QChatGPT - QQ机器人 | ||||
| 项目主页:https://github.com/RockChinQ/QChatGPT | ||||
|  | ||||
| 根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。 | ||||
| 根据[文档](https://qchatgpt.rockchin.top)完成部署后,在 `data/provider.json`设置`requester.openai-chat-completions.base-url`为 One API 实例地址,并填写 API Key 到 `keys.openai` 组中,设置 `model` 为要使用的模型名称。 | ||||
|  | ||||
| 可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。 | ||||
| 运行期间可以通过`!model`命令查看、切换可用模型。 | ||||
|  | ||||
| ### 部署到第三方平台 | ||||
| <details> | ||||
| @@ -344,6 +351,11 @@ graph LR | ||||
| 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 | ||||
|    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||
|    + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 | ||||
|    + 如果需要使用哨兵或者集群模式: | ||||
|      + 则需要把该环境变量设置为节点列表,例如:`localhost:49153,localhost:49154,localhost:49155`。 | ||||
|      + 除此之外还需要设置以下环境变量: | ||||
|        + `REDIS_PASSWORD`:Redis 集群或者哨兵模式下的密码设置。 | ||||
|        + `REDIS_MASTER_NAME`:Redis 哨兵模式下主节点的名称。 | ||||
| 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 | ||||
|    + 例子:`SESSION_SECRET=random_string` | ||||
| 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 | ||||
| @@ -369,33 +381,36 @@ graph LR | ||||
|    + 例子:`NODE_TYPE=slave` | ||||
| 9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | ||||
|    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` | ||||
| 10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | ||||
| 11. 例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||
| 12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||
| 10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。  | ||||
|    +例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||
| 11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||
|     + 例子:`POLLING_INTERVAL=5` | ||||
| 13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||
| 12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||
|     + 例子:`BATCH_UPDATE_ENABLED=true` | ||||
|     + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 | ||||
| 14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||
| 13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||
|     + 例子:`BATCH_UPDATE_INTERVAL=5` | ||||
| 15. 请求频率限制: | ||||
| 14. 请求频率限制: | ||||
|     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 | ||||
|     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 | ||||
| 16. 编码器缓存设置: | ||||
| 15. 编码器缓存设置: | ||||
|     + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 | ||||
|     + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 | ||||
| 17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | ||||
| 18. `RELAY_PROXY`:设置后使用该代理来请求 API。 | ||||
| 19. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。 | ||||
| 20. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。 | ||||
| 21. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | ||||
| 22. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 | ||||
| 23. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 | ||||
| 24. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 | ||||
| 25. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 | ||||
| 26. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 | ||||
| 27. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 | ||||
| 28. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 | ||||
| 16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | ||||
| 17. `RELAY_PROXY`:设置后使用该代理来请求 API。 | ||||
| 18. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。 | ||||
| 19. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。 | ||||
| 20. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | ||||
| 21. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 | ||||
| 22. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 | ||||
| 23. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 | ||||
| 24. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 | ||||
| 25. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 | ||||
| 26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 | ||||
| 27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 | ||||
| 28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 | ||||
| 29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage,默认不开启,可选值为 `true` 和 `false`。 | ||||
| 30. `TEST_PROMPT`:测试模型时的用户 prompt,默认为 `Print your model name exactly and do not output without any other text.`。 | ||||
|  | ||||
| ### 命令行参数 | ||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||
|   | ||||
| @@ -1,13 +1,14 @@ | ||||
| package config | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/common/env" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/env" | ||||
|  | ||||
| 	"github.com/google/uuid" | ||||
| ) | ||||
|  | ||||
| @@ -35,6 +36,7 @@ var PasswordLoginEnabled = true | ||||
| var PasswordRegisterEnabled = true | ||||
| var EmailVerificationEnabled = false | ||||
| var GitHubOAuthEnabled = false | ||||
| var OidcEnabled = false | ||||
| var WeChatAuthEnabled = false | ||||
| var TurnstileCheckEnabled = false | ||||
| var RegisterEnabled = true | ||||
| @@ -70,6 +72,13 @@ var GitHubClientSecret = "" | ||||
| var LarkClientId = "" | ||||
| var LarkClientSecret = "" | ||||
|  | ||||
| var OidcClientId = "" | ||||
| var OidcClientSecret = "" | ||||
| var OidcWellKnown = "" | ||||
| var OidcAuthorizationEndpoint = "" | ||||
| var OidcTokenEndpoint = "" | ||||
| var OidcUserinfoEndpoint = "" | ||||
|  | ||||
| var WeChatServerAddress = "" | ||||
| var WeChatServerToken = "" | ||||
| var WeChatAccountQRCodeImageURL = "" | ||||
| @@ -143,8 +152,15 @@ var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) | ||||
|  | ||||
| var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") | ||||
|  | ||||
| var InitialRootAccessToken = os.Getenv("INITIAL_ROOT_ACCESS_TOKEN") | ||||
|  | ||||
| var GeminiVersion = env.String("GEMINI_VERSION", "v1") | ||||
|  | ||||
| var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false) | ||||
|  | ||||
| var RelayProxy = env.String("RELAY_PROXY", "") | ||||
| var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") | ||||
| var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) | ||||
|  | ||||
| var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false) | ||||
| var TestPrompt = env.String("TEST_PROMPT", "Print your model name exactly and do not output without any other text.") | ||||
|   | ||||
| @@ -20,4 +20,5 @@ const ( | ||||
| 	BaseURL           = "base_url" | ||||
| 	AvailableModels   = "available_models" | ||||
| 	KeyRequestBody    = "key_request_body" | ||||
| 	SystemPrompt      = "system_prompt" | ||||
| ) | ||||
|   | ||||
| @@ -31,15 +31,15 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||
| 	contentType := c.Request.Header.Get("Content-Type") | ||||
| 	if strings.HasPrefix(contentType, "application/json") { | ||||
| 		err = json.Unmarshal(requestBody, &v) | ||||
| 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||
| 	} else { | ||||
| 		// skip for now | ||||
| 		// TODO: someday non json request have variant model, we will need to implementation this | ||||
| 		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||
| 		err = c.ShouldBind(&v) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	// Reset request body | ||||
| 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -1,9 +1,8 @@ | ||||
| package helper | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"html/template" | ||||
| 	"log" | ||||
| 	"net" | ||||
| @@ -11,6 +10,10 @@ import ( | ||||
| 	"runtime" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| ) | ||||
|  | ||||
| func OpenBrowser(url string) { | ||||
| @@ -106,6 +109,18 @@ func GenRequestID() string { | ||||
| 	return GetTimeString() + random.GetRandomNumberString(8) | ||||
| } | ||||
|  | ||||
| func SetRequestID(ctx context.Context, id string) context.Context { | ||||
| 	return context.WithValue(ctx, RequestIdKey, id) | ||||
| } | ||||
|  | ||||
| func GetRequestID(ctx context.Context) string { | ||||
| 	rawRequestId := ctx.Value(RequestIdKey) | ||||
| 	if rawRequestId == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	return rawRequestId.(string) | ||||
| } | ||||
|  | ||||
| func GetResponseID(c *gin.Context) string { | ||||
| 	logID := c.GetString(RequestIdKey) | ||||
| 	return fmt.Sprintf("chatcmpl-%s", logID) | ||||
| @@ -137,3 +152,23 @@ func String2Int(str string) int { | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|  | ||||
| func Float64PtrMax(p *float64, maxValue float64) *float64 { | ||||
| 	if p == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	if *p > maxValue { | ||||
| 		return &maxValue | ||||
| 	} | ||||
| 	return p | ||||
| } | ||||
|  | ||||
| func Float64PtrMin(p *float64, minValue float64) *float64 { | ||||
| 	if p == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	if *p < minValue { | ||||
| 		return &minValue | ||||
| 	} | ||||
| 	return p | ||||
| } | ||||
|   | ||||
| @@ -13,3 +13,8 @@ func GetTimeString() string { | ||||
| 	now := time.Now() | ||||
| 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) | ||||
| } | ||||
|  | ||||
| // CalcElapsedTime return the elapsed time in milliseconds (ms) | ||||
| func CalcElapsedTime(start time.Time) int64 { | ||||
| 	return time.Now().Sub(start).Milliseconds() | ||||
| } | ||||
|   | ||||
| @@ -7,19 +7,25 @@ import ( | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"runtime" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| ) | ||||
|  | ||||
| type loggerLevel string | ||||
|  | ||||
| const ( | ||||
| 	loggerDEBUG = "DEBUG" | ||||
| 	loggerINFO  = "INFO" | ||||
| 	loggerWarn  = "WARN" | ||||
| 	loggerError = "ERR" | ||||
| 	loggerDEBUG loggerLevel = "DEBUG" | ||||
| 	loggerINFO  loggerLevel = "INFO" | ||||
| 	loggerWarn  loggerLevel = "WARN" | ||||
| 	loggerError loggerLevel = "ERROR" | ||||
| 	loggerFatal loggerLevel = "FATAL" | ||||
| ) | ||||
|  | ||||
| var setupLogOnce sync.Once | ||||
| @@ -27,7 +33,12 @@ var setupLogOnce sync.Once | ||||
| func SetupLogger() { | ||||
| 	setupLogOnce.Do(func() { | ||||
| 		if LogDir != "" { | ||||
| 			logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) | ||||
| 			var logPath string | ||||
| 			if config.OnlyOneLogFile { | ||||
| 				logPath = filepath.Join(LogDir, "oneapi.log") | ||||
| 			} else { | ||||
| 				logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) | ||||
| 			} | ||||
| 			fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | ||||
| 			if err != nil { | ||||
| 				log.Fatal("failed to open log file") | ||||
| @@ -39,27 +50,26 @@ func SetupLogger() { | ||||
| } | ||||
|  | ||||
| func SysLog(s string) { | ||||
| 	t := time.Now() | ||||
| 	_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | ||||
| 	logHelper(nil, loggerINFO, s) | ||||
| } | ||||
|  | ||||
| func SysLogf(format string, a ...any) { | ||||
| 	SysLog(fmt.Sprintf(format, a...)) | ||||
| 	logHelper(nil, loggerINFO, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func SysError(s string) { | ||||
| 	t := time.Now() | ||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | ||||
| 	logHelper(nil, loggerError, s) | ||||
| } | ||||
|  | ||||
| func SysErrorf(format string, a ...any) { | ||||
| 	SysError(fmt.Sprintf(format, a...)) | ||||
| 	logHelper(nil, loggerError, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func Debug(ctx context.Context, msg string) { | ||||
| 	if config.DebugEnabled { | ||||
| 		logHelper(ctx, loggerDEBUG, msg) | ||||
| 	if !config.DebugEnabled { | ||||
| 		return | ||||
| 	} | ||||
| 	logHelper(ctx, loggerDEBUG, msg) | ||||
| } | ||||
|  | ||||
| func Info(ctx context.Context, msg string) { | ||||
| @@ -75,37 +85,65 @@ func Error(ctx context.Context, msg string) { | ||||
| } | ||||
|  | ||||
| func Debugf(ctx context.Context, format string, a ...any) { | ||||
| 	Debug(ctx, fmt.Sprintf(format, a...)) | ||||
| 	logHelper(ctx, loggerDEBUG, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func Infof(ctx context.Context, format string, a ...any) { | ||||
| 	Info(ctx, fmt.Sprintf(format, a...)) | ||||
| 	logHelper(ctx, loggerINFO, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func Warnf(ctx context.Context, format string, a ...any) { | ||||
| 	Warn(ctx, fmt.Sprintf(format, a...)) | ||||
| 	logHelper(ctx, loggerWarn, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func Errorf(ctx context.Context, format string, a ...any) { | ||||
| 	Error(ctx, fmt.Sprintf(format, a...)) | ||||
| 	logHelper(ctx, loggerError, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func logHelper(ctx context.Context, level string, msg string) { | ||||
| func FatalLog(s string) { | ||||
| 	logHelper(nil, loggerFatal, s) | ||||
| } | ||||
|  | ||||
| func FatalLogf(format string, a ...any) { | ||||
| 	logHelper(nil, loggerFatal, fmt.Sprintf(format, a...)) | ||||
| } | ||||
|  | ||||
| func logHelper(ctx context.Context, level loggerLevel, msg string) { | ||||
| 	writer := gin.DefaultErrorWriter | ||||
| 	if level == loggerINFO { | ||||
| 		writer = gin.DefaultWriter | ||||
| 	} | ||||
| 	id := ctx.Value(helper.RequestIdKey) | ||||
| 	if id == nil { | ||||
| 		id = helper.GenRequestID() | ||||
| 	var requestId string | ||||
| 	if ctx != nil { | ||||
| 		rawRequestId := helper.GetRequestID(ctx) | ||||
| 		if rawRequestId != "" { | ||||
| 			requestId = fmt.Sprintf(" | %s", rawRequestId) | ||||
| 		} | ||||
| 	} | ||||
| 	lineInfo, funcName := getLineInfo() | ||||
| 	now := time.Now() | ||||
| 	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) | ||||
| 	_, _ = fmt.Fprintf(writer, "[%s] %v%s%s %s%s \n", level, now.Format("2006/01/02 - 15:04:05"), requestId, lineInfo, funcName, msg) | ||||
| 	SetupLogger() | ||||
| 	if level == loggerFatal { | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func FatalLog(v ...any) { | ||||
| 	t := time.Now() | ||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) | ||||
| 	os.Exit(1) | ||||
| func getLineInfo() (string, string) { | ||||
| 	funcName := "[unknown] " | ||||
| 	pc, file, line, ok := runtime.Caller(3) | ||||
| 	if ok { | ||||
| 		if fn := runtime.FuncForPC(pc); fn != nil { | ||||
| 			parts := strings.Split(fn.Name(), ".") | ||||
| 			funcName = "[" + parts[len(parts)-1] + "] " | ||||
| 		} | ||||
| 	} else { | ||||
| 		file = "unknown" | ||||
| 		line = 0 | ||||
| 	} | ||||
| 	parts := strings.Split(file, "one-api/") | ||||
| 	if len(parts) > 1 { | ||||
| 		file = parts[1] | ||||
| 	} | ||||
| 	return fmt.Sprintf(" | %s:%d", file, line), funcName | ||||
| } | ||||
|   | ||||
| @@ -6,11 +6,16 @@ import ( | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"net" | ||||
| 	"net/smtp" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func shouldAuth() bool { | ||||
| 	return config.SMTPAccount != "" || config.SMTPToken != "" | ||||
| } | ||||
|  | ||||
| func SendEmail(subject string, receiver string, content string) error { | ||||
| 	if receiver == "" { | ||||
| 		return fmt.Errorf("receiver is empty") | ||||
| @@ -41,16 +46,24 @@ func SendEmail(subject string, receiver string, content string) error { | ||||
| 		"Date: %s\r\n"+ | ||||
| 		"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", | ||||
| 		receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) | ||||
|  | ||||
| 	auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) | ||||
| 	addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) | ||||
| 	to := strings.Split(receiver, ";") | ||||
|  | ||||
| 	if config.SMTPPort == 465 { | ||||
| 		tlsConfig := &tls.Config{ | ||||
| 			InsecureSkipVerify: true, | ||||
| 			ServerName:         config.SMTPServer, | ||||
| 	if config.SMTPPort == 465 || !shouldAuth() { | ||||
| 		// need advanced client | ||||
| 		var conn net.Conn | ||||
| 		var err error | ||||
| 		if config.SMTPPort == 465 { | ||||
| 			tlsConfig := &tls.Config{ | ||||
| 				InsecureSkipVerify: true, | ||||
| 				ServerName:         config.SMTPServer, | ||||
| 			} | ||||
| 			conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) | ||||
| 		} else { | ||||
| 			conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)) | ||||
| 		} | ||||
| 		conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @@ -59,8 +72,10 @@ func SendEmail(subject string, receiver string, content string) error { | ||||
| 			return err | ||||
| 		} | ||||
| 		defer client.Close() | ||||
| 		if err = client.Auth(auth); err != nil { | ||||
| 			return err | ||||
| 		if shouldAuth() { | ||||
| 			if err = client.Auth(auth); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 		if err = client.Mail(config.SMTPFrom); err != nil { | ||||
| 			return err | ||||
|   | ||||
| @@ -2,13 +2,15 @@ package common | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"os" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var RDB *redis.Client | ||||
| var RDB redis.Cmdable | ||||
| var RedisEnabled = true | ||||
|  | ||||
| // InitRedisClient This function is called after init() | ||||
| @@ -23,13 +25,23 @@ func InitRedisClient() (err error) { | ||||
| 		logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") | ||||
| 		return nil | ||||
| 	} | ||||
| 	logger.SysLog("Redis is enabled") | ||||
| 	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("failed to parse Redis connection string: " + err.Error()) | ||||
| 	redisConnString := os.Getenv("REDIS_CONN_STRING") | ||||
| 	if os.Getenv("REDIS_MASTER_NAME") == "" { | ||||
| 		logger.SysLog("Redis is enabled") | ||||
| 		opt, err := redis.ParseURL(redisConnString) | ||||
| 		if err != nil { | ||||
| 			logger.FatalLog("failed to parse Redis connection string: " + err.Error()) | ||||
| 		} | ||||
| 		RDB = redis.NewClient(opt) | ||||
| 	} else { | ||||
| 		// cluster mode | ||||
| 		logger.SysLog("Redis cluster mode enabled") | ||||
| 		RDB = redis.NewUniversalClient(&redis.UniversalOptions{ | ||||
| 			Addrs:      strings.Split(redisConnString, ","), | ||||
| 			Password:   os.Getenv("REDIS_PASSWORD"), | ||||
| 			MasterName: os.Getenv("REDIS_MASTER_NAME"), | ||||
| 		}) | ||||
| 	} | ||||
| 	RDB = redis.NewClient(opt) | ||||
|  | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||
| 	defer cancel() | ||||
|  | ||||
|   | ||||
| @@ -3,9 +3,10 @@ package render | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func StringData(c *gin.Context, str string) { | ||||
|   | ||||
| @@ -5,16 +5,18 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type GitHubOAuthResponse struct { | ||||
| @@ -81,6 +83,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { | ||||
| } | ||||
|  | ||||
| func GitHubOAuth(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	session := sessions.Default(c) | ||||
| 	state := c.Query("state") | ||||
| 	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { | ||||
| @@ -136,7 +139,7 @@ func GitHubOAuth(c *gin.Context) { | ||||
| 			user.Role = model.RoleCommonUser | ||||
| 			user.Status = model.UserStatusEnabled | ||||
|  | ||||
| 			if err := user.Insert(0); err != nil { | ||||
| 			if err := user.Insert(ctx, 0); err != nil { | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| 					"success": false, | ||||
| 					"message": err.Error(), | ||||
|   | ||||
| @@ -5,15 +5,17 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type LarkOAuthResponse struct { | ||||
| @@ -40,7 +42,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) { | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) | ||||
| 	req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -79,6 +81,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) { | ||||
| } | ||||
|  | ||||
| func LarkOAuth(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	session := sessions.Default(c) | ||||
| 	state := c.Query("state") | ||||
| 	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { | ||||
| @@ -125,7 +128,7 @@ func LarkOAuth(c *gin.Context) { | ||||
| 			user.Role = model.RoleCommonUser | ||||
| 			user.Status = model.UserStatusEnabled | ||||
|  | ||||
| 			if err := user.Insert(0); err != nil { | ||||
| 			if err := user.Insert(ctx, 0); err != nil { | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| 					"success": false, | ||||
| 					"message": err.Error(), | ||||
|   | ||||
							
								
								
									
										228
									
								
								controller/auth/oidc.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								controller/auth/oidc.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,228 @@ | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| ) | ||||
|  | ||||
| type OidcResponse struct { | ||||
| 	AccessToken  string `json:"access_token"` | ||||
| 	IDToken      string `json:"id_token"` | ||||
| 	RefreshToken string `json:"refresh_token"` | ||||
| 	TokenType    string `json:"token_type"` | ||||
| 	ExpiresIn    int    `json:"expires_in"` | ||||
| 	Scope        string `json:"scope"` | ||||
| } | ||||
|  | ||||
| type OidcUser struct { | ||||
| 	OpenID            string `json:"sub"` | ||||
| 	Email             string `json:"email"` | ||||
| 	Name              string `json:"name"` | ||||
| 	PreferredUsername string `json:"preferred_username"` | ||||
| 	Picture           string `json:"picture"` | ||||
| } | ||||
|  | ||||
| func getOidcUserInfoByCode(code string) (*OidcUser, error) { | ||||
| 	if code == "" { | ||||
| 		return nil, errors.New("无效的参数") | ||||
| 	} | ||||
| 	values := map[string]string{ | ||||
| 		"client_id":     config.OidcClientId, | ||||
| 		"client_secret": config.OidcClientSecret, | ||||
| 		"code":          code, | ||||
| 		"grant_type":    "authorization_code", | ||||
| 		"redirect_uri":  fmt.Sprintf("%s/oauth/oidc", config.ServerAddress), | ||||
| 	} | ||||
| 	jsonData, err := json.Marshal(values) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/json") | ||||
| 	req.Header.Set("Accept", "application/json") | ||||
| 	client := http.Client{ | ||||
| 		Timeout: 5 * time.Second, | ||||
| 	} | ||||
| 	res, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		logger.SysLog(err.Error()) | ||||
| 		return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
| 	var oidcResponse OidcResponse | ||||
| 	err = json.NewDecoder(res.Body).Decode(&oidcResponse) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) | ||||
| 	res2, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		logger.SysLog(err.Error()) | ||||
| 		return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") | ||||
| 	} | ||||
| 	var oidcUser OidcUser | ||||
| 	err = json.NewDecoder(res2.Body).Decode(&oidcUser) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &oidcUser, nil | ||||
| } | ||||
|  | ||||
| func OidcAuth(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	session := sessions.Default(c) | ||||
| 	state := c.Query("state") | ||||
| 	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { | ||||
| 		c.JSON(http.StatusForbidden, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "state is empty or not same", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	username := session.Get("username") | ||||
| 	if username != nil { | ||||
| 		OidcBind(c) | ||||
| 		return | ||||
| 	} | ||||
| 	if !config.OidcEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "管理员未开启通过 OIDC 登录以及注册", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	code := c.Query("code") | ||||
| 	oidcUser, err := getOidcUserInfoByCode(code) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user := model.User{ | ||||
| 		OidcId: oidcUser.OpenID, | ||||
| 	} | ||||
| 	if model.IsOidcIdAlreadyTaken(user.OidcId) { | ||||
| 		err := user.FillUserByOidcId() | ||||
| 		if err != nil { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": err.Error(), | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} else { | ||||
| 		if config.RegisterEnabled { | ||||
| 			user.Email = oidcUser.Email | ||||
| 			if oidcUser.PreferredUsername != "" { | ||||
| 				user.Username = oidcUser.PreferredUsername | ||||
| 			} else { | ||||
| 				user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1) | ||||
| 			} | ||||
| 			if oidcUser.Name != "" { | ||||
| 				user.DisplayName = oidcUser.Name | ||||
| 			} else { | ||||
| 				user.DisplayName = "OIDC User" | ||||
| 			} | ||||
| 			err := user.Insert(ctx, 0) | ||||
| 			if err != nil { | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| 					"success": false, | ||||
| 					"message": err.Error(), | ||||
| 				}) | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "管理员关闭了新用户注册", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if user.Status != model.UserStatusEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "用户已被封禁", | ||||
| 			"success": false, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	controller.SetupLogin(&user, c) | ||||
| } | ||||
|  | ||||
| func OidcBind(c *gin.Context) { | ||||
| 	if !config.OidcEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "管理员未开启通过 OIDC 登录以及注册", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	code := c.Query("code") | ||||
| 	oidcUser, err := getOidcUserInfoByCode(code) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user := model.User{ | ||||
| 		OidcId: oidcUser.OpenID, | ||||
| 	} | ||||
| 	if model.IsOidcIdAlreadyTaken(user.OidcId) { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "该 OIDC 账户已被绑定", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	session := sessions.Default(c) | ||||
| 	id := session.Get("id") | ||||
| 	// id := c.GetInt("id")  // critical bug! | ||||
| 	user.Id = id.(int) | ||||
| 	err = user.FillUserById() | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user.OidcId = oidcUser.OpenID | ||||
| 	err = user.Update(false) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "bind", | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
| @@ -4,14 +4,16 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type wechatLoginResponse struct { | ||||
| @@ -52,6 +54,7 @@ func getWeChatIdByCode(code string) (string, error) { | ||||
| } | ||||
|  | ||||
| func WeChatAuth(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	if !config.WeChatAuthEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "管理员未开启通过微信登录以及注册", | ||||
| @@ -87,7 +90,7 @@ func WeChatAuth(c *gin.Context) { | ||||
| 			user.Role = model.RoleCommonUser | ||||
| 			user.Status = model.UserStatusEnabled | ||||
|  | ||||
| 			if err := user.Insert(0); err != nil { | ||||
| 			if err := user.Insert(ctx, 0); err != nil { | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| 					"success": false, | ||||
| 					"message": err.Error(), | ||||
|   | ||||
| @@ -17,9 +17,11 @@ func GetSubscription(c *gin.Context) { | ||||
| 	if config.DisplayTokenStatEnabled { | ||||
| 		tokenId := c.GetInt(ctxkey.TokenId) | ||||
| 		token, err = model.GetTokenById(tokenId) | ||||
| 		expiredTime = token.ExpiredTime | ||||
| 		remainQuota = token.RemainQuota | ||||
| 		usedQuota = token.UsedQuota | ||||
| 		if err == nil { | ||||
| 			expiredTime = token.ExpiredTime | ||||
| 			remainQuota = token.RemainQuota | ||||
| 			usedQuota = token.UsedQuota | ||||
| 		} | ||||
| 	} else { | ||||
| 		userId := c.GetInt(ctxkey.Id) | ||||
| 		remainQuota, err = model.GetUserQuota(userId) | ||||
|   | ||||
| @@ -4,16 +4,17 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/client" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/monitor" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
| @@ -81,6 +82,36 @@ type APGC2DGPTUsageResponse struct { | ||||
| 	TotalUsed      float64 `json:"total_used"` | ||||
| } | ||||
|  | ||||
| type SiliconFlowUsageResponse struct { | ||||
| 	Code    int    `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| 	Status  bool   `json:"status"` | ||||
| 	Data    struct { | ||||
| 		ID            string `json:"id"` | ||||
| 		Name          string `json:"name"` | ||||
| 		Image         string `json:"image"` | ||||
| 		Email         string `json:"email"` | ||||
| 		IsAdmin       bool   `json:"isAdmin"` | ||||
| 		Balance       string `json:"balance"` | ||||
| 		Status        string `json:"status"` | ||||
| 		Introduction  string `json:"introduction"` | ||||
| 		Role          string `json:"role"` | ||||
| 		ChargeBalance string `json:"chargeBalance"` | ||||
| 		TotalBalance  string `json:"totalBalance"` | ||||
| 		Category      string `json:"category"` | ||||
| 	} `json:"data"` | ||||
| } | ||||
|  | ||||
| type DeepSeekUsageResponse struct { | ||||
| 	IsAvailable  bool `json:"is_available"` | ||||
| 	BalanceInfos []struct { | ||||
| 		Currency        string `json:"currency"` | ||||
| 		TotalBalance    string `json:"total_balance"` | ||||
| 		GrantedBalance  string `json:"granted_balance"` | ||||
| 		ToppedUpBalance string `json:"topped_up_balance"` | ||||
| 	} `json:"balance_infos"` | ||||
| } | ||||
|  | ||||
| // GetAuthHeader get auth header | ||||
| func GetAuthHeader(token string) http.Header { | ||||
| 	h := http.Header{} | ||||
| @@ -203,6 +234,57 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { | ||||
| 	return response.TotalAvailable, nil | ||||
| } | ||||
|  | ||||
| func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := "https://api.siliconflow.cn/v1/user/info" | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	response := SiliconFlowUsageResponse{} | ||||
| 	err = json.Unmarshal(body, &response) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	if response.Code != 20000 { | ||||
| 		return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message) | ||||
| 	} | ||||
| 	balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	channel.UpdateBalance(balance) | ||||
| 	return balance, nil | ||||
| } | ||||
|  | ||||
| func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := "https://api.deepseek.com/user/balance" | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	response := DeepSeekUsageResponse{} | ||||
| 	err = json.Unmarshal(body, &response) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	index := -1 | ||||
| 	for i, balanceInfo := range response.BalanceInfos { | ||||
| 		if balanceInfo.Currency == "CNY" { | ||||
| 			index = i | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	if index == -1 { | ||||
| 		return 0, errors.New("currency CNY not found") | ||||
| 	} | ||||
| 	balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	channel.UpdateBalance(balance) | ||||
| 	return balance, nil | ||||
| } | ||||
|  | ||||
| func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||
| 	baseURL := channeltype.ChannelBaseURLs[channel.Type] | ||||
| 	if channel.GetBaseURL() == "" { | ||||
| @@ -227,6 +309,10 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||
| 		return updateChannelAPI2GPTBalance(channel) | ||||
| 	case channeltype.AIGC2D: | ||||
| 		return updateChannelAIGC2DBalance(channel) | ||||
| 	case channeltype.SiliconFlow: | ||||
| 		return updateChannelSiliconFlowBalance(channel) | ||||
| 	case channeltype.DeepSeek: | ||||
| 		return updateChannelDeepSeekBalance(channel) | ||||
| 	default: | ||||
| 		return 0, errors.New("尚未实现") | ||||
| 	} | ||||
|   | ||||
| @@ -2,6 +2,7 @@ package controller | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| @@ -14,38 +15,58 @@ import ( | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/message" | ||||
| 	"github.com/songquanpeng/one-api/middleware" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/monitor" | ||||
| 	relay "github.com/songquanpeng/one-api/relay" | ||||
| 	"github.com/songquanpeng/one-api/relay" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"github.com/songquanpeng/one-api/relay/controller" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func buildTestRequest() *relaymodel.GeneralOpenAIRequest { | ||||
| func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest { | ||||
| 	if model == "" { | ||||
| 		model = "gpt-3.5-turbo" | ||||
| 	} | ||||
| 	testRequest := &relaymodel.GeneralOpenAIRequest{ | ||||
| 		MaxTokens: 2, | ||||
| 		Stream:    false, | ||||
| 		Model:     "gpt-3.5-turbo", | ||||
| 		Model: model, | ||||
| 	} | ||||
| 	testMessage := relaymodel.Message{ | ||||
| 		Role:    "user", | ||||
| 		Content: "hi", | ||||
| 		Content: config.TestPrompt, | ||||
| 	} | ||||
| 	testRequest.Messages = append(testRequest.Messages, testMessage) | ||||
| 	return testRequest | ||||
| } | ||||
|  | ||||
| func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) { | ||||
| func parseTestResponse(resp string) (*openai.TextResponse, string, error) { | ||||
| 	var response openai.TextResponse | ||||
| 	err := json.Unmarshal([]byte(resp), &response) | ||||
| 	if err != nil { | ||||
| 		return nil, "", err | ||||
| 	} | ||||
| 	if len(response.Choices) == 0 { | ||||
| 		return nil, "", errors.New("response has no choices") | ||||
| 	} | ||||
| 	stringContent, ok := response.Choices[0].Content.(string) | ||||
| 	if !ok { | ||||
| 		return nil, "", errors.New("response content is not string") | ||||
| 	} | ||||
| 	return &response, stringContent, nil | ||||
| } | ||||
|  | ||||
| func testChannel(ctx context.Context, channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (responseMessage string, err error, openaiErr *relaymodel.Error) { | ||||
| 	startTime := time.Now() | ||||
| 	w := httptest.NewRecorder() | ||||
| 	c, _ := gin.CreateTestContext(w) | ||||
| 	c.Request = &http.Request{ | ||||
| @@ -65,64 +86,87 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error | ||||
| 	apiType := channeltype.ToAPIType(channel.Type) | ||||
| 	adaptor := relay.GetAdaptor(apiType) | ||||
| 	if adaptor == nil { | ||||
| 		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil | ||||
| 		return "", fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil | ||||
| 	} | ||||
| 	adaptor.Init(meta) | ||||
| 	var modelName string | ||||
| 	modelList := adaptor.GetModelList() | ||||
| 	modelName := request.Model | ||||
| 	modelMap := channel.GetModelMapping() | ||||
| 	if len(modelList) != 0 { | ||||
| 		modelName = modelList[0] | ||||
| 	} | ||||
| 	if modelName == "" || !strings.Contains(channel.Models, modelName) { | ||||
| 		modelNames := strings.Split(channel.Models, ",") | ||||
| 		if len(modelNames) > 0 { | ||||
| 			modelName = modelNames[0] | ||||
| 		} | ||||
| 		if modelMap != nil && modelMap[modelName] != "" { | ||||
| 			modelName = modelMap[modelName] | ||||
| 		} | ||||
| 	} | ||||
| 	request := buildTestRequest() | ||||
| 	if modelMap != nil && modelMap[modelName] != "" { | ||||
| 		modelName = modelMap[modelName] | ||||
| 	} | ||||
| 	meta.OriginModelName, meta.ActualModelName = request.Model, modelName | ||||
| 	request.Model = modelName | ||||
| 	meta.OriginModelName, meta.ActualModelName = modelName, modelName | ||||
| 	convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 		return "", err, nil | ||||
| 	} | ||||
| 	jsonData, err := json.Marshal(convertedRequest) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 		return "", err, nil | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		logContent := fmt.Sprintf("渠道 %s 测试成功,响应:%s", channel.Name, responseMessage) | ||||
| 		if err != nil || openaiErr != nil { | ||||
| 			errorMessage := "" | ||||
| 			if err != nil { | ||||
| 				errorMessage = err.Error() | ||||
| 			} else { | ||||
| 				errorMessage = openaiErr.Message | ||||
| 			} | ||||
| 			logContent = fmt.Sprintf("渠道 %s 测试失败,错误:%s", channel.Name, errorMessage) | ||||
| 		} | ||||
| 		go model.RecordTestLog(ctx, &model.Log{ | ||||
| 			ChannelId:   channel.Id, | ||||
| 			ModelName:   modelName, | ||||
| 			Content:     logContent, | ||||
| 			ElapsedTime: helper.CalcElapsedTime(startTime), | ||||
| 		}) | ||||
| 	}() | ||||
| 	logger.SysLog(string(jsonData)) | ||||
| 	requestBody := bytes.NewBuffer(jsonData) | ||||
| 	c.Request.Body = io.NopCloser(requestBody) | ||||
| 	resp, err := adaptor.DoRequest(c, meta, requestBody) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 		return "", err, nil | ||||
| 	} | ||||
| 	if resp != nil && resp.StatusCode != http.StatusOK { | ||||
| 		err := controller.RelayErrorHandler(resp) | ||||
| 		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error | ||||
| 		errorMessage := err.Error.Message | ||||
| 		if errorMessage != "" { | ||||
| 			errorMessage = ", error message: " + errorMessage | ||||
| 		} | ||||
| 		return "", fmt.Errorf("http status code: %d%s", resp.StatusCode, errorMessage), &err.Error | ||||
| 	} | ||||
| 	usage, respErr := adaptor.DoResponse(c, resp, meta) | ||||
| 	if respErr != nil { | ||||
| 		return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error | ||||
| 		return "", fmt.Errorf("%s", respErr.Error.Message), &respErr.Error | ||||
| 	} | ||||
| 	if usage == nil { | ||||
| 		return errors.New("usage is nil"), nil | ||||
| 		return "", errors.New("usage is nil"), nil | ||||
| 	} | ||||
| 	rawResponse := w.Body.String() | ||||
| 	_, responseMessage, err = parseTestResponse(rawResponse) | ||||
| 	if err != nil { | ||||
| 		return "", err, nil | ||||
| 	} | ||||
| 	result := w.Result() | ||||
| 	// print result.Body | ||||
| 	respBody, err := io.ReadAll(result.Body) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 		return "", err, nil | ||||
| 	} | ||||
| 	logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) | ||||
| 	return nil, nil | ||||
| 	return responseMessage, nil, nil | ||||
| } | ||||
|  | ||||
| func TestChannel(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	id, err := strconv.Atoi(c.Param("id")) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| @@ -139,24 +183,31 @@ func TestChannel(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	modelName := c.Query("model") | ||||
| 	testRequest := buildTestRequest(modelName) | ||||
| 	tik := time.Now() | ||||
| 	err, _ = testChannel(channel) | ||||
| 	responseMessage, err, _ := testChannel(ctx, channel, testRequest) | ||||
| 	tok := time.Now() | ||||
| 	milliseconds := tok.Sub(tik).Milliseconds() | ||||
| 	if err != nil { | ||||
| 		milliseconds = 0 | ||||
| 	} | ||||
| 	go channel.UpdateResponseTime(milliseconds) | ||||
| 	consumedTime := float64(milliseconds) / 1000.0 | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 			"time":    consumedTime, | ||||
| 			"success":   false, | ||||
| 			"message":   err.Error(), | ||||
| 			"time":      consumedTime, | ||||
| 			"modelName": modelName, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"time":    consumedTime, | ||||
| 		"success":   true, | ||||
| 		"message":   responseMessage, | ||||
| 		"time":      consumedTime, | ||||
| 		"modelName": modelName, | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
| @@ -164,7 +215,7 @@ func TestChannel(c *gin.Context) { | ||||
| var testAllChannelsLock sync.Mutex | ||||
| var testAllChannelsRunning bool = false | ||||
|  | ||||
| func testChannels(notify bool, scope string) error { | ||||
| func testChannels(ctx context.Context, notify bool, scope string) error { | ||||
| 	if config.RootUserEmail == "" { | ||||
| 		config.RootUserEmail = model.GetRootUserEmail() | ||||
| 	} | ||||
| @@ -187,11 +238,12 @@ func testChannels(notify bool, scope string) error { | ||||
| 		for _, channel := range channels { | ||||
| 			isChannelEnabled := channel.Status == model.ChannelStatusEnabled | ||||
| 			tik := time.Now() | ||||
| 			err, openaiErr := testChannel(channel) | ||||
| 			testRequest := buildTestRequest("") | ||||
| 			_, err, openaiErr := testChannel(ctx, channel, testRequest) | ||||
| 			tok := time.Now() | ||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | ||||
| 			if isChannelEnabled && milliseconds > disableThreshold { | ||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||
| 				err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) | ||||
| 				if config.AutomaticDisableChannelEnabled { | ||||
| 					monitor.DisableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 				} else { | ||||
| @@ -221,11 +273,12 @@ func testChannels(notify bool, scope string) error { | ||||
| } | ||||
|  | ||||
| func TestChannels(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	scope := c.Query("scope") | ||||
| 	if scope == "" { | ||||
| 		scope = "all" | ||||
| 	} | ||||
| 	err := testChannels(true, scope) | ||||
| 	err := testChannels(ctx, true, scope) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -241,10 +294,11 @@ func TestChannels(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func AutomaticallyTestChannels(frequency int) { | ||||
| 	ctx := context.Background() | ||||
| 	for { | ||||
| 		time.Sleep(time.Duration(frequency) * time.Minute) | ||||
| 		logger.SysLog("testing all channels") | ||||
| 		_ = testChannels(false, "all") | ||||
| 		_ = testChannels(ctx, false, "all") | ||||
| 		logger.SysLog("channel test finished") | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -18,24 +18,30 @@ func GetStatus(c *gin.Context) { | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data": gin.H{ | ||||
| 			"version":             common.Version, | ||||
| 			"start_time":          common.StartTime, | ||||
| 			"email_verification":  config.EmailVerificationEnabled, | ||||
| 			"github_oauth":        config.GitHubOAuthEnabled, | ||||
| 			"github_client_id":    config.GitHubClientId, | ||||
| 			"lark_client_id":      config.LarkClientId, | ||||
| 			"system_name":         config.SystemName, | ||||
| 			"logo":                config.Logo, | ||||
| 			"footer_html":         config.Footer, | ||||
| 			"wechat_qrcode":       config.WeChatAccountQRCodeImageURL, | ||||
| 			"wechat_login":        config.WeChatAuthEnabled, | ||||
| 			"server_address":      config.ServerAddress, | ||||
| 			"turnstile_check":     config.TurnstileCheckEnabled, | ||||
| 			"turnstile_site_key":  config.TurnstileSiteKey, | ||||
| 			"top_up_link":         config.TopUpLink, | ||||
| 			"chat_link":           config.ChatLink, | ||||
| 			"quota_per_unit":      config.QuotaPerUnit, | ||||
| 			"display_in_currency": config.DisplayInCurrencyEnabled, | ||||
| 			"version":                     common.Version, | ||||
| 			"start_time":                  common.StartTime, | ||||
| 			"email_verification":          config.EmailVerificationEnabled, | ||||
| 			"github_oauth":                config.GitHubOAuthEnabled, | ||||
| 			"github_client_id":            config.GitHubClientId, | ||||
| 			"lark_client_id":              config.LarkClientId, | ||||
| 			"system_name":                 config.SystemName, | ||||
| 			"logo":                        config.Logo, | ||||
| 			"footer_html":                 config.Footer, | ||||
| 			"wechat_qrcode":               config.WeChatAccountQRCodeImageURL, | ||||
| 			"wechat_login":                config.WeChatAuthEnabled, | ||||
| 			"server_address":              config.ServerAddress, | ||||
| 			"turnstile_check":             config.TurnstileCheckEnabled, | ||||
| 			"turnstile_site_key":          config.TurnstileSiteKey, | ||||
| 			"top_up_link":                 config.TopUpLink, | ||||
| 			"chat_link":                   config.ChatLink, | ||||
| 			"quota_per_unit":              config.QuotaPerUnit, | ||||
| 			"display_in_currency":         config.DisplayInCurrencyEnabled, | ||||
| 			"oidc":                        config.OidcEnabled, | ||||
| 			"oidc_client_id":              config.OidcClientId, | ||||
| 			"oidc_well_known":             config.OidcWellKnown, | ||||
| 			"oidc_authorization_endpoint": config.OidcAuthorizationEndpoint, | ||||
| 			"oidc_token_endpoint":         config.OidcTokenEndpoint, | ||||
| 			"oidc_userinfo_endpoint":      config.OidcUserinfoEndpoint, | ||||
| 		}, | ||||
| 	}) | ||||
| 	return | ||||
|   | ||||
| @@ -34,6 +34,8 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { | ||||
| 		fallthrough | ||||
| 	case relaymode.AudioTranscription: | ||||
| 		err = controller.RelayAudioHelper(c, relayMode) | ||||
| 	case relaymode.Proxy: | ||||
| 		err = controller.RelayProxyHelper(c, relayMode) | ||||
| 	default: | ||||
| 		err = controller.RelayTextHelper(c) | ||||
| 	} | ||||
| @@ -58,7 +60,7 @@ func Relay(c *gin.Context) { | ||||
| 	channelName := c.GetString(ctxkey.ChannelName) | ||||
| 	group := c.GetString(ctxkey.Group) | ||||
| 	originalModel := c.GetString(ctxkey.OriginalModel) | ||||
| 	go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) | ||||
| 	go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) | ||||
| 	requestId := c.GetString(helper.RequestIdKey) | ||||
| 	retryTimes := config.RetryTimes | ||||
| 	if !shouldRetry(c, bizErr.StatusCode) { | ||||
| @@ -85,12 +87,14 @@ func Relay(c *gin.Context) { | ||||
| 		channelId := c.GetInt(ctxkey.ChannelId) | ||||
| 		lastFailedChannelId = channelId | ||||
| 		channelName := c.GetString(ctxkey.ChannelName) | ||||
| 		go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) | ||||
| 		go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) | ||||
| 	} | ||||
| 	if bizErr != nil { | ||||
| 		if bizErr.StatusCode == http.StatusTooManyRequests { | ||||
| 			bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" | ||||
| 		} | ||||
|  | ||||
| 		// BUG: bizErr is in race condition | ||||
| 		bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId) | ||||
| 		c.JSON(bizErr.StatusCode, gin.H{ | ||||
| 			"error": bizErr.Error, | ||||
| @@ -117,7 +121,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) { | ||||
| func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err model.ErrorWithStatusCode) { | ||||
| 	logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) | ||||
| 	// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||
| 	if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { | ||||
|   | ||||
| @@ -109,6 +109,7 @@ func Logout(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func Register(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	if !config.RegisterEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "管理员关闭了新用户注册", | ||||
| @@ -166,7 +167,7 @@ func Register(c *gin.Context) { | ||||
| 	if config.EmailVerificationEnabled { | ||||
| 		cleanUser.Email = user.Email | ||||
| 	} | ||||
| 	if err := cleanUser.Insert(inviterId); err != nil { | ||||
| 	if err := cleanUser.Insert(ctx, inviterId); err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| @@ -362,6 +363,7 @@ func GetSelf(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func UpdateUser(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	var updatedUser model.User | ||||
| 	err := json.NewDecoder(c.Request.Body).Decode(&updatedUser) | ||||
| 	if err != nil || updatedUser.Id == 0 { | ||||
| @@ -416,7 +418,7 @@ func UpdateUser(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
| 	if originUser.Quota != updatedUser.Quota { | ||||
| 		model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) | ||||
| 		model.RecordLog(ctx, originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| @@ -535,6 +537,7 @@ func DeleteSelf(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func CreateUser(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	var user model.User | ||||
| 	err := json.NewDecoder(c.Request.Body).Decode(&user) | ||||
| 	if err != nil || user.Username == "" || user.Password == "" { | ||||
| @@ -568,7 +571,7 @@ func CreateUser(c *gin.Context) { | ||||
| 		Password:    user.Password, | ||||
| 		DisplayName: user.DisplayName, | ||||
| 	} | ||||
| 	if err := cleanUser.Insert(0); err != nil { | ||||
| 	if err := cleanUser.Insert(ctx, 0); err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| @@ -747,6 +750,7 @@ type topUpRequest struct { | ||||
| } | ||||
|  | ||||
| func TopUp(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	req := topUpRequest{} | ||||
| 	err := c.ShouldBindJSON(&req) | ||||
| 	if err != nil { | ||||
| @@ -757,7 +761,7 @@ func TopUp(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
| 	id := c.GetInt("id") | ||||
| 	quota, err := model.Redeem(req.Key, id) | ||||
| 	quota, err := model.Redeem(ctx, req.Key, id) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -780,6 +784,7 @@ type adminTopUpRequest struct { | ||||
| } | ||||
|  | ||||
| func AdminTopUp(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	req := adminTopUpRequest{} | ||||
| 	err := c.ShouldBindJSON(&req) | ||||
| 	if err != nil { | ||||
| @@ -800,7 +805,7 @@ func AdminTopUp(c *gin.Context) { | ||||
| 	if req.Remark == "" { | ||||
| 		req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) | ||||
| 	} | ||||
| 	model.RecordTopupLog(req.UserId, req.Remark, req.Quota) | ||||
| 	model.RecordTopupLog(ctx, req.UserId, req.Remark, req.Quota) | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
|   | ||||
							
								
								
									
										42
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										42
									
								
								go.mod
									
									
									
									
									
								
							| @@ -1,9 +1,9 @@ | ||||
| module github.com/songquanpeng/one-api | ||||
|  | ||||
| // +heroku goVersion go1.18 | ||||
| go 1.20 | ||||
|  | ||||
| require ( | ||||
| 	cloud.google.com/go/iam v1.1.10 | ||||
| 	github.com/aws/aws-sdk-go-v2 v1.27.0 | ||||
| 	github.com/aws/aws-sdk-go-v2/credentials v1.17.15 | ||||
| 	github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 | ||||
| @@ -19,19 +19,25 @@ require ( | ||||
| 	github.com/gorilla/websocket v1.5.1 | ||||
| 	github.com/jinzhu/copier v0.4.0 | ||||
| 	github.com/joho/godotenv v1.5.1 | ||||
| 	github.com/patrickmn/go-cache v2.1.0+incompatible | ||||
| 	github.com/pkg/errors v0.9.1 | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.7 | ||||
| 	github.com/smartystreets/goconvey v1.8.1 | ||||
| 	github.com/stretchr/testify v1.9.0 | ||||
| 	golang.org/x/crypto v0.23.0 | ||||
| 	golang.org/x/crypto v0.31.0 | ||||
| 	golang.org/x/image v0.18.0 | ||||
| 	golang.org/x/sync v0.10.0 | ||||
| 	google.golang.org/api v0.187.0 | ||||
| 	gorm.io/driver/mysql v1.5.6 | ||||
| 	gorm.io/driver/postgres v1.5.7 | ||||
| 	gorm.io/driver/sqlite v1.5.5 | ||||
| 	gorm.io/driver/sqlite v1.5.1 | ||||
| 	gorm.io/gorm v1.25.10 | ||||
| ) | ||||
|  | ||||
| require ( | ||||
| 	cloud.google.com/go/auth v0.6.1 // indirect | ||||
| 	cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect | ||||
| 	cloud.google.com/go/compute/metadata v0.3.0 // indirect | ||||
| 	filippo.io/edwards25519 v1.1.0 // indirect | ||||
| 	github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect | ||||
| 	github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect | ||||
| @@ -45,13 +51,21 @@ require ( | ||||
| 	github.com/davecgh/go-spew v1.1.1 // indirect | ||||
| 	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect | ||||
| 	github.com/dlclark/regexp2 v1.11.0 // indirect | ||||
| 	github.com/felixge/httpsnoop v1.0.4 // indirect | ||||
| 	github.com/fsnotify/fsnotify v1.7.0 // indirect | ||||
| 	github.com/gabriel-vasile/mimetype v1.4.3 // indirect | ||||
| 	github.com/gin-contrib/sse v0.1.0 // indirect | ||||
| 	github.com/go-logr/logr v1.4.1 // indirect | ||||
| 	github.com/go-logr/stdr v1.2.2 // indirect | ||||
| 	github.com/go-playground/locales v0.14.1 // indirect | ||||
| 	github.com/go-playground/universal-translator v0.18.1 // indirect | ||||
| 	github.com/go-sql-driver/mysql v1.8.1 // indirect | ||||
| 	github.com/goccy/go-json v0.10.3 // indirect | ||||
| 	github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect | ||||
| 	github.com/golang/protobuf v1.5.4 // indirect | ||||
| 	github.com/google/s2a-go v0.1.7 // indirect | ||||
| 	github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect | ||||
| 	github.com/googleapis/gax-go/v2 v2.12.5 // indirect | ||||
| 	github.com/gopherjs/gopherjs v1.17.2 // indirect | ||||
| 	github.com/gorilla/context v1.1.2 // indirect | ||||
| 	github.com/gorilla/securecookie v1.1.2 // indirect | ||||
| @@ -68,7 +82,7 @@ require ( | ||||
| 	github.com/kr/text v0.2.0 // indirect | ||||
| 	github.com/leodido/go-urn v1.4.0 // indirect | ||||
| 	github.com/mattn/go-isatty v0.0.20 // indirect | ||||
| 	github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect | ||||
| 	github.com/mattn/go-sqlite3 v1.14.24 // indirect | ||||
| 	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect | ||||
| 	github.com/modern-go/reflect2 v1.0.2 // indirect | ||||
| 	github.com/pelletier/go-toml/v2 v2.2.2 // indirect | ||||
| @@ -76,11 +90,21 @@ require ( | ||||
| 	github.com/smarty/assertions v1.15.0 // indirect | ||||
| 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | ||||
| 	github.com/ugorji/go/codec v1.2.12 // indirect | ||||
| 	go.opencensus.io v0.24.0 // indirect | ||||
| 	go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect | ||||
| 	go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect | ||||
| 	go.opentelemetry.io/otel v1.24.0 // indirect | ||||
| 	go.opentelemetry.io/otel/metric v1.24.0 // indirect | ||||
| 	go.opentelemetry.io/otel/trace v1.24.0 // indirect | ||||
| 	golang.org/x/arch v0.8.0 // indirect | ||||
| 	golang.org/x/net v0.25.0 // indirect | ||||
| 	golang.org/x/sync v0.7.0 // indirect | ||||
| 	golang.org/x/sys v0.20.0 // indirect | ||||
| 	golang.org/x/text v0.16.0 // indirect | ||||
| 	google.golang.org/protobuf v1.34.1 // indirect | ||||
| 	golang.org/x/net v0.26.0 // indirect | ||||
| 	golang.org/x/oauth2 v0.21.0 // indirect | ||||
| 	golang.org/x/sys v0.28.0 // indirect | ||||
| 	golang.org/x/text v0.21.0 // indirect | ||||
| 	golang.org/x/time v0.5.0 // indirect | ||||
| 	google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect | ||||
| 	google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect | ||||
| 	google.golang.org/grpc v1.64.1 // indirect | ||||
| 	google.golang.org/protobuf v1.34.2 // indirect | ||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||
| ) | ||||
|   | ||||
							
								
								
									
										164
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										164
									
								
								go.sum
									
									
									
									
									
								
							| @@ -1,5 +1,15 @@ | ||||
| cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= | ||||
| cloud.google.com/go/auth v0.6.1 h1:T0Zw1XM5c1GlpN2HYr2s+m3vr1p2wy+8VN+Z1FKxW38= | ||||
| cloud.google.com/go/auth v0.6.1/go.mod h1:eFHG7zDzbXHKmjJddFG/rBlcGp6t25SwRUiEQSlO4x4= | ||||
| cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= | ||||
| cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= | ||||
| cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= | ||||
| cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= | ||||
| cloud.google.com/go/iam v1.1.10 h1:ZSAr64oEhQSClwBL670MsJAW5/RLiC6kfw3Bqmd5ZDI= | ||||
| cloud.google.com/go/iam v1.1.10/go.mod h1:iEgMq62sg8zx446GCaijmA2Miwg5o3UbO+nI47WHJps= | ||||
| filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= | ||||
| filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= | ||||
| github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= | ||||
| github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo= | ||||
| github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= | ||||
| github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= | ||||
| @@ -18,12 +28,15 @@ github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc | ||||
| github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= | ||||
| github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= | ||||
| github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= | ||||
| github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= | ||||
| github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= | ||||
| github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= | ||||
| github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= | ||||
| github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= | ||||
| github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= | ||||
| github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= | ||||
| github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= | ||||
| github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= | ||||
| github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= | ||||
| github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | ||||
| github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= | ||||
| @@ -32,6 +45,12 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r | ||||
| github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= | ||||
| github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= | ||||
| github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= | ||||
| github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= | ||||
| github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= | ||||
| github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= | ||||
| github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= | ||||
| github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= | ||||
| github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= | ||||
| github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= | ||||
| github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= | ||||
| github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= | ||||
| @@ -48,6 +67,11 @@ github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0Nglqm | ||||
| github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw= | ||||
| github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= | ||||
| github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= | ||||
| github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= | ||||
| github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= | ||||
| github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= | ||||
| github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= | ||||
| github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= | ||||
| github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= | ||||
| github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= | ||||
| github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= | ||||
| @@ -64,11 +88,40 @@ github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= | ||||
| github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= | ||||
| github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= | ||||
| github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= | ||||
| github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= | ||||
| github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= | ||||
| github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= | ||||
| github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= | ||||
| github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= | ||||
| github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= | ||||
| github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= | ||||
| github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= | ||||
| github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= | ||||
| github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= | ||||
| github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= | ||||
| github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= | ||||
| github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= | ||||
| github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= | ||||
| github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= | ||||
| github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= | ||||
| github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= | ||||
| github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= | ||||
| github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= | ||||
| github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= | ||||
| github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | ||||
| github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | ||||
| github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | ||||
| github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= | ||||
| github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= | ||||
| github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= | ||||
| github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= | ||||
| github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= | ||||
| github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= | ||||
| github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= | ||||
| github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= | ||||
| github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= | ||||
| github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= | ||||
| github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBYGmXdxA= | ||||
| github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E= | ||||
| github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= | ||||
| github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= | ||||
| github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o= | ||||
| @@ -110,8 +163,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= | ||||
| github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= | ||||
| github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= | ||||
| github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= | ||||
| github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= | ||||
| github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= | ||||
| github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= | ||||
| github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= | ||||
| github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= | ||||
| github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= | ||||
| github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= | ||||
| @@ -120,6 +173,8 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY | ||||
| github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= | ||||
| github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= | ||||
| github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= | ||||
| github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= | ||||
| github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= | ||||
| github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= | ||||
| github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= | ||||
| github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= | ||||
| @@ -128,6 +183,7 @@ github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQ | ||||
| github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= | ||||
| github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||||
| github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||||
| github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= | ||||
| github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= | ||||
| github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= | ||||
| github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= | ||||
| @@ -149,26 +205,96 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS | ||||
| github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= | ||||
| github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= | ||||
| github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= | ||||
| go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= | ||||
| go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= | ||||
| go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= | ||||
| go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0= | ||||
| go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk= | ||||
| go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw= | ||||
| go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo= | ||||
| go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo= | ||||
| go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI= | ||||
| go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco= | ||||
| go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI= | ||||
| go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU= | ||||
| golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | ||||
| golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= | ||||
| golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= | ||||
| golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= | ||||
| golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= | ||||
| golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= | ||||
| golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= | ||||
| golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= | ||||
| golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= | ||||
| golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= | ||||
| golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= | ||||
| golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= | ||||
| golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= | ||||
| golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= | ||||
| golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= | ||||
| golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= | ||||
| golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= | ||||
| golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= | ||||
| golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= | ||||
| golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= | ||||
| golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= | ||||
| golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= | ||||
| golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= | ||||
| golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= | ||||
| golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= | ||||
| golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= | ||||
| golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= | ||||
| golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= | ||||
| golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= | ||||
| golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= | ||||
| golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= | ||||
| golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= | ||||
| golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= | ||||
| golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||
| golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= | ||||
| golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= | ||||
| golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= | ||||
| google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= | ||||
| google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= | ||||
| golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= | ||||
| golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||
| golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | ||||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= | ||||
| golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= | ||||
| golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= | ||||
| golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= | ||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||
| golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||
| golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= | ||||
| golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= | ||||
| golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| google.golang.org/api v0.187.0 h1:Mxs7VATVC2v7CY+7Xwm4ndkX71hpElcvx0D1Ji/p1eo= | ||||
| google.golang.org/api v0.187.0/go.mod h1:KIHlTc4x7N7gKKuVsdmfBXN13yEEWXWFURWY6SBp2gk= | ||||
| google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= | ||||
| google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= | ||||
| google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= | ||||
| google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= | ||||
| google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= | ||||
| google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 h1:MuYw1wJzT+ZkybKfaOXKp5hJiZDn2iHaXRw0mRYdHSc= | ||||
| google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4/go.mod h1:px9SlOOZBg1wM1zdnr8jEL4CNGUBZ+ZKYtNPApNQc4c= | ||||
| google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d h1:k3zyW3BYYR30e8v3x0bTDdE9vpYFjZHK+HcyqkrppWk= | ||||
| google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= | ||||
| google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= | ||||
| google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= | ||||
| google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= | ||||
| google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= | ||||
| google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= | ||||
| google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= | ||||
| google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= | ||||
| google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= | ||||
| google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= | ||||
| google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= | ||||
| google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= | ||||
| google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= | ||||
| google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= | ||||
| google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= | ||||
| google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= | ||||
| google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= | ||||
| google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= | ||||
| google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= | ||||
| gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= | ||||
| gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= | ||||
| @@ -180,10 +306,12 @@ gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= | ||||
| gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= | ||||
| gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= | ||||
| gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= | ||||
| gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E= | ||||
| gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE= | ||||
| gorm.io/driver/sqlite v1.5.1 h1:hYyrLkAWE71bcarJDPdZNTLWtr8XrSjOWyjUYI6xdL4= | ||||
| gorm.io/driver/sqlite v1.5.1/go.mod h1:7MZZ2Z8bqyfSQA1gYEV6MagQWj3cpUkJj9Z+d1HEMEQ= | ||||
| gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= | ||||
| gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= | ||||
| gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= | ||||
| honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= | ||||
| honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= | ||||
| nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= | ||||
| rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= | ||||
|   | ||||
							
								
								
									
										22
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								main.go
									
									
									
									
									
								
							| @@ -27,27 +27,19 @@ func main() { | ||||
| 	common.Init() | ||||
| 	logger.SetupLogger() | ||||
| 	logger.SysLogf("One API %s started", common.Version) | ||||
| 	if os.Getenv("GIN_MODE") != "debug" { | ||||
|  | ||||
| 	if os.Getenv("GIN_MODE") != gin.DebugMode { | ||||
| 		gin.SetMode(gin.ReleaseMode) | ||||
| 	} | ||||
| 	if config.DebugEnabled { | ||||
| 		logger.SysLog("running in debug mode") | ||||
| 	} | ||||
| 	var err error | ||||
|  | ||||
| 	// Initialize SQL Database | ||||
| 	model.DB, err = model.InitDB("SQL_DSN") | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("failed to initialize database: " + err.Error()) | ||||
| 	} | ||||
| 	if os.Getenv("LOG_SQL_DSN") != "" { | ||||
| 		logger.SysLog("using secondary database for table logs") | ||||
| 		model.LOG_DB, err = model.InitDB("LOG_SQL_DSN") | ||||
| 		if err != nil { | ||||
| 			logger.FatalLog("failed to initialize secondary database: " + err.Error()) | ||||
| 		} | ||||
| 	} else { | ||||
| 		model.LOG_DB = model.DB | ||||
| 	} | ||||
| 	model.InitDB() | ||||
| 	model.InitLogDB() | ||||
|  | ||||
| 	var err error | ||||
| 	err = model.CreateRootAccountIfNeed() | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("database init error: " + err.Error()) | ||||
|   | ||||
| @@ -140,6 +140,12 @@ func TokenAuth() func(c *gin.Context) { | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// set channel id for proxy relay | ||||
| 		if channelId := c.Param("channelid"); channelId != "" { | ||||
| 			c.Set(ctxkey.SpecificChannelId, channelId) | ||||
| 		} | ||||
|  | ||||
| 		c.Next() | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -2,21 +2,24 @@ package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| ) | ||||
|  | ||||
| type ModelRequest struct { | ||||
| 	Model string `json:"model"` | ||||
| 	Model string `json:"model" form:"model"` | ||||
| } | ||||
|  | ||||
| func Distribute() func(c *gin.Context) { | ||||
| 	return func(c *gin.Context) { | ||||
| 		ctx := c.Request.Context() | ||||
| 		userId := c.GetInt(ctxkey.Id) | ||||
| 		userGroup, _ := model.CacheGetUserGroup(userId) | ||||
| 		c.Set(ctxkey.Group, userGroup) | ||||
| @@ -52,6 +55,7 @@ func Distribute() func(c *gin.Context) { | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 		logger.Debugf(ctx, "user id %d, user group: %s, request model: %s, using channel #%d", userId, userGroup, requestModel, channel.Id) | ||||
| 		SetupContextForSelectedChannel(c, channel, requestModel) | ||||
| 		c.Next() | ||||
| 	} | ||||
| @@ -61,6 +65,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode | ||||
| 	c.Set(ctxkey.Channel, channel.Type) | ||||
| 	c.Set(ctxkey.ChannelId, channel.Id) | ||||
| 	c.Set(ctxkey.ChannelName, channel.Name) | ||||
| 	if channel.SystemPrompt != nil && *channel.SystemPrompt != "" { | ||||
| 		c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt) | ||||
| 	} | ||||
| 	c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) | ||||
| 	c.Set(ctxkey.OriginalModel, modelName) // for retry | ||||
| 	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||
|   | ||||
							
								
								
									
										27
									
								
								middleware/gzip.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								middleware/gzip.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,27 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"compress/gzip" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| func GzipDecodeMiddleware() gin.HandlerFunc { | ||||
| 	return func(c *gin.Context) { | ||||
| 		if c.GetHeader("Content-Encoding") == "gzip" { | ||||
| 			gzipReader, err := gzip.NewReader(c.Request.Body) | ||||
| 			if err != nil { | ||||
| 				c.AbortWithStatus(http.StatusBadRequest) | ||||
| 				return | ||||
| 			} | ||||
| 			defer gzipReader.Close() | ||||
|  | ||||
| 			// Replace the request body with the decompressed data | ||||
| 			c.Request.Body = io.NopCloser(gzipReader) | ||||
| 		} | ||||
|  | ||||
| 		// Continue processing the request | ||||
| 		c.Next() | ||||
| 	} | ||||
| } | ||||
| @@ -3,11 +3,12 @@ package middleware | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var timeFormat = "2006-01-02T15:04:05.000Z" | ||||
| @@ -70,6 +71,11 @@ func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark s | ||||
| } | ||||
|  | ||||
| func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) { | ||||
| 	if maxRequestNum == 0 { | ||||
| 		return func(c *gin.Context) { | ||||
| 			c.Next() | ||||
| 		} | ||||
| 	} | ||||
| 	if common.RedisEnabled { | ||||
| 		return func(c *gin.Context) { | ||||
| 			redisRateLimiter(c, maxRequestNum, duration, mark) | ||||
|   | ||||
| @@ -1,8 +1,8 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| ) | ||||
|  | ||||
| @@ -10,7 +10,7 @@ func RequestId() func(c *gin.Context) { | ||||
| 	return func(c *gin.Context) { | ||||
| 		id := helper.GenRequestID() | ||||
| 		c.Set(helper.RequestIdKey, id) | ||||
| 		ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id) | ||||
| 		ctx := helper.SetRequestID(c.Request.Context(), id) | ||||
| 		c.Request = c.Request.WithContext(ctx) | ||||
| 		c.Header(helper.RequestIdKey, id) | ||||
| 		c.Next() | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package model | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| @@ -36,16 +37,19 @@ type Channel struct { | ||||
| 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | ||||
| 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"` | ||||
| 	Config             string  `json:"config"` | ||||
| 	SystemPrompt       *string `json:"system_prompt" gorm:"type:text"` | ||||
| } | ||||
|  | ||||
| type ChannelConfig struct { | ||||
| 	Region     string `json:"region,omitempty"` | ||||
| 	SK         string `json:"sk,omitempty"` | ||||
| 	AK         string `json:"ak,omitempty"` | ||||
| 	UserID     string `json:"user_id,omitempty"` | ||||
| 	APIVersion string `json:"api_version,omitempty"` | ||||
| 	LibraryID  string `json:"library_id,omitempty"` | ||||
| 	Plugin     string `json:"plugin,omitempty"` | ||||
| 	Region            string `json:"region,omitempty"` | ||||
| 	SK                string `json:"sk,omitempty"` | ||||
| 	AK                string `json:"ak,omitempty"` | ||||
| 	UserID            string `json:"user_id,omitempty"` | ||||
| 	APIVersion        string `json:"api_version,omitempty"` | ||||
| 	LibraryID         string `json:"library_id,omitempty"` | ||||
| 	Plugin            string `json:"plugin,omitempty"` | ||||
| 	VertexAIProjectID string `json:"vertex_ai_project_id,omitempty"` | ||||
| 	VertexAIADC       string `json:"vertex_ai_adc,omitempty"` | ||||
| } | ||||
|  | ||||
| func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { | ||||
|   | ||||
							
								
								
									
										100
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										100
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -3,26 +3,32 @@ package model | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type Log struct { | ||||
| 	Id               int    `json:"id"` | ||||
| 	UserId           int    `json:"user_id" gorm:"index"` | ||||
| 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index:idx_created_at_type"` | ||||
| 	Type             int    `json:"type" gorm:"index:idx_created_at_type"` | ||||
| 	Content          string `json:"content"` | ||||
| 	Username         string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` | ||||
| 	TokenName        string `json:"token_name" gorm:"index;default:''"` | ||||
| 	ModelName        string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` | ||||
| 	Quota            int    `json:"quota" gorm:"default:0"` | ||||
| 	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"` | ||||
| 	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"` | ||||
| 	ChannelId        int    `json:"channel" gorm:"index"` | ||||
| 	Id                int    `json:"id"` | ||||
| 	UserId            int    `json:"user_id" gorm:"index"` | ||||
| 	CreatedAt         int64  `json:"created_at" gorm:"bigint;index:idx_created_at_type"` | ||||
| 	Type              int    `json:"type" gorm:"index:idx_created_at_type"` | ||||
| 	Content           string `json:"content"` | ||||
| 	Username          string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` | ||||
| 	TokenName         string `json:"token_name" gorm:"index;default:''"` | ||||
| 	ModelName         string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` | ||||
| 	Quota             int    `json:"quota" gorm:"default:0"` | ||||
| 	PromptTokens      int    `json:"prompt_tokens" gorm:"default:0"` | ||||
| 	CompletionTokens  int    `json:"completion_tokens" gorm:"default:0"` | ||||
| 	ChannelId         int    `json:"channel" gorm:"index"` | ||||
| 	RequestId         string `json:"request_id" gorm:"default:''"` | ||||
| 	ElapsedTime       int64  `json:"elapsed_time" gorm:"default:0"` // unit is ms | ||||
| 	IsStream          bool   `json:"is_stream" gorm:"default:false"` | ||||
| 	SystemPromptReset bool   `json:"system_prompt_reset" gorm:"default:false"` | ||||
| } | ||||
|  | ||||
| const ( | ||||
| @@ -31,9 +37,21 @@ const ( | ||||
| 	LogTypeConsume | ||||
| 	LogTypeManage | ||||
| 	LogTypeSystem | ||||
| 	LogTypeTest | ||||
| ) | ||||
|  | ||||
| func RecordLog(userId int, logType int, content string) { | ||||
| func recordLogHelper(ctx context.Context, log *Log) { | ||||
| 	requestId := helper.GetRequestID(ctx) | ||||
| 	log.RequestId = requestId | ||||
| 	err := LOG_DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| 		logger.Error(ctx, "failed to record log: "+err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	logger.Infof(ctx, "record log: %+v", log) | ||||
| } | ||||
|  | ||||
| func RecordLog(ctx context.Context, userId int, logType int, content string) { | ||||
| 	if logType == LogTypeConsume && !config.LogConsumeEnabled { | ||||
| 		return | ||||
| 	} | ||||
| @@ -44,13 +62,10 @@ func RecordLog(userId int, logType int, content string) { | ||||
| 		Type:      logType, | ||||
| 		Content:   content, | ||||
| 	} | ||||
| 	err := LOG_DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| 		logger.SysError("failed to record log: " + err.Error()) | ||||
| 	} | ||||
| 	recordLogHelper(ctx, log) | ||||
| } | ||||
|  | ||||
| func RecordTopupLog(userId int, content string, quota int) { | ||||
| func RecordTopupLog(ctx context.Context, userId int, content string, quota int) { | ||||
| 	log := &Log{ | ||||
| 		UserId:    userId, | ||||
| 		Username:  GetUsernameById(userId), | ||||
| @@ -59,34 +74,23 @@ func RecordTopupLog(userId int, content string, quota int) { | ||||
| 		Content:   content, | ||||
| 		Quota:     quota, | ||||
| 	} | ||||
| 	err := LOG_DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| 		logger.SysError("failed to record log: " + err.Error()) | ||||
| 	} | ||||
| 	recordLogHelper(ctx, log) | ||||
| } | ||||
|  | ||||
| func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { | ||||
| 	logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||
| func RecordConsumeLog(ctx context.Context, log *Log) { | ||||
| 	if !config.LogConsumeEnabled { | ||||
| 		return | ||||
| 	} | ||||
| 	log := &Log{ | ||||
| 		UserId:           userId, | ||||
| 		Username:         GetUsernameById(userId), | ||||
| 		CreatedAt:        helper.GetTimestamp(), | ||||
| 		Type:             LogTypeConsume, | ||||
| 		Content:          content, | ||||
| 		PromptTokens:     promptTokens, | ||||
| 		CompletionTokens: completionTokens, | ||||
| 		TokenName:        tokenName, | ||||
| 		ModelName:        modelName, | ||||
| 		Quota:            int(quota), | ||||
| 		ChannelId:        channelId, | ||||
| 	} | ||||
| 	err := LOG_DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| 		logger.Error(ctx, "failed to record log: "+err.Error()) | ||||
| 	} | ||||
| 	log.Username = GetUsernameById(log.UserId) | ||||
| 	log.CreatedAt = helper.GetTimestamp() | ||||
| 	log.Type = LogTypeConsume | ||||
| 	recordLogHelper(ctx, log) | ||||
| } | ||||
|  | ||||
| func RecordTestLog(ctx context.Context, log *Log) { | ||||
| 	log.CreatedAt = helper.GetTimestamp() | ||||
| 	log.Type = LogTypeTest | ||||
| 	recordLogHelper(ctx, log) | ||||
| } | ||||
|  | ||||
| func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { | ||||
| @@ -152,7 +156,11 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | ||||
| } | ||||
|  | ||||
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { | ||||
| 	tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)") | ||||
| 	ifnull := "ifnull" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		ifnull = "COALESCE" | ||||
| 	} | ||||
| 	tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull)) | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
| 	} | ||||
| @@ -176,7 +184,11 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa | ||||
| } | ||||
|  | ||||
| func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | ||||
| 	tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") | ||||
| 	ifnull := "ifnull" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		ifnull = "COALESCE" | ||||
| 	} | ||||
| 	tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull)) | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
| 	} | ||||
|   | ||||
							
								
								
									
										225
									
								
								model/main.go
									
									
									
									
									
								
							
							
						
						
									
										225
									
								
								model/main.go
									
									
									
									
									
								
							| @@ -1,6 +1,7 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| @@ -29,13 +30,17 @@ func CreateRootAccountIfNeed() error { | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		accessToken := random.GetUUID() | ||||
| 		if config.InitialRootAccessToken != "" { | ||||
| 			accessToken = config.InitialRootAccessToken | ||||
| 		} | ||||
| 		rootUser := User{ | ||||
| 			Username:    "root", | ||||
| 			Password:    hashedPassword, | ||||
| 			Role:        RoleRootUser, | ||||
| 			Status:      UserStatusEnabled, | ||||
| 			DisplayName: "Root User", | ||||
| 			AccessToken: random.GetUUID(), | ||||
| 			AccessToken: accessToken, | ||||
| 			Quota:       500000000000000, | ||||
| 		} | ||||
| 		DB.Create(&rootUser) | ||||
| @@ -60,90 +65,156 @@ func CreateRootAccountIfNeed() error { | ||||
| } | ||||
|  | ||||
| func chooseDB(envName string) (*gorm.DB, error) { | ||||
| 	if os.Getenv(envName) != "" { | ||||
| 		dsn := os.Getenv(envName) | ||||
| 		if strings.HasPrefix(dsn, "postgres://") { | ||||
| 			// Use PostgreSQL | ||||
| 			logger.SysLog("using PostgreSQL as database") | ||||
| 			common.UsingPostgreSQL = true | ||||
| 			return gorm.Open(postgres.New(postgres.Config{ | ||||
| 				DSN:                  dsn, | ||||
| 				PreferSimpleProtocol: true, // disables implicit prepared statement usage | ||||
| 			}), &gorm.Config{ | ||||
| 				PrepareStmt: true, // precompile SQL | ||||
| 			}) | ||||
| 		} | ||||
| 	dsn := os.Getenv(envName) | ||||
|  | ||||
| 	switch { | ||||
| 	case strings.HasPrefix(dsn, "postgres://"): | ||||
| 		// Use PostgreSQL | ||||
| 		return openPostgreSQL(dsn) | ||||
| 	case dsn != "": | ||||
| 		// Use MySQL | ||||
| 		logger.SysLog("using MySQL as database") | ||||
| 		common.UsingMySQL = true | ||||
| 		return gorm.Open(mysql.Open(dsn), &gorm.Config{ | ||||
| 			PrepareStmt: true, // precompile SQL | ||||
| 		}) | ||||
| 		return openMySQL(dsn) | ||||
| 	default: | ||||
| 		// Use SQLite | ||||
| 		return openSQLite() | ||||
| 	} | ||||
| 	// Use SQLite | ||||
| 	logger.SysLog("SQL_DSN not set, using SQLite as database") | ||||
| 	common.UsingSQLite = true | ||||
| 	config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) | ||||
| 	return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ | ||||
| } | ||||
|  | ||||
| func openPostgreSQL(dsn string) (*gorm.DB, error) { | ||||
| 	logger.SysLog("using PostgreSQL as database") | ||||
| 	common.UsingPostgreSQL = true | ||||
| 	return gorm.Open(postgres.New(postgres.Config{ | ||||
| 		DSN:                  dsn, | ||||
| 		PreferSimpleProtocol: true, // disables implicit prepared statement usage | ||||
| 	}), &gorm.Config{ | ||||
| 		PrepareStmt: true, // precompile SQL | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func InitDB(envName string) (db *gorm.DB, err error) { | ||||
| 	db, err = chooseDB(envName) | ||||
| 	if err == nil { | ||||
| 		if config.DebugSQLEnabled { | ||||
| 			db = db.Debug() | ||||
| 		} | ||||
| 		sqlDB, err := db.DB() | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) | ||||
| 		sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) | ||||
| 		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) | ||||
| func openMySQL(dsn string) (*gorm.DB, error) { | ||||
| 	logger.SysLog("using MySQL as database") | ||||
| 	common.UsingMySQL = true | ||||
| 	return gorm.Open(mysql.Open(dsn), &gorm.Config{ | ||||
| 		PrepareStmt: true, // precompile SQL | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| 		if !config.IsMasterNode { | ||||
| 			return db, err | ||||
| 		} | ||||
| 		if common.UsingMySQL { | ||||
| 			_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded | ||||
| 		} | ||||
| 		logger.SysLog("database migration started") | ||||
| 		err = db.AutoMigrate(&Channel{}) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&Token{}) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&User{}) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&Option{}) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&Redemption{}) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&Ability{}) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = db.AutoMigrate(&Log{}) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		logger.SysLog("database migrated") | ||||
| 		return db, err | ||||
| 	} else { | ||||
| 		logger.FatalLog(err) | ||||
| func openSQLite() (*gorm.DB, error) { | ||||
| 	logger.SysLog("SQL_DSN not set, using SQLite as database") | ||||
| 	common.UsingSQLite = true | ||||
| 	dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout) | ||||
| 	return gorm.Open(sqlite.Open(dsn), &gorm.Config{ | ||||
| 		PrepareStmt: true, // precompile SQL | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func InitDB() { | ||||
| 	var err error | ||||
| 	DB, err = chooseDB("SQL_DSN") | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("failed to initialize database: " + err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	return db, err | ||||
|  | ||||
| 	sqlDB := setDBConns(DB) | ||||
|  | ||||
| 	if !config.IsMasterNode { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if common.UsingMySQL { | ||||
| 		_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded | ||||
| 	} | ||||
|  | ||||
| 	logger.SysLog("database migration started") | ||||
| 	if err = migrateDB(); err != nil { | ||||
| 		logger.FatalLog("failed to migrate database: " + err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	logger.SysLog("database migrated") | ||||
| } | ||||
|  | ||||
| func migrateDB() error { | ||||
| 	var err error | ||||
| 	if err = DB.AutoMigrate(&Channel{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err = DB.AutoMigrate(&Token{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err = DB.AutoMigrate(&User{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err = DB.AutoMigrate(&Option{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err = DB.AutoMigrate(&Redemption{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err = DB.AutoMigrate(&Ability{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err = DB.AutoMigrate(&Log{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err = DB.AutoMigrate(&Channel{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func InitLogDB() { | ||||
| 	if os.Getenv("LOG_SQL_DSN") == "" { | ||||
| 		LOG_DB = DB | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	logger.SysLog("using secondary database for table logs") | ||||
| 	var err error | ||||
| 	LOG_DB, err = chooseDB("LOG_SQL_DSN") | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("failed to initialize secondary database: " + err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	setDBConns(LOG_DB) | ||||
|  | ||||
| 	if !config.IsMasterNode { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	logger.SysLog("secondary database migration started") | ||||
| 	err = migrateLOGDB() | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("failed to migrate secondary database: " + err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	logger.SysLog("secondary database migrated") | ||||
| } | ||||
|  | ||||
| func migrateLOGDB() error { | ||||
| 	var err error | ||||
| 	if err = LOG_DB.AutoMigrate(&Log{}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func setDBConns(db *gorm.DB) *sql.DB { | ||||
| 	if config.DebugSQLEnabled { | ||||
| 		db = db.Debug() | ||||
| 	} | ||||
|  | ||||
| 	sqlDB, err := db.DB() | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("failed to connect database: " + err.Error()) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) | ||||
| 	sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) | ||||
| 	sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) | ||||
| 	return sqlDB | ||||
| } | ||||
|  | ||||
| func closeDB(db *gorm.DB) error { | ||||
|   | ||||
| @@ -28,6 +28,7 @@ func InitOptionMap() { | ||||
| 	config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) | ||||
| 	config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) | ||||
| 	config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) | ||||
| 	config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled) | ||||
| 	config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) | ||||
| 	config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) | ||||
| 	config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) | ||||
| @@ -130,6 +131,8 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 			config.EmailVerificationEnabled = boolValue | ||||
| 		case "GitHubOAuthEnabled": | ||||
| 			config.GitHubOAuthEnabled = boolValue | ||||
| 		case "OidcEnabled": | ||||
| 			config.OidcEnabled = boolValue | ||||
| 		case "WeChatAuthEnabled": | ||||
| 			config.WeChatAuthEnabled = boolValue | ||||
| 		case "TurnstileCheckEnabled": | ||||
| @@ -176,6 +179,18 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 		config.LarkClientId = value | ||||
| 	case "LarkClientSecret": | ||||
| 		config.LarkClientSecret = value | ||||
| 	case "OidcClientId": | ||||
| 		config.OidcClientId = value | ||||
| 	case "OidcClientSecret": | ||||
| 		config.OidcClientSecret = value | ||||
| 	case "OidcWellKnown": | ||||
| 		config.OidcWellKnown = value | ||||
| 	case "OidcAuthorizationEndpoint": | ||||
| 		config.OidcAuthorizationEndpoint = value | ||||
| 	case "OidcTokenEndpoint": | ||||
| 		config.OidcTokenEndpoint = value | ||||
| 	case "OidcUserinfoEndpoint": | ||||
| 		config.OidcUserinfoEndpoint = value | ||||
| 	case "Footer": | ||||
| 		config.Footer = value | ||||
| 	case "SystemName": | ||||
|   | ||||
| @@ -1,11 +1,14 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -48,7 +51,7 @@ func GetRedemptionById(id int) (*Redemption, error) { | ||||
| 	return &redemption, err | ||||
| } | ||||
|  | ||||
| func Redeem(key string, userId int) (quota int64, err error) { | ||||
| func Redeem(ctx context.Context, key string, userId int) (quota int64, err error) { | ||||
| 	if key == "" { | ||||
| 		return 0, errors.New("未提供兑换码") | ||||
| 	} | ||||
| @@ -82,7 +85,7 @@ func Redeem(key string, userId int) (quota int64, err error) { | ||||
| 	if err != nil { | ||||
| 		return 0, errors.New("兑换失败," + err.Error()) | ||||
| 	} | ||||
| 	RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) | ||||
| 	RecordLog(ctx, userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) | ||||
| 	return redemption.Quota, nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -30,7 +30,7 @@ type Token struct { | ||||
| 	RemainQuota    int64   `json:"remain_quota" gorm:"bigint;default:0"` | ||||
| 	UnlimitedQuota bool    `json:"unlimited_quota" gorm:"default:false"` | ||||
| 	UsedQuota      int64   `json:"used_quota" gorm:"bigint;default:0"` // used quota | ||||
| 	Models         *string `json:"models" gorm:"default:''"`           // allowed models | ||||
| 	Models         *string `json:"models" gorm:"type:text"`            // allowed models | ||||
| 	Subnet         *string `json:"subnet" gorm:"default:''"`           // allowed subnet | ||||
| } | ||||
|  | ||||
| @@ -121,30 +121,40 @@ func GetTokenById(id int) (*Token, error) { | ||||
| 	return &token, err | ||||
| } | ||||
|  | ||||
| func (token *Token) Insert() error { | ||||
| func (t *Token) Insert() error { | ||||
| 	var err error | ||||
| 	err = DB.Create(token).Error | ||||
| 	err = DB.Create(t).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // Update Make sure your token's fields is completed, because this will update non-zero values | ||||
| func (token *Token) Update() error { | ||||
| func (t *Token) Update() error { | ||||
| 	var err error | ||||
| 	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error | ||||
| 	err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (token *Token) SelectUpdate() error { | ||||
| func (t *Token) SelectUpdate() error { | ||||
| 	// This can update zero values | ||||
| 	return DB.Model(token).Select("accessed_time", "status").Updates(token).Error | ||||
| 	return DB.Model(t).Select("accessed_time", "status").Updates(t).Error | ||||
| } | ||||
|  | ||||
| func (token *Token) Delete() error { | ||||
| func (t *Token) Delete() error { | ||||
| 	var err error | ||||
| 	err = DB.Delete(token).Error | ||||
| 	err = DB.Delete(t).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (t *Token) GetModels() string { | ||||
| 	if t == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	if t.Models == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	return *t.Models | ||||
| } | ||||
|  | ||||
| func DeleteTokenById(id int, userId int) (err error) { | ||||
| 	// Why we need userId here? In case user want to delete other's token. | ||||
| 	if id == 0 || userId == 0 { | ||||
| @@ -254,14 +264,14 @@ func PreConsumeTokenQuota(tokenId int, quota int64) (err error) { | ||||
|  | ||||
| func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { | ||||
| 	token, err := GetTokenById(tokenId) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if quota > 0 { | ||||
| 		err = DecreaseUserQuota(token.UserId, quota) | ||||
| 	} else { | ||||
| 		err = IncreaseUserQuota(token.UserId, -quota) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if !token.UnlimitedQuota { | ||||
| 		if quota > 0 { | ||||
| 			err = DecreaseTokenQuota(tokenId, quota) | ||||
|   | ||||
| @@ -1,16 +1,19 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/blacklist" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"gorm.io/gorm" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -39,6 +42,7 @@ type User struct { | ||||
| 	GitHubId         string `json:"github_id" gorm:"column:github_id;index"` | ||||
| 	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"` | ||||
| 	LarkId           string `json:"lark_id" gorm:"column:lark_id;index"` | ||||
| 	OidcId           string `json:"oidc_id" gorm:"column:oidc_id;index"` | ||||
| 	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database! | ||||
| 	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management | ||||
| 	Quota            int64  `json:"quota" gorm:"bigint;default:0"` | ||||
| @@ -91,7 +95,7 @@ func GetUserById(id int, selectAll bool) (*User, error) { | ||||
| 	if selectAll { | ||||
| 		err = DB.First(&user, "id = ?", id).Error | ||||
| 	} else { | ||||
| 		err = DB.Omit("password").First(&user, "id = ?", id).Error | ||||
| 		err = DB.Omit("password", "access_token").First(&user, "id = ?", id).Error | ||||
| 	} | ||||
| 	return &user, err | ||||
| } | ||||
| @@ -113,7 +117,7 @@ func DeleteUserById(id int) (err error) { | ||||
| 	return user.Delete() | ||||
| } | ||||
|  | ||||
| func (user *User) Insert(inviterId int) error { | ||||
| func (user *User) Insert(ctx context.Context, inviterId int) error { | ||||
| 	var err error | ||||
| 	if user.Password != "" { | ||||
| 		user.Password, err = common.Password2Hash(user.Password) | ||||
| @@ -129,16 +133,16 @@ func (user *User) Insert(inviterId int) error { | ||||
| 		return result.Error | ||||
| 	} | ||||
| 	if config.QuotaForNewUser > 0 { | ||||
| 		RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) | ||||
| 		RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) | ||||
| 	} | ||||
| 	if inviterId != 0 { | ||||
| 		if config.QuotaForInvitee > 0 { | ||||
| 			_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee) | ||||
| 			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) | ||||
| 			RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) | ||||
| 		} | ||||
| 		if config.QuotaForInviter > 0 { | ||||
| 			_ = IncreaseUserQuota(inviterId, config.QuotaForInviter) | ||||
| 			RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) | ||||
| 			RecordLog(ctx, inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) | ||||
| 		} | ||||
| 	} | ||||
| 	// create default token | ||||
| @@ -245,6 +249,14 @@ func (user *User) FillUserByLarkId() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (user *User) FillUserByOidcId() error { | ||||
| 	if user.OidcId == "" { | ||||
| 		return errors.New("oidc id 为空!") | ||||
| 	} | ||||
| 	DB.Where(User{OidcId: user.OidcId}).First(user) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (user *User) FillUserByWeChatId() error { | ||||
| 	if user.WeChatId == "" { | ||||
| 		return errors.New("WeChat id 为空!") | ||||
| @@ -277,6 +289,10 @@ func IsLarkIdAlreadyTaken(githubId string) bool { | ||||
| 	return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|  | ||||
| func IsOidcIdAlreadyTaken(oidcId string) bool { | ||||
| 	return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|  | ||||
| func IsUsernameAlreadyTaken(username string) bool { | ||||
| 	return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|   | ||||
| @@ -1,10 +1,11 @@ | ||||
| package monitor | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func ShouldDisableChannel(err *model.Error, statusCode int) bool { | ||||
| @@ -18,31 +19,23 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool { | ||||
| 		return true | ||||
| 	} | ||||
| 	switch err.Type { | ||||
| 	case "insufficient_quota": | ||||
| 		return true | ||||
| 	// https://docs.anthropic.com/claude/reference/errors | ||||
| 	case "authentication_error": | ||||
| 		return true | ||||
| 	case "permission_error": | ||||
| 		return true | ||||
| 	case "forbidden": | ||||
| 	case "insufficient_quota", "authentication_error", "permission_error", "forbidden": | ||||
| 		return true | ||||
| 	} | ||||
| 	if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | ||||
| 		return true | ||||
| 	} | ||||
| 	if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic | ||||
| 		return true | ||||
| 	} else if strings.HasPrefix(err.Message, "This organization has been disabled.") { | ||||
| 		return true | ||||
| 	} | ||||
| 	//if strings.Contains(err.Message, "quota") { | ||||
| 	//	return true | ||||
| 	//} | ||||
| 	if strings.Contains(err.Message, "credit") { | ||||
| 		return true | ||||
| 	} | ||||
| 	if strings.Contains(err.Message, "balance") { | ||||
|  | ||||
| 	lowerMessage := strings.ToLower(err.Message) | ||||
| 	if strings.Contains(lowerMessage, "your access was terminated") || | ||||
| 		strings.Contains(lowerMessage, "violation of our policies") || | ||||
| 		strings.Contains(lowerMessage, "your credit balance is too low") || | ||||
| 		strings.Contains(lowerMessage, "organization has been disabled") || | ||||
| 		strings.Contains(lowerMessage, "credit") || | ||||
| 		strings.Contains(lowerMessage, "balance") || | ||||
| 		strings.Contains(lowerMessage, "permission denied") || | ||||
| 		strings.Contains(lowerMessage, "organization has been restricted") || // groq | ||||
| 		strings.Contains(lowerMessage, "已欠费") { | ||||
| 		return true | ||||
| 	} | ||||
| 	return false | ||||
|   | ||||
| @@ -15,7 +15,10 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/ollama" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/palm" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/proxy" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/replicate" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/tencent" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/vertexai" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/xunfei" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/zhipu" | ||||
| 	"github.com/songquanpeng/one-api/relay/apitype" | ||||
| @@ -55,6 +58,12 @@ func GetAdaptor(apiType int) adaptor.Adaptor { | ||||
| 		return &cloudflare.Adaptor{} | ||||
| 	case apitype.DeepL: | ||||
| 		return &deepl.Adaptor{} | ||||
| 	case apitype.VertexAI: | ||||
| 		return &vertexai.Adaptor{} | ||||
| 	case apitype.Proxy: | ||||
| 		return &proxy.Adaptor{} | ||||
| 	case apitype.Replicate: | ||||
| 		return &replicate.Adaptor{} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -1,7 +1,23 @@ | ||||
| package ali | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", | ||||
| 	"text-embedding-v1", | ||||
| 	"qwen-turbo", "qwen-turbo-latest", | ||||
| 	"qwen-plus", "qwen-plus-latest", | ||||
| 	"qwen-max", "qwen-max-latest", | ||||
| 	"qwen-max-longcontext", | ||||
| 	"qwen-vl-max", "qwen-vl-max-latest", "qwen-vl-plus", "qwen-vl-plus-latest", | ||||
| 	"qwen-vl-ocr", "qwen-vl-ocr-latest", | ||||
| 	"qwen-audio-turbo", | ||||
| 	"qwen-math-plus", "qwen-math-plus-latest", "qwen-math-turbo", "qwen-math-turbo-latest", | ||||
| 	"qwen-coder-plus", "qwen-coder-plus-latest", "qwen-coder-turbo", "qwen-coder-turbo-latest", | ||||
| 	"qwq-32b-preview", "qwen2.5-72b-instruct", "qwen2.5-32b-instruct", "qwen2.5-14b-instruct", "qwen2.5-7b-instruct", "qwen2.5-3b-instruct", "qwen2.5-1.5b-instruct", "qwen2.5-0.5b-instruct", | ||||
| 	"qwen2-72b-instruct", "qwen2-57b-a14b-instruct", "qwen2-7b-instruct", "qwen2-1.5b-instruct", "qwen2-0.5b-instruct", | ||||
| 	"qwen1.5-110b-chat", "qwen1.5-72b-chat", "qwen1.5-32b-chat", "qwen1.5-14b-chat", "qwen1.5-7b-chat", "qwen1.5-1.8b-chat", "qwen1.5-0.5b-chat", | ||||
| 	"qwen-72b-chat", "qwen-14b-chat", "qwen-7b-chat", "qwen-1.8b-chat", "qwen-1.8b-longcontext-chat", | ||||
| 	"qwen2-vl-7b-instruct", "qwen2-vl-2b-instruct", "qwen-vl-v1", "qwen-vl-chat-v1", | ||||
| 	"qwen2-audio-instruct", "qwen-audio-chat", | ||||
| 	"qwen2.5-math-72b-instruct", "qwen2.5-math-7b-instruct", "qwen2.5-math-1.5b-instruct", "qwen2-math-72b-instruct", "qwen2-math-7b-instruct", "qwen2-math-1.5b-instruct", | ||||
| 	"qwen2.5-coder-32b-instruct", "qwen2.5-coder-14b-instruct", "qwen2.5-coder-7b-instruct", "qwen2.5-coder-3b-instruct", "qwen2.5-coder-1.5b-instruct", "qwen2.5-coder-0.5b-instruct", | ||||
| 	"text-embedding-v1", "text-embedding-v3", "text-embedding-v2", "text-embedding-async-v2", "text-embedding-async-v1", | ||||
| 	"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", | ||||
| } | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package ali | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| @@ -35,9 +36,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 		enableSearch = true | ||||
| 		aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) | ||||
| 	} | ||||
| 	if request.TopP >= 1 { | ||||
| 		request.TopP = 0.9999 | ||||
| 	} | ||||
| 	request.TopP = helper.Float64PtrMax(request.TopP, 0.9999) | ||||
| 	return &ChatRequest{ | ||||
| 		Model: aliModel, | ||||
| 		Input: Input{ | ||||
| @@ -59,7 +58,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
|  | ||||
| func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { | ||||
| 	return &EmbeddingRequest{ | ||||
| 		Model: "text-embedding-v1", | ||||
| 		Model: request.Model, | ||||
| 		Input: struct { | ||||
| 			Texts []string `json:"texts"` | ||||
| 		}{ | ||||
| @@ -102,8 +101,9 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
|  | ||||
| 	requestModel := c.GetString(ctxkey.RequestModel) | ||||
| 	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) | ||||
| 	fullTextResponse.Model = requestModel | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
|   | ||||
| @@ -16,13 +16,13 @@ type Input struct { | ||||
| } | ||||
|  | ||||
| type Parameters struct { | ||||
| 	TopP              float64      `json:"top_p,omitempty"` | ||||
| 	TopP              *float64     `json:"top_p,omitempty"` | ||||
| 	TopK              int          `json:"top_k,omitempty"` | ||||
| 	Seed              uint64       `json:"seed,omitempty"` | ||||
| 	EnableSearch      bool         `json:"enable_search,omitempty"` | ||||
| 	IncrementalOutput bool         `json:"incremental_output,omitempty"` | ||||
| 	MaxTokens         int          `json:"max_tokens,omitempty"` | ||||
| 	Temperature       float64      `json:"temperature,omitempty"` | ||||
| 	Temperature       *float64     `json:"temperature,omitempty"` | ||||
| 	ResultFormat      string       `json:"result_format,omitempty"` | ||||
| 	Tools             []model.Tool `json:"tools,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -3,12 +3,14 @@ package anthropic | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| @@ -31,6 +33,13 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me | ||||
| 	} | ||||
| 	req.Header.Set("anthropic-version", anthropicVersion) | ||||
| 	req.Header.Set("anthropic-beta", "messages-2023-12-15") | ||||
|  | ||||
| 	// https://x.com/alexalbert__/status/1812921642143900036 | ||||
| 	// claude-3-5-sonnet can support 8k context | ||||
| 	if strings.HasPrefix(meta.ActualModelName, "claude-3-5-sonnet") { | ||||
| 		req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15") | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -3,7 +3,10 @@ package anthropic | ||||
| var ModelList = []string{ | ||||
| 	"claude-instant-1.2", "claude-2.0", "claude-2.1", | ||||
| 	"claude-3-haiku-20240307", | ||||
| 	"claude-3-5-haiku-20241022", | ||||
| 	"claude-3-sonnet-20240229", | ||||
| 	"claude-3-opus-20240229", | ||||
| 	"claude-3-5-sonnet-20240620", | ||||
| 	"claude-3-5-sonnet-20241022", | ||||
| 	"claude-3-5-sonnet-latest", | ||||
| } | ||||
|   | ||||
| @@ -48,8 +48,8 @@ type Request struct { | ||||
| 	MaxTokens     int       `json:"max_tokens,omitempty"` | ||||
| 	StopSequences []string  `json:"stop_sequences,omitempty"` | ||||
| 	Stream        bool      `json:"stream,omitempty"` | ||||
| 	Temperature   float64   `json:"temperature,omitempty"` | ||||
| 	TopP          float64   `json:"top_p,omitempty"` | ||||
| 	Temperature   *float64  `json:"temperature,omitempty"` | ||||
| 	TopP          *float64  `json:"top_p,omitempty"` | ||||
| 	TopK          int       `json:"top_k,omitempty"` | ||||
| 	Tools         []Tool    `json:"tools,omitempty"` | ||||
| 	ToolChoice    any       `json:"tool_choice,omitempty"` | ||||
|   | ||||
| @@ -1,17 +1,16 @@ | ||||
| package aws | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/aws/aws-sdk-go-v2/aws" | ||||
| 	"github.com/aws/aws-sdk-go-v2/credentials" | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/aws/aws-sdk-go-v2/aws" | ||||
| 	"github.com/aws/aws-sdk-go-v2/credentials" | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/anthropic" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/aws/utils" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
| @@ -19,18 +18,52 @@ import ( | ||||
| var _ adaptor.Adaptor = new(Adaptor) | ||||
| 
 | ||||
| type Adaptor struct { | ||||
| 	meta      *meta.Meta | ||||
| 	awsClient *bedrockruntime.Client | ||||
| 	awsAdapter utils.AwsAdapter | ||||
| 
 | ||||
| 	Meta      *meta.Meta | ||||
| 	AwsClient *bedrockruntime.Client | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| 	a.meta = meta | ||||
| 	a.awsClient = bedrockruntime.New(bedrockruntime.Options{ | ||||
| 	a.Meta = meta | ||||
| 	a.AwsClient = bedrockruntime.New(bedrockruntime.Options{ | ||||
| 		Region:      meta.Config.Region, | ||||
| 		Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 
 | ||||
| 	adaptor := GetAdaptor(request.Model) | ||||
| 	if adaptor == nil { | ||||
| 		return nil, errors.New("adaptor not found") | ||||
| 	} | ||||
| 
 | ||||
| 	a.awsAdapter = adaptor | ||||
| 	return adaptor.ConvertRequest(c, relayMode, request) | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if a.awsAdapter == nil { | ||||
| 		return nil, utils.WrapErr(errors.New("awsAdapter is nil")) | ||||
| 	} | ||||
| 	return a.awsAdapter.DoResponse(c, a.AwsClient, meta) | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) GetModelList() (models []string) { | ||||
| 	for model := range adaptors { | ||||
| 		models = append(models, model) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return "aws" | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	return "", nil | ||||
| } | ||||
| @@ -39,17 +72,6 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 
 | ||||
| 	claudeReq := anthropic.ConvertRequest(*request) | ||||
| 	c.Set(ctxkey.RequestModel, request.Model) | ||||
| 	c.Set(ctxkey.ConvertedRequest, claudeReq) | ||||
| 	return claudeReq, nil | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| @@ -60,23 +82,3 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, a.awsClient) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, a.awsClient, meta.ActualModelName) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) GetModelList() (models []string) { | ||||
| 	for n := range awsModelIDMap { | ||||
| 		models = append(models, n) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return "aws" | ||||
| } | ||||
							
								
								
									
										37
									
								
								relay/adaptor/aws/claude/adapter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								relay/adaptor/aws/claude/adapter.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | ||||
| package aws | ||||
|  | ||||
| import ( | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/anthropic" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/aws/utils" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| var _ utils.AwsAdapter = new(Adaptor) | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
|  | ||||
| 	claudeReq := anthropic.ConvertRequest(*request) | ||||
| 	c.Set(ctxkey.RequestModel, request.Model) | ||||
| 	c.Set(ctxkey.ConvertedRequest, claudeReq) | ||||
| 	return claudeReq, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, awsCli) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, awsCli, meta.ActualModelName) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| @@ -5,8 +5,6 @@ import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 
 | ||||
| @@ -17,34 +15,31 @@ import ( | ||||
| 	"github.com/jinzhu/copier" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/anthropic" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/aws/utils" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
| 
 | ||||
| func wrapErr(err error) *relaymodel.ErrorWithStatusCode { | ||||
| 	return &relaymodel.ErrorWithStatusCode{ | ||||
| 		StatusCode: http.StatusInternalServerError, | ||||
| 		Error: relaymodel.Error{ | ||||
| 			Message: fmt.Sprintf("%s", err.Error()), | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html | ||||
| var awsModelIDMap = map[string]string{ | ||||
| var AwsModelIDMap = map[string]string{ | ||||
| 	"claude-instant-1.2":         "anthropic.claude-instant-v1", | ||||
| 	"claude-2.0":                 "anthropic.claude-v2", | ||||
| 	"claude-2.1":                 "anthropic.claude-v2:1", | ||||
| 	"claude-3-sonnet-20240229":   "anthropic.claude-3-sonnet-20240229-v1:0", | ||||
| 	"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", | ||||
| 	"claude-3-opus-20240229":     "anthropic.claude-3-opus-20240229-v1:0", | ||||
| 	"claude-3-haiku-20240307":    "anthropic.claude-3-haiku-20240307-v1:0", | ||||
| 	"claude-3-sonnet-20240229":   "anthropic.claude-3-sonnet-20240229-v1:0", | ||||
| 	"claude-3-opus-20240229":     "anthropic.claude-3-opus-20240229-v1:0", | ||||
| 	"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", | ||||
| 	"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0", | ||||
| 	"claude-3-5-sonnet-latest":   "anthropic.claude-3-5-sonnet-20241022-v2:0", | ||||
| 	"claude-3-5-haiku-20241022":  "anthropic.claude-3-5-haiku-20241022-v1:0", | ||||
| } | ||||
| 
 | ||||
| func awsModelID(requestModel string) (string, error) { | ||||
| 	if awsModelID, ok := awsModelIDMap[requestModel]; ok { | ||||
| 	if awsModelID, ok := AwsModelIDMap[requestModel]; ok { | ||||
| 		return awsModelID, nil | ||||
| 	} | ||||
| 
 | ||||
| @@ -54,7 +49,7 @@ func awsModelID(requestModel string) (string, error) { | ||||
| func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { | ||||
| 	awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "awsModelID")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil | ||||
| 	} | ||||
| 
 | ||||
| 	awsReq := &bedrockruntime.InvokeModelInput{ | ||||
| @@ -65,30 +60,30 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (* | ||||
| 
 | ||||
| 	claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) | ||||
| 	if !ok { | ||||
| 		return wrapErr(errors.New("request not found")), nil | ||||
| 		return utils.WrapErr(errors.New("request not found")), nil | ||||
| 	} | ||||
| 	claudeReq := claudeReq_.(*anthropic.Request) | ||||
| 	awsClaudeReq := &Request{ | ||||
| 		AnthropicVersion: "bedrock-2023-05-31", | ||||
| 	} | ||||
| 	if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "copy request")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "copy request")), nil | ||||
| 	} | ||||
| 
 | ||||
| 	awsReq.Body, err = json.Marshal(awsClaudeReq) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "marshal request")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "marshal request")), nil | ||||
| 	} | ||||
| 
 | ||||
| 	awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "InvokeModel")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil | ||||
| 	} | ||||
| 
 | ||||
| 	claudeResponse := new(anthropic.Response) | ||||
| 	err = json.Unmarshal(awsResp.Body, claudeResponse) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "unmarshal response")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil | ||||
| 	} | ||||
| 
 | ||||
| 	openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse) | ||||
| @@ -108,7 +103,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E | ||||
| 	createdTime := helper.GetTimestamp() | ||||
| 	awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "awsModelID")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil | ||||
| 	} | ||||
| 
 | ||||
| 	awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ | ||||
| @@ -119,7 +114,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E | ||||
| 
 | ||||
| 	claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) | ||||
| 	if !ok { | ||||
| 		return wrapErr(errors.New("request not found")), nil | ||||
| 		return utils.WrapErr(errors.New("request not found")), nil | ||||
| 	} | ||||
| 	claudeReq := claudeReq_.(*anthropic.Request) | ||||
| 
 | ||||
| @@ -127,16 +122,16 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E | ||||
| 		AnthropicVersion: "bedrock-2023-05-31", | ||||
| 	} | ||||
| 	if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "copy request")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "copy request")), nil | ||||
| 	} | ||||
| 	awsReq.Body, err = json.Marshal(awsClaudeReq) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "marshal request")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "marshal request")), nil | ||||
| 	} | ||||
| 
 | ||||
| 	awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil | ||||
| 		return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil | ||||
| 	} | ||||
| 	stream := awsResp.GetStream() | ||||
| 	defer stream.Close() | ||||
| @@ -11,8 +11,8 @@ type Request struct { | ||||
| 	Messages         []anthropic.Message `json:"messages"` | ||||
| 	System           string              `json:"system,omitempty"` | ||||
| 	MaxTokens        int                 `json:"max_tokens,omitempty"` | ||||
| 	Temperature      float64             `json:"temperature,omitempty"` | ||||
| 	TopP             float64             `json:"top_p,omitempty"` | ||||
| 	Temperature      *float64            `json:"temperature,omitempty"` | ||||
| 	TopP             *float64            `json:"top_p,omitempty"` | ||||
| 	TopK             int                 `json:"top_k,omitempty"` | ||||
| 	StopSequences    []string            `json:"stop_sequences,omitempty"` | ||||
| 	Tools            []anthropic.Tool    `json:"tools,omitempty"` | ||||
							
								
								
									
										37
									
								
								relay/adaptor/aws/llama3/adapter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								relay/adaptor/aws/llama3/adapter.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | ||||
| package aws | ||||
|  | ||||
| import ( | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/aws/utils" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| var _ utils.AwsAdapter = new(Adaptor) | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
|  | ||||
| 	llamaReq := ConvertRequest(*request) | ||||
| 	c.Set(ctxkey.RequestModel, request.Model) | ||||
| 	c.Set(ctxkey.ConvertedRequest, llamaReq) | ||||
| 	return llamaReq, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, awsCli) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, awsCli, meta.ActualModelName) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										231
									
								
								relay/adaptor/aws/llama3/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										231
									
								
								relay/adaptor/aws/llama3/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,231 @@ | ||||
| // Package aws provides the AWS adaptor for the relay service. | ||||
| package aws | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"text/template" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
|  | ||||
| 	"github.com/aws/aws-sdk-go-v2/aws" | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/aws/utils" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| // Only support llama-3-8b and llama-3-70b instruction models | ||||
| // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html | ||||
| var AwsModelIDMap = map[string]string{ | ||||
| 	"llama3-8b-8192":  "meta.llama3-8b-instruct-v1:0", | ||||
| 	"llama3-70b-8192": "meta.llama3-70b-instruct-v1:0", | ||||
| } | ||||
|  | ||||
| func awsModelID(requestModel string) (string, error) { | ||||
| 	if awsModelID, ok := AwsModelIDMap[requestModel]; ok { | ||||
| 		return awsModelID, nil | ||||
| 	} | ||||
|  | ||||
| 	return "", errors.Errorf("model %s not found", requestModel) | ||||
| } | ||||
|  | ||||
| // promptTemplate with range | ||||
| const promptTemplate = `<|begin_of_text|>{{range .Messages}}<|start_header_id|>{{.Role}}<|end_header_id|>{{.StringContent}}<|eot_id|>{{end}}<|start_header_id|>assistant<|end_header_id|> | ||||
| ` | ||||
|  | ||||
| var promptTpl = template.Must(template.New("llama3-chat").Parse(promptTemplate)) | ||||
|  | ||||
| func RenderPrompt(messages []relaymodel.Message) string { | ||||
| 	var buf bytes.Buffer | ||||
| 	err := promptTpl.Execute(&buf, struct{ Messages []relaymodel.Message }{messages}) | ||||
| 	if err != nil { | ||||
| 		logger.SysError("error rendering prompt messages: " + err.Error()) | ||||
| 	} | ||||
| 	return buf.String() | ||||
| } | ||||
|  | ||||
| func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request { | ||||
| 	llamaRequest := Request{ | ||||
| 		MaxGenLen:   textRequest.MaxTokens, | ||||
| 		Temperature: textRequest.Temperature, | ||||
| 		TopP:        textRequest.TopP, | ||||
| 	} | ||||
| 	if llamaRequest.MaxGenLen == 0 { | ||||
| 		llamaRequest.MaxGenLen = 2048 | ||||
| 	} | ||||
| 	prompt := RenderPrompt(textRequest.Messages) | ||||
| 	llamaRequest.Prompt = prompt | ||||
| 	return &llamaRequest | ||||
| } | ||||
|  | ||||
| func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { | ||||
| 	awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) | ||||
| 	if err != nil { | ||||
| 		return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil | ||||
| 	} | ||||
|  | ||||
| 	awsReq := &bedrockruntime.InvokeModelInput{ | ||||
| 		ModelId:     aws.String(awsModelId), | ||||
| 		Accept:      aws.String("application/json"), | ||||
| 		ContentType: aws.String("application/json"), | ||||
| 	} | ||||
|  | ||||
| 	llamaReq, ok := c.Get(ctxkey.ConvertedRequest) | ||||
| 	if !ok { | ||||
| 		return utils.WrapErr(errors.New("request not found")), nil | ||||
| 	} | ||||
|  | ||||
| 	awsReq.Body, err = json.Marshal(llamaReq) | ||||
| 	if err != nil { | ||||
| 		return utils.WrapErr(errors.Wrap(err, "marshal request")), nil | ||||
| 	} | ||||
|  | ||||
| 	awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) | ||||
| 	if err != nil { | ||||
| 		return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil | ||||
| 	} | ||||
|  | ||||
| 	var llamaResponse Response | ||||
| 	err = json.Unmarshal(awsResp.Body, &llamaResponse) | ||||
| 	if err != nil { | ||||
| 		return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil | ||||
| 	} | ||||
|  | ||||
| 	openaiResp := ResponseLlama2OpenAI(&llamaResponse) | ||||
| 	openaiResp.Model = modelName | ||||
| 	usage := relaymodel.Usage{ | ||||
| 		PromptTokens:     llamaResponse.PromptTokenCount, | ||||
| 		CompletionTokens: llamaResponse.GenerationTokenCount, | ||||
| 		TotalTokens:      llamaResponse.PromptTokenCount + llamaResponse.GenerationTokenCount, | ||||
| 	} | ||||
| 	openaiResp.Usage = usage | ||||
|  | ||||
| 	c.JSON(http.StatusOK, openaiResp) | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse { | ||||
| 	var responseText string | ||||
| 	if len(llamaResponse.Generation) > 0 { | ||||
| 		responseText = llamaResponse.Generation | ||||
| 	} | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: relaymodel.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: responseText, | ||||
| 			Name:    nil, | ||||
| 		}, | ||||
| 		FinishReason: llamaResponse.StopReason, | ||||
| 	} | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", random.GetUUID()), | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { | ||||
| 	createdTime := helper.GetTimestamp() | ||||
| 	awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) | ||||
| 	if err != nil { | ||||
| 		return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil | ||||
| 	} | ||||
|  | ||||
| 	awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ | ||||
| 		ModelId:     aws.String(awsModelId), | ||||
| 		Accept:      aws.String("application/json"), | ||||
| 		ContentType: aws.String("application/json"), | ||||
| 	} | ||||
|  | ||||
| 	llamaReq, ok := c.Get(ctxkey.ConvertedRequest) | ||||
| 	if !ok { | ||||
| 		return utils.WrapErr(errors.New("request not found")), nil | ||||
| 	} | ||||
|  | ||||
| 	awsReq.Body, err = json.Marshal(llamaReq) | ||||
| 	if err != nil { | ||||
| 		return utils.WrapErr(errors.Wrap(err, "marshal request")), nil | ||||
| 	} | ||||
|  | ||||
| 	awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) | ||||
| 	if err != nil { | ||||
| 		return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil | ||||
| 	} | ||||
| 	stream := awsResp.GetStream() | ||||
| 	defer stream.Close() | ||||
|  | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	var usage relaymodel.Usage | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		event, ok := <-stream.Events() | ||||
| 		if !ok { | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
|  | ||||
| 		switch v := event.(type) { | ||||
| 		case *types.ResponseStreamMemberChunk: | ||||
| 			var llamaResp StreamResponse | ||||
| 			err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&llamaResp) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return false | ||||
| 			} | ||||
|  | ||||
| 			if llamaResp.PromptTokenCount > 0 { | ||||
| 				usage.PromptTokens = llamaResp.PromptTokenCount | ||||
| 			} | ||||
| 			if llamaResp.StopReason == "stop" { | ||||
| 				usage.CompletionTokens = llamaResp.GenerationTokenCount | ||||
| 				usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens | ||||
| 			} | ||||
| 			response := StreamResponseLlama2OpenAI(&llamaResp) | ||||
| 			response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID()) | ||||
| 			response.Model = c.GetString(ctxkey.OriginalModel) | ||||
| 			response.Created = createdTime | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | ||||
| 			return true | ||||
| 		case *types.UnknownUnionMember: | ||||
| 			fmt.Println("unknown tag:", v.Tag) | ||||
| 			return false | ||||
| 		default: | ||||
| 			fmt.Println("union is nil or unknown type") | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func StreamResponseLlama2OpenAI(llamaResponse *StreamResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = llamaResponse.Generation | ||||
| 	choice.Delta.Role = "assistant" | ||||
| 	finishReason := llamaResponse.StopReason | ||||
| 	if finishReason != "null" { | ||||
| 		choice.FinishReason = &finishReason | ||||
| 	} | ||||
| 	var openaiResponse openai.ChatCompletionsStreamResponse | ||||
| 	openaiResponse.Object = "chat.completion.chunk" | ||||
| 	openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||
| 	return &openaiResponse | ||||
| } | ||||
							
								
								
									
										45
									
								
								relay/adaptor/aws/llama3/main_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								relay/adaptor/aws/llama3/main_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | ||||
| package aws_test | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	aws "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| func TestRenderPrompt(t *testing.T) { | ||||
| 	messages := []relaymodel.Message{ | ||||
| 		{ | ||||
| 			Role:    "user", | ||||
| 			Content: "What's your name?", | ||||
| 		}, | ||||
| 	} | ||||
| 	prompt := aws.RenderPrompt(messages) | ||||
| 	expected := `<|begin_of_text|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|> | ||||
| ` | ||||
| 	assert.Equal(t, expected, prompt) | ||||
|  | ||||
| 	messages = []relaymodel.Message{ | ||||
| 		{ | ||||
| 			Role:    "system", | ||||
| 			Content: "Your name is Kat. You are a detective.", | ||||
| 		}, | ||||
| 		{ | ||||
| 			Role:    "user", | ||||
| 			Content: "What's your name?", | ||||
| 		}, | ||||
| 		{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: "Kat", | ||||
| 		}, | ||||
| 		{ | ||||
| 			Role:    "user", | ||||
| 			Content: "What's your job?", | ||||
| 		}, | ||||
| 	} | ||||
| 	prompt = aws.RenderPrompt(messages) | ||||
| 	expected = `<|begin_of_text|><|start_header_id|>system<|end_header_id|>Your name is Kat. You are a detective.<|eot_id|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>Kat<|eot_id|><|start_header_id|>user<|end_header_id|>What's your job?<|eot_id|><|start_header_id|>assistant<|end_header_id|> | ||||
| ` | ||||
| 	assert.Equal(t, expected, prompt) | ||||
| } | ||||
							
								
								
									
										29
									
								
								relay/adaptor/aws/llama3/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								relay/adaptor/aws/llama3/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| package aws | ||||
|  | ||||
| // Request is the request to AWS Llama3 | ||||
| // | ||||
| // https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html | ||||
| type Request struct { | ||||
| 	Prompt      string   `json:"prompt"` | ||||
| 	MaxGenLen   int      `json:"max_gen_len,omitempty"` | ||||
| 	Temperature *float64 `json:"temperature,omitempty"` | ||||
| 	TopP        *float64 `json:"top_p,omitempty"` | ||||
| } | ||||
|  | ||||
| // Response is the response from AWS Llama3 | ||||
| // | ||||
| // https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html | ||||
| type Response struct { | ||||
| 	Generation           string `json:"generation"` | ||||
| 	PromptTokenCount     int    `json:"prompt_token_count"` | ||||
| 	GenerationTokenCount int    `json:"generation_token_count"` | ||||
| 	StopReason           string `json:"stop_reason"` | ||||
| } | ||||
|  | ||||
| // {'generation': 'Hi', 'prompt_token_count': 15, 'generation_token_count': 1, 'stop_reason': None} | ||||
| type StreamResponse struct { | ||||
| 	Generation           string `json:"generation"` | ||||
| 	PromptTokenCount     int    `json:"prompt_token_count"` | ||||
| 	GenerationTokenCount int    `json:"generation_token_count"` | ||||
| 	StopReason           string `json:"stop_reason"` | ||||
| } | ||||
							
								
								
									
										39
									
								
								relay/adaptor/aws/registry.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								relay/adaptor/aws/registry.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | ||||
| package aws | ||||
|  | ||||
| import ( | ||||
| 	claude "github.com/songquanpeng/one-api/relay/adaptor/aws/claude" | ||||
| 	llama3 "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/aws/utils" | ||||
| ) | ||||
|  | ||||
| type AwsModelType int | ||||
|  | ||||
| const ( | ||||
| 	AwsClaude AwsModelType = iota + 1 | ||||
| 	AwsLlama3 | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	adaptors = map[string]AwsModelType{} | ||||
| ) | ||||
|  | ||||
| func init() { | ||||
| 	for model := range claude.AwsModelIDMap { | ||||
| 		adaptors[model] = AwsClaude | ||||
| 	} | ||||
| 	for model := range llama3.AwsModelIDMap { | ||||
| 		adaptors[model] = AwsLlama3 | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetAdaptor(model string) utils.AwsAdapter { | ||||
| 	adaptorType := adaptors[model] | ||||
| 	switch adaptorType { | ||||
| 	case AwsClaude: | ||||
| 		return &claude.Adaptor{} | ||||
| 	case AwsLlama3: | ||||
| 		return &llama3.Adaptor{} | ||||
| 	default: | ||||
| 		return nil | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										51
									
								
								relay/adaptor/aws/utils/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								relay/adaptor/aws/utils/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,51 @@ | ||||
| package utils | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/aws/aws-sdk-go-v2/aws" | ||||
| 	"github.com/aws/aws-sdk-go-v2/credentials" | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| type AwsAdapter interface { | ||||
| 	ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) | ||||
| 	DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) | ||||
| } | ||||
|  | ||||
| type Adaptor struct { | ||||
| 	Meta      *meta.Meta | ||||
| 	AwsClient *bedrockruntime.Client | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| 	a.Meta = meta | ||||
| 	a.AwsClient = bedrockruntime.New(bedrockruntime.Options{ | ||||
| 		Region:      meta.Config.Region, | ||||
| 		Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	return "", nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	return request, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
							
								
								
									
										16
									
								
								relay/adaptor/aws/utils/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								relay/adaptor/aws/utils/utils.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| package utils | ||||
|  | ||||
| import ( | ||||
| 	"net/http" | ||||
|  | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func WrapErr(err error) *relaymodel.ErrorWithStatusCode { | ||||
| 	return &relaymodel.ErrorWithStatusCode{ | ||||
| 		StatusCode: http.StatusInternalServerError, | ||||
| 		Error: relaymodel.Error{ | ||||
| 			Message: err.Error(), | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
| @@ -35,9 +35,9 @@ type Message struct { | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	Messages        []Message `json:"messages"` | ||||
| 	Temperature     float64   `json:"temperature,omitempty"` | ||||
| 	TopP            float64   `json:"top_p,omitempty"` | ||||
| 	PenaltyScore    float64   `json:"penalty_score,omitempty"` | ||||
| 	Temperature     *float64  `json:"temperature,omitempty"` | ||||
| 	TopP            *float64  `json:"top_p,omitempty"` | ||||
| 	PenaltyScore    *float64  `json:"penalty_score,omitempty"` | ||||
| 	Stream          bool      `json:"stream,omitempty"` | ||||
| 	System          string    `json:"system,omitempty"` | ||||
| 	DisableSearch   bool      `json:"disable_search,omitempty"` | ||||
|   | ||||
| @@ -5,11 +5,13 @@ import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| @@ -27,8 +29,33 @@ func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| 	a.meta = meta | ||||
| } | ||||
|  | ||||
| // WorkerAI cannot be used across accounts with AIGateWay | ||||
| // https://developers.cloudflare.com/ai-gateway/providers/workersai/#openai-compatible-endpoints | ||||
| // https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/workers-ai | ||||
| func (a *Adaptor) isAIGateWay(baseURL string) bool { | ||||
| 	return strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") && strings.HasSuffix(baseURL, "/workers-ai") | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil | ||||
| 	isAIGateWay := a.isAIGateWay(meta.BaseURL) | ||||
| 	var urlPrefix string | ||||
| 	if isAIGateWay { | ||||
| 		urlPrefix = meta.BaseURL | ||||
| 	} else { | ||||
| 		urlPrefix = fmt.Sprintf("%s/client/v4/accounts/%s/ai", meta.BaseURL, meta.Config.UserID) | ||||
| 	} | ||||
|  | ||||
| 	switch meta.Mode { | ||||
| 	case relaymode.ChatCompletions: | ||||
| 		return fmt.Sprintf("%s/v1/chat/completions", urlPrefix), nil | ||||
| 	case relaymode.Embeddings: | ||||
| 		return fmt.Sprintf("%s/v1/embeddings", urlPrefix), nil | ||||
| 	default: | ||||
| 		if isAIGateWay { | ||||
| 			return fmt.Sprintf("%s/%s", urlPrefix, meta.ActualModelName), nil | ||||
| 		} | ||||
| 		return fmt.Sprintf("%s/run/%s", urlPrefix, meta.ActualModelName), nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| @@ -41,7 +68,14 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	return ConvertRequest(*request), nil | ||||
| 	switch relayMode { | ||||
| 	case relaymode.Completions: | ||||
| 		return ConvertCompletionsRequest(*request), nil | ||||
| 	case relaymode.ChatCompletions, relaymode.Embeddings: | ||||
| 		return request, nil | ||||
| 	default: | ||||
| 		return nil, errors.New("not implemented") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package cloudflare | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"@cf/meta/llama-3.1-8b-instruct", | ||||
| 	"@cf/meta/llama-2-7b-chat-fp16", | ||||
| 	"@cf/meta/llama-2-7b-chat-int8", | ||||
| 	"@cf/mistral/mistral-7b-instruct-v0.1", | ||||
|   | ||||
| @@ -3,11 +3,13 @@ package cloudflare | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| @@ -16,57 +18,23 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 	var promptBuilder strings.Builder | ||||
| 	for _, message := range textRequest.Messages { | ||||
| 		promptBuilder.WriteString(message.StringContent()) | ||||
| 		promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息 | ||||
| 	} | ||||
|  | ||||
| func ConvertCompletionsRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 	p, _ := textRequest.Prompt.(string) | ||||
| 	return &Request{ | ||||
| 		Prompt:      p, | ||||
| 		MaxTokens:   textRequest.MaxTokens, | ||||
| 		Prompt:      promptBuilder.String(), | ||||
| 		Stream:      textRequest.Stream, | ||||
| 		Temperature: textRequest.Temperature, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse { | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: model.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: cloudflareResponse.Result.Response, | ||||
| 		}, | ||||
| 		FinishReason: "stop", | ||||
| 	} | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = cloudflareResponse.Response | ||||
| 	choice.Delta.Role = "assistant" | ||||
| 	openaiResponse := openai.ChatCompletionsStreamResponse{ | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 	} | ||||
| 	return &openaiResponse | ||||
| } | ||||
|  | ||||
| func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(bufio.ScanLines) | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	id := helper.GetResponseID(c) | ||||
| 	responseModel := c.GetString("original_model") | ||||
| 	responseModel := c.GetString(ctxkey.OriginalModel) | ||||
| 	var responseText string | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| @@ -77,22 +45,22 @@ func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelN | ||||
| 		data = strings.TrimPrefix(data, "data: ") | ||||
| 		data = strings.TrimSuffix(data, "\r") | ||||
|  | ||||
| 		var cloudflareResponse StreamResponse | ||||
| 		err := json.Unmarshal([]byte(data), &cloudflareResponse) | ||||
| 		if data == "[DONE]" { | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		var response openai.ChatCompletionsStreamResponse | ||||
| 		err := json.Unmarshal([]byte(data), &response) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) | ||||
| 		if response == nil { | ||||
| 			continue | ||||
| 		for _, v := range response.Choices { | ||||
| 			v.Delta.Role = "assistant" | ||||
| 			responseText += v.Delta.StringContent() | ||||
| 		} | ||||
|  | ||||
| 		responseText += cloudflareResponse.Response | ||||
| 		response.Id = id | ||||
| 		response.Model = responseModel | ||||
|  | ||||
| 		response.Model = modelName | ||||
| 		err = render.ObjectData(c, response) | ||||
| 		if err != nil { | ||||
| 			logger.SysError(err.Error()) | ||||
| @@ -123,22 +91,25 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var cloudflareResponse Response | ||||
| 	err = json.Unmarshal(responseBody, &cloudflareResponse) | ||||
| 	var response openai.TextResponse | ||||
| 	err = json.Unmarshal(responseBody, &response) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	fullTextResponse := ResponseCloudflare2OpenAI(&cloudflareResponse) | ||||
| 	fullTextResponse.Model = modelName | ||||
| 	usage := openai.ResponseText2Usage(cloudflareResponse.Result.Response, modelName, promptTokens) | ||||
| 	fullTextResponse.Usage = *usage | ||||
| 	fullTextResponse.Id = helper.GetResponseID(c) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	response.Model = modelName | ||||
| 	var responseText string | ||||
| 	for _, v := range response.Choices { | ||||
| 		responseText += v.Message.Content.(string) | ||||
| 	} | ||||
| 	usage := openai.ResponseText2Usage(responseText, modelName, promptTokens) | ||||
| 	response.Usage = *usage | ||||
| 	response.Id = helper.GetResponseID(c) | ||||
| 	jsonResponse, err := json.Marshal(response) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	_, _ = c.Writer.Write(jsonResponse) | ||||
| 	return nil, usage | ||||
| } | ||||
|   | ||||
| @@ -1,25 +1,13 @@ | ||||
| package cloudflare | ||||
|  | ||||
| import "github.com/songquanpeng/one-api/relay/model" | ||||
|  | ||||
| type Request struct { | ||||
| 	Lora        string  `json:"lora,omitempty"` | ||||
| 	MaxTokens   int     `json:"max_tokens,omitempty"` | ||||
| 	Prompt      string  `json:"prompt,omitempty"` | ||||
| 	Raw         bool    `json:"raw,omitempty"` | ||||
| 	Stream      bool    `json:"stream,omitempty"` | ||||
| 	Temperature float64 `json:"temperature,omitempty"` | ||||
| } | ||||
|  | ||||
| type Result struct { | ||||
| 	Response string `json:"response"` | ||||
| } | ||||
|  | ||||
| type Response struct { | ||||
| 	Result   Result   `json:"result"` | ||||
| 	Success  bool     `json:"success"` | ||||
| 	Errors   []string `json:"errors"` | ||||
| 	Messages []string `json:"messages"` | ||||
| } | ||||
|  | ||||
| type StreamResponse struct { | ||||
| 	Response string `json:"response"` | ||||
| 	Messages    []model.Message `json:"messages,omitempty"` | ||||
| 	Lora        string          `json:"lora,omitempty"` | ||||
| 	MaxTokens   int             `json:"max_tokens,omitempty"` | ||||
| 	Prompt      string          `json:"prompt,omitempty"` | ||||
| 	Raw         bool            `json:"raw,omitempty"` | ||||
| 	Stream      bool            `json:"stream,omitempty"` | ||||
| 	Temperature *float64        `json:"temperature,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -43,7 +43,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 		K:                textRequest.TopK, | ||||
| 		Stream:           textRequest.Stream, | ||||
| 		FrequencyPenalty: textRequest.FrequencyPenalty, | ||||
| 		PresencePenalty:  textRequest.FrequencyPenalty, | ||||
| 		PresencePenalty:  textRequest.PresencePenalty, | ||||
| 		Seed:             int(textRequest.Seed), | ||||
| 	} | ||||
| 	if cohereRequest.Model == "" { | ||||
|   | ||||
| @@ -10,15 +10,15 @@ type Request struct { | ||||
| 	PromptTruncation string        `json:"prompt_truncation,omitempty"` // 默认值为"AUTO" | ||||
| 	Connectors       []Connector   `json:"connectors,omitempty"` | ||||
| 	Documents        []Document    `json:"documents,omitempty"` | ||||
| 	Temperature      float64       `json:"temperature,omitempty"` // 默认值为0.3 | ||||
| 	Temperature      *float64      `json:"temperature,omitempty"` // 默认值为0.3 | ||||
| 	MaxTokens        int           `json:"max_tokens,omitempty"` | ||||
| 	MaxInputTokens   int           `json:"max_input_tokens,omitempty"` | ||||
| 	K                int           `json:"k,omitempty"` // 默认值为0 | ||||
| 	P                float64       `json:"p,omitempty"` // 默认值为0.75 | ||||
| 	P                *float64      `json:"p,omitempty"` // 默认值为0.75 | ||||
| 	Seed             int           `json:"seed,omitempty"` | ||||
| 	StopSequences    []string      `json:"stop_sequences,omitempty"` | ||||
| 	FrequencyPenalty float64       `json:"frequency_penalty,omitempty"` // 默认值为0.0 | ||||
| 	PresencePenalty  float64       `json:"presence_penalty,omitempty"`  // 默认值为0.0 | ||||
| 	FrequencyPenalty *float64      `json:"frequency_penalty,omitempty"` // 默认值为0.0 | ||||
| 	PresencePenalty  *float64      `json:"presence_penalty,omitempty"`  // 默认值为0.0 | ||||
| 	Tools            []Tool        `json:"tools,omitempty"` | ||||
| 	ToolResults      []ToolResult  `json:"tool_results,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -7,8 +7,12 @@ import ( | ||||
| ) | ||||
|  | ||||
| func GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	if meta.Mode == relaymode.ChatCompletions { | ||||
| 	switch meta.Mode { | ||||
| 	case relaymode.ChatCompletions: | ||||
| 		return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil | ||||
| 	case relaymode.Embeddings: | ||||
| 		return fmt.Sprintf("%s/api/v3/embeddings", meta.BaseURL), nil | ||||
| 	default: | ||||
| 	} | ||||
| 	return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode) | ||||
| } | ||||
|   | ||||
| @@ -7,7 +7,6 @@ import ( | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	channelhelper "github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| @@ -24,7 +23,15 @@ func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion) | ||||
| 	var defaultVersion string | ||||
| 	switch meta.ActualModelName { | ||||
| 	case "gemini-2.0-flash-exp", | ||||
| 		"gemini-2.0-flash-thinking-exp", | ||||
| 		"gemini-2.0-flash-thinking-exp-01-21": | ||||
| 		defaultVersion = "v1beta" | ||||
| 	} | ||||
|  | ||||
| 	version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion) | ||||
| 	action := "" | ||||
| 	switch meta.Mode { | ||||
| 	case relaymode.Embeddings: | ||||
| @@ -36,6 +43,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	if meta.IsStream { | ||||
| 		action = "streamGenerateContent?alt=sse" | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -3,6 +3,9 @@ package gemini | ||||
| // https://ai.google.dev/models/gemini | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro", | ||||
| 	"gemini-pro-vision", "gemini-1.0-pro-vision-001", "embedding-001", "text-embedding-004", | ||||
| 	"gemini-pro", "gemini-1.0-pro", | ||||
| 	"gemini-1.5-flash", "gemini-1.5-pro", | ||||
| 	"text-embedding-004", "aqa", | ||||
| 	"gemini-2.0-flash-exp", | ||||
| 	"gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21", | ||||
| } | ||||
|   | ||||
| @@ -4,11 +4,12 @@ import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| @@ -28,6 +29,11 @@ const ( | ||||
| 	VisionMaxImageNum = 16 | ||||
| ) | ||||
|  | ||||
| var mimeTypeMap = map[string]string{ | ||||
| 	"json_object": "application/json", | ||||
| 	"text":        "text/plain", | ||||
| } | ||||
|  | ||||
| // Setting safety to the lowest possible values since Gemini is already powerless enough | ||||
| func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 	geminiRequest := ChatRequest{ | ||||
| @@ -49,6 +55,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 				Category:  "HARM_CATEGORY_DANGEROUS_CONTENT", | ||||
| 				Threshold: config.GeminiSafetySetting, | ||||
| 			}, | ||||
| 			{ | ||||
| 				Category:  "HARM_CATEGORY_CIVIC_INTEGRITY", | ||||
| 				Threshold: config.GeminiSafetySetting, | ||||
| 			}, | ||||
| 		}, | ||||
| 		GenerationConfig: ChatGenerationConfig{ | ||||
| 			Temperature:     textRequest.Temperature, | ||||
| @@ -56,6 +66,15 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 			MaxOutputTokens: textRequest.MaxTokens, | ||||
| 		}, | ||||
| 	} | ||||
| 	if textRequest.ResponseFormat != nil { | ||||
| 		if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok { | ||||
| 			geminiRequest.GenerationConfig.ResponseMimeType = mimeType | ||||
| 		} | ||||
| 		if textRequest.ResponseFormat.JsonSchema != nil { | ||||
| 			geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema | ||||
| 			geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"] | ||||
| 		} | ||||
| 	} | ||||
| 	if textRequest.Tools != nil { | ||||
| 		functions := make([]model.Function, 0, len(textRequest.Tools)) | ||||
| 		for _, tool := range textRequest.Tools { | ||||
| @@ -232,7 +251,14 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 			if candidate.Content.Parts[0].FunctionCall != nil { | ||||
| 				choice.Message.ToolCalls = getToolCalls(&candidate) | ||||
| 			} else { | ||||
| 				choice.Message.Content = candidate.Content.Parts[0].Text | ||||
| 				var builder strings.Builder | ||||
| 				for _, part := range candidate.Content.Parts { | ||||
| 					if i > 0 { | ||||
| 						builder.WriteString("\n") | ||||
| 					} | ||||
| 					builder.WriteString(part.Text) | ||||
| 				} | ||||
| 				choice.Message.Content = builder.String() | ||||
| 			} | ||||
| 		} else { | ||||
| 			choice.Message.Content = "" | ||||
|   | ||||
| @@ -65,10 +65,12 @@ type ChatTools struct { | ||||
| } | ||||
|  | ||||
| type ChatGenerationConfig struct { | ||||
| 	Temperature     float64  `json:"temperature,omitempty"` | ||||
| 	TopP            float64  `json:"topP,omitempty"` | ||||
| 	TopK            float64  `json:"topK,omitempty"` | ||||
| 	MaxOutputTokens int      `json:"maxOutputTokens,omitempty"` | ||||
| 	CandidateCount  int      `json:"candidateCount,omitempty"` | ||||
| 	StopSequences   []string `json:"stopSequences,omitempty"` | ||||
| 	ResponseMimeType string   `json:"responseMimeType,omitempty"` | ||||
| 	ResponseSchema   any      `json:"responseSchema,omitempty"` | ||||
| 	Temperature      *float64 `json:"temperature,omitempty"` | ||||
| 	TopP             *float64 `json:"topP,omitempty"` | ||||
| 	TopK             float64  `json:"topK,omitempty"` | ||||
| 	MaxOutputTokens  int      `json:"maxOutputTokens,omitempty"` | ||||
| 	CandidateCount   int      `json:"candidateCount,omitempty"` | ||||
| 	StopSequences    []string `json:"stopSequences,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -4,9 +4,24 @@ package groq | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"gemma-7b-it", | ||||
| 	"llama2-7b-2048", | ||||
| 	"llama2-70b-4096", | ||||
| 	"mixtral-8x7b-32768", | ||||
| 	"llama3-8b-8192", | ||||
| 	"gemma2-9b-it", | ||||
| 	"llama-3.1-70b-versatile", | ||||
| 	"llama-3.1-8b-instant", | ||||
| 	"llama-3.2-11b-text-preview", | ||||
| 	"llama-3.2-11b-vision-preview", | ||||
| 	"llama-3.2-1b-preview", | ||||
| 	"llama-3.2-3b-preview", | ||||
| 	"llama-3.2-11b-vision-preview", | ||||
| 	"llama-3.2-90b-text-preview", | ||||
| 	"llama-3.2-90b-vision-preview", | ||||
| 	"llama-guard-3-8b", | ||||
| 	"llama3-70b-8192", | ||||
| 	"llama3-8b-8192", | ||||
| 	"llama3-groq-70b-8192-tool-use-preview", | ||||
| 	"llama3-groq-8b-8192-tool-use-preview", | ||||
| 	"llava-v1.5-7b-4096-preview", | ||||
| 	"mixtral-8x7b-32768", | ||||
| 	"distil-whisper-large-v3-en", | ||||
| 	"whisper-large-v3", | ||||
| 	"whisper-large-v3-turbo", | ||||
| } | ||||
|   | ||||
							
								
								
									
										19
									
								
								relay/adaptor/novita/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								relay/adaptor/novita/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| package novita | ||||
|  | ||||
| // https://novita.ai/llm-api | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"meta-llama/llama-3-8b-instruct", | ||||
| 	"meta-llama/llama-3-70b-instruct", | ||||
| 	"nousresearch/hermes-2-pro-llama-3-8b", | ||||
| 	"nousresearch/nous-hermes-llama2-13b", | ||||
| 	"mistralai/mistral-7b-instruct", | ||||
| 	"cognitivecomputations/dolphin-mixtral-8x22b", | ||||
| 	"sao10k/l3-70b-euryale-v2.1", | ||||
| 	"sophosympatheia/midnight-rose-70b", | ||||
| 	"gryphe/mythomax-l2-13b", | ||||
| 	"Nous-Hermes-2-Mixtral-8x7B-DPO", | ||||
| 	"lzlv_70b", | ||||
| 	"teknium/openhermes-2.5-mistral-7b", | ||||
| 	"microsoft/wizardlm-2-8x22b", | ||||
| } | ||||
							
								
								
									
										15
									
								
								relay/adaptor/novita/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								relay/adaptor/novita/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| package novita | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| ) | ||||
|  | ||||
| func GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	if meta.Mode == relaymode.ChatCompletions { | ||||
| 		return fmt.Sprintf("%s/chat/completions", meta.BaseURL), nil | ||||
| 	} | ||||
| 	return "", fmt.Errorf("unsupported relay mode %d for novita", meta.Mode) | ||||
| } | ||||
| @@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	// https://github.com/ollama/ollama/blob/main/docs/api.md | ||||
| 	fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) | ||||
| 	if meta.Mode == relaymode.Embeddings { | ||||
| 		fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL) | ||||
| 		fullRequestURL = fmt.Sprintf("%s/api/embed", meta.BaseURL) | ||||
| 	} | ||||
| 	return fullRequestURL, nil | ||||
| } | ||||
|   | ||||
| @@ -31,6 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 			TopP:             request.TopP, | ||||
| 			FrequencyPenalty: request.FrequencyPenalty, | ||||
| 			PresencePenalty:  request.PresencePenalty, | ||||
| 			NumPredict:       request.MaxTokens, | ||||
| 			NumCtx:           request.NumCtx, | ||||
| 		}, | ||||
| 		Stream: request.Stream, | ||||
| 	} | ||||
| @@ -118,8 +120,10 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 	common.SetEventStreamHeaders(c) | ||||
|  | ||||
| 	for scanner.Scan() { | ||||
| 		data := strings.TrimPrefix(scanner.Text(), "}") | ||||
| 		data = data + "}" | ||||
| 		data := scanner.Text() | ||||
| 		if strings.HasPrefix(data, "}") { | ||||
| 			data = strings.TrimPrefix(data, "}") + "}" | ||||
| 		} | ||||
|  | ||||
| 		var ollamaResponse ChatResponse | ||||
| 		err := json.Unmarshal([]byte(data), &ollamaResponse) | ||||
| @@ -157,8 +161,15 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
|  | ||||
| func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { | ||||
| 	return &EmbeddingRequest{ | ||||
| 		Model:  request.Model, | ||||
| 		Prompt: strings.Join(request.ParseInput(), " "), | ||||
| 		Model: request.Model, | ||||
| 		Input: request.ParseInput(), | ||||
| 		Options: &Options{ | ||||
| 			Seed:             int(request.Seed), | ||||
| 			Temperature:      request.Temperature, | ||||
| 			TopP:             request.TopP, | ||||
| 			FrequencyPenalty: request.FrequencyPenalty, | ||||
| 			PresencePenalty:  request.PresencePenalty, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -201,15 +212,17 @@ func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.Embeddi | ||||
| 	openAIEmbeddingResponse := openai.EmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]openai.EmbeddingResponseItem, 0, 1), | ||||
| 		Model:  "text-embedding-v1", | ||||
| 		Model:  response.Model, | ||||
| 		Usage:  model.Usage{TotalTokens: 0}, | ||||
| 	} | ||||
|  | ||||
| 	openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ | ||||
| 		Object:    `embedding`, | ||||
| 		Index:     0, | ||||
| 		Embedding: response.Embedding, | ||||
| 	}) | ||||
| 	for i, embedding := range response.Embeddings { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ | ||||
| 			Object:    `embedding`, | ||||
| 			Index:     i, | ||||
| 			Embedding: embedding, | ||||
| 		}) | ||||
| 	} | ||||
| 	return &openAIEmbeddingResponse | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -1,12 +1,14 @@ | ||||
| package ollama | ||||
|  | ||||
| type Options struct { | ||||
| 	Seed             int     `json:"seed,omitempty"` | ||||
| 	Temperature      float64 `json:"temperature,omitempty"` | ||||
| 	TopK             int     `json:"top_k,omitempty"` | ||||
| 	TopP             float64 `json:"top_p,omitempty"` | ||||
| 	FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` | ||||
| 	PresencePenalty  float64 `json:"presence_penalty,omitempty"` | ||||
| 	Seed             int      `json:"seed,omitempty"` | ||||
| 	Temperature      *float64 `json:"temperature,omitempty"` | ||||
| 	TopK             int      `json:"top_k,omitempty"` | ||||
| 	TopP             *float64 `json:"top_p,omitempty"` | ||||
| 	FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` | ||||
| 	PresencePenalty  *float64 `json:"presence_penalty,omitempty"` | ||||
| 	NumPredict       int      `json:"num_predict,omitempty"` | ||||
| 	NumCtx           int      `json:"num_ctx,omitempty"` | ||||
| } | ||||
|  | ||||
| type Message struct { | ||||
| @@ -37,11 +39,15 @@ type ChatResponse struct { | ||||
| } | ||||
|  | ||||
| type EmbeddingRequest struct { | ||||
| 	Model  string `json:"model"` | ||||
| 	Prompt string `json:"prompt"` | ||||
| 	Model string   `json:"model"` | ||||
| 	Input []string `json:"input"` | ||||
| 	// Truncate  bool     `json:"truncate,omitempty"` | ||||
| 	Options *Options `json:"options,omitempty"` | ||||
| 	// KeepAlive string   `json:"keep_alive,omitempty"` | ||||
| } | ||||
|  | ||||
| type EmbeddingResponse struct { | ||||
| 	Error     string    `json:"error,omitempty"` | ||||
| 	Embedding []float64 `json:"embedding,omitempty"` | ||||
| 	Error      string      `json:"error,omitempty"` | ||||
| 	Model      string      `json:"model"` | ||||
| 	Embeddings [][]float64 `json:"embeddings"` | ||||
| } | ||||
|   | ||||
| @@ -3,17 +3,19 @@ package openai | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/doubao" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/minimax" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/novita" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| @@ -48,6 +50,8 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 		return minimax.GetRequestURL(meta) | ||||
| 	case channeltype.Doubao: | ||||
| 		return doubao.GetRequestURL(meta) | ||||
| 	case channeltype.Novita: | ||||
| 		return novita.GetRequestURL(meta) | ||||
| 	default: | ||||
| 		return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil | ||||
| 	} | ||||
| @@ -71,6 +75,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	if request.Stream { | ||||
| 		// always return usage in stream mode | ||||
| 		if request.StreamOptions == nil { | ||||
| 			request.StreamOptions = &model.StreamOptions{} | ||||
| 		} | ||||
| 		request.StreamOptions.IncludeUsage = true | ||||
| 	} | ||||
| 	return request, nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -10,8 +10,11 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/minimax" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/mistral" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/moonshot" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/novita" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/siliconflow" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/stepfun" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/togetherai" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/xai" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| ) | ||||
|  | ||||
| @@ -28,6 +31,9 @@ var CompatibleChannels = []int{ | ||||
| 	channeltype.StepFun, | ||||
| 	channeltype.DeepSeek, | ||||
| 	channeltype.TogetherAI, | ||||
| 	channeltype.Novita, | ||||
| 	channeltype.SiliconFlow, | ||||
| 	channeltype.XAI, | ||||
| } | ||||
|  | ||||
| func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||
| @@ -56,6 +62,12 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||
| 		return "together.ai", togetherai.ModelList | ||||
| 	case channeltype.Doubao: | ||||
| 		return "doubao", doubao.ModelList | ||||
| 	case channeltype.Novita: | ||||
| 		return "novita", novita.ModelList | ||||
| 	case channeltype.SiliconFlow: | ||||
| 		return "siliconflow", siliconflow.ModelList | ||||
| 	case channeltype.XAI: | ||||
| 		return "xai", xai.ModelList | ||||
| 	default: | ||||
| 		return "openai", ModelList | ||||
| 	} | ||||
|   | ||||
| @@ -8,6 +8,10 @@ var ModelList = []string{ | ||||
| 	"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", | ||||
| 	"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", | ||||
| 	"gpt-4o", "gpt-4o-2024-05-13", | ||||
| 	"gpt-4o-2024-08-06", | ||||
| 	"gpt-4o-2024-11-20", | ||||
| 	"chatgpt-4o-latest", | ||||
| 	"gpt-4o-mini", "gpt-4o-mini-2024-07-18", | ||||
| 	"gpt-4-vision-preview", | ||||
| 	"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", | ||||
| 	"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", | ||||
| @@ -17,4 +21,7 @@ var ModelList = []string{ | ||||
| 	"dall-e-2", "dall-e-3", | ||||
| 	"whisper-1", | ||||
| 	"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", | ||||
| 	"o1", "o1-2024-12-17", | ||||
| 	"o1-preview", "o1-preview-2024-09-12", | ||||
| 	"o1-mini", "o1-mini-2024-09-12", | ||||
| } | ||||
|   | ||||
| @@ -2,15 +2,16 @@ package openai | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { | ||||
| func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage { | ||||
| 	usage := &model.Usage{} | ||||
| 	usage.PromptTokens = promptTokens | ||||
| 	usage.CompletionTokens = CountTokenText(responseText, modeName) | ||||
| 	usage.CompletionTokens = CountTokenText(responseText, modelName) | ||||
| 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens | ||||
| 	return usage | ||||
| } | ||||
|   | ||||
| @@ -4,11 +4,12 @@ import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/conv" | ||||
| @@ -31,6 +32,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
|  | ||||
| 	doneRendered := false | ||||
| 	for scanner.Scan() { | ||||
| 		data := scanner.Text() | ||||
| 		if len(data) < dataPrefixLength { // ignore blank line or wrong format | ||||
| @@ -41,6 +43,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | ||||
| 		} | ||||
| 		if strings.HasPrefix(data[dataPrefixLength:], done) { | ||||
| 			render.StringData(c, data) | ||||
| 			doneRendered = true | ||||
| 			continue | ||||
| 		} | ||||
| 		switch relayMode { | ||||
| @@ -52,8 +55,8 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | ||||
| 				render.StringData(c, data) // if error happened, pass the data to client | ||||
| 				continue                   // just ignore the error | ||||
| 			} | ||||
| 			if len(streamResponse.Choices) == 0 { | ||||
| 				// but for empty choice, we should not pass it to client, this is for azure | ||||
| 			if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil { | ||||
| 				// but for empty choice and no usage, we should not pass it to client, this is for azure | ||||
| 				continue // just ignore empty choice | ||||
| 			} | ||||
| 			render.StringData(c, data) | ||||
| @@ -81,7 +84,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | ||||
| 		logger.SysError("error reading stream: " + err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	render.Done(c) | ||||
| 	if !doneRendered { | ||||
| 		render.Done(c) | ||||
| 	} | ||||
|  | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -97,7 +97,11 @@ func CountTokenMessages(messages []model.Message, model string) int { | ||||
| 				m := it.(map[string]any) | ||||
| 				switch m["type"] { | ||||
| 				case "text": | ||||
| 					tokenNum += getTokenNum(tokenEncoder, m["text"].(string)) | ||||
| 					if textValue, ok := m["text"]; ok { | ||||
| 						if textString, ok := textValue.(string); ok { | ||||
| 							tokenNum += getTokenNum(tokenEncoder, textString) | ||||
| 						} | ||||
| 					} | ||||
| 				case "image_url": | ||||
| 					imageUrl, ok := m["image_url"].(map[string]any) | ||||
| 					if ok { | ||||
| @@ -106,7 +110,7 @@ func CountTokenMessages(messages []model.Message, model string) int { | ||||
| 						if imageUrl["detail"] != nil { | ||||
| 							detail = imageUrl["detail"].(string) | ||||
| 						} | ||||
| 						imageTokens, err := countImageTokens(url, detail) | ||||
| 						imageTokens, err := countImageTokens(url, detail, model) | ||||
| 						if err != nil { | ||||
| 							logger.SysError("error counting image tokens: " + err.Error()) | ||||
| 						} else { | ||||
| @@ -130,11 +134,15 @@ const ( | ||||
| 	lowDetailCost         = 85 | ||||
| 	highDetailCostPerTile = 170 | ||||
| 	additionalCost        = 85 | ||||
| 	// gpt-4o-mini cost higher than other model | ||||
| 	gpt4oMiniLowDetailCost  = 2833 | ||||
| 	gpt4oMiniHighDetailCost = 5667 | ||||
| 	gpt4oMiniAdditionalCost = 2833 | ||||
| ) | ||||
|  | ||||
| // https://platform.openai.com/docs/guides/vision/calculating-costs | ||||
| // https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb | ||||
| func countImageTokens(url string, detail string) (_ int, err error) { | ||||
| func countImageTokens(url string, detail string, model string) (_ int, err error) { | ||||
| 	var fetchSize = true | ||||
| 	var width, height int | ||||
| 	// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding | ||||
| @@ -168,6 +176,9 @@ func countImageTokens(url string, detail string) (_ int, err error) { | ||||
| 	} | ||||
| 	switch detail { | ||||
| 	case "low": | ||||
| 		if strings.HasPrefix(model, "gpt-4o-mini") { | ||||
| 			return gpt4oMiniLowDetailCost, nil | ||||
| 		} | ||||
| 		return lowDetailCost, nil | ||||
| 	case "high": | ||||
| 		if fetchSize { | ||||
| @@ -187,6 +198,9 @@ func countImageTokens(url string, detail string) (_ int, err error) { | ||||
| 			height = int(float64(height) * ratio) | ||||
| 		} | ||||
| 		numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512)) | ||||
| 		if strings.HasPrefix(model, "gpt-4o-mini") { | ||||
| 			return numSquares*gpt4oMiniHighDetailCost + gpt4oMiniAdditionalCost, nil | ||||
| 		} | ||||
| 		result := numSquares*highDetailCostPerTile + additionalCost | ||||
| 		return result, nil | ||||
| 	default: | ||||
|   | ||||
| @@ -1,8 +1,16 @@ | ||||
| package openai | ||||
|  | ||||
| import "github.com/songquanpeng/one-api/relay/model" | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode { | ||||
| 	logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err)) | ||||
|  | ||||
| 	Error := model.Error{ | ||||
| 		Message: err.Error(), | ||||
| 		Type:    "one_api_error", | ||||
|   | ||||
| @@ -19,11 +19,11 @@ type Prompt struct { | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	Prompt         Prompt  `json:"prompt"` | ||||
| 	Temperature    float64 `json:"temperature,omitempty"` | ||||
| 	CandidateCount int     `json:"candidateCount,omitempty"` | ||||
| 	TopP           float64 `json:"topP,omitempty"` | ||||
| 	TopK           int     `json:"topK,omitempty"` | ||||
| 	Prompt         Prompt   `json:"prompt"` | ||||
| 	Temperature    *float64 `json:"temperature,omitempty"` | ||||
| 	CandidateCount int      `json:"candidateCount,omitempty"` | ||||
| 	TopP           *float64 `json:"topP,omitempty"` | ||||
| 	TopK           int      `json:"topK,omitempty"` | ||||
| } | ||||
|  | ||||
| type Error struct { | ||||
|   | ||||
							
								
								
									
										89
									
								
								relay/adaptor/proxy/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								relay/adaptor/proxy/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | ||||
| package proxy | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	channelhelper "github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| var _ adaptor.Adaptor = new(Adaptor) | ||||
|  | ||||
| const channelName = "proxy" | ||||
|  | ||||
| type Adaptor struct{} | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	return nil, errors.New("notimplement") | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	for k, v := range resp.Header { | ||||
| 		for _, vv := range v { | ||||
| 			c.Writer.Header().Set(k, vv) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	if _, gerr := io.Copy(c.Writer, resp.Body); gerr != nil { | ||||
| 		return nil, &relaymodel.ErrorWithStatusCode{ | ||||
| 			StatusCode: http.StatusInternalServerError, | ||||
| 			Error: relaymodel.Error{ | ||||
| 				Message: gerr.Error(), | ||||
| 			}, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetModelList() (models []string) { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return channelName | ||||
| } | ||||
|  | ||||
| // GetRequestURL remove static prefix, and return the real request url to the upstream service | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	prefix := fmt.Sprintf("/v1/oneapi/proxy/%d", meta.ChannelId) | ||||
| 	return meta.BaseURL + strings.TrimPrefix(meta.RequestURLPath, prefix), nil | ||||
|  | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	for k, v := range c.Request.Header { | ||||
| 		req.Header.Set(k, v[0]) | ||||
| 	} | ||||
|  | ||||
| 	// remove unnecessary headers | ||||
| 	req.Header.Del("Host") | ||||
| 	req.Header.Del("Content-Length") | ||||
| 	req.Header.Del("Accept-Encoding") | ||||
| 	req.Header.Del("Connection") | ||||
|  | ||||
| 	// set authorization header | ||||
| 	req.Header.Set("Authorization", meta.APIKey) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	return nil, errors.Errorf("not implement") | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return channelhelper.DoRequestHelper(a, c, meta, requestBody) | ||||
| } | ||||
							
								
								
									
										136
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,136 @@ | ||||
| package replicate | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"slices" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| 	meta *meta.Meta | ||||
| } | ||||
|  | ||||
| // ConvertImageRequest implements adaptor.Adaptor. | ||||
| func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	return DrawImageRequest{ | ||||
| 		Input: ImageInput{ | ||||
| 			Steps:           25, | ||||
| 			Prompt:          request.Prompt, | ||||
| 			Guidance:        3, | ||||
| 			Seed:            int(time.Now().UnixNano()), | ||||
| 			SafetyTolerance: 5, | ||||
| 			NImages:         1, // replicate will always return 1 image | ||||
| 			Width:           1440, | ||||
| 			Height:          1440, | ||||
| 			AspectRatio:     "1:1", | ||||
| 		}, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if !request.Stream { | ||||
| 		// TODO: support non-stream mode | ||||
| 		return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true") | ||||
| 	} | ||||
|  | ||||
| 	// Build the prompt from OpenAI messages | ||||
| 	var promptBuilder strings.Builder | ||||
| 	for _, message := range request.Messages { | ||||
| 		switch msgCnt := message.Content.(type) { | ||||
| 		case string: | ||||
| 			promptBuilder.WriteString(message.Role) | ||||
| 			promptBuilder.WriteString(": ") | ||||
| 			promptBuilder.WriteString(msgCnt) | ||||
| 			promptBuilder.WriteString("\n") | ||||
| 		default: | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	replicateRequest := ReplicateChatRequest{ | ||||
| 		Input: ChatInput{ | ||||
| 			Prompt:           promptBuilder.String(), | ||||
| 			MaxTokens:        request.MaxTokens, | ||||
| 			Temperature:      1.0, | ||||
| 			TopP:             1.0, | ||||
| 			PresencePenalty:  0.0, | ||||
| 			FrequencyPenalty: 0.0, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	// Map optional fields | ||||
| 	if request.Temperature != nil { | ||||
| 		replicateRequest.Input.Temperature = *request.Temperature | ||||
| 	} | ||||
| 	if request.TopP != nil { | ||||
| 		replicateRequest.Input.TopP = *request.TopP | ||||
| 	} | ||||
| 	if request.PresencePenalty != nil { | ||||
| 		replicateRequest.Input.PresencePenalty = *request.PresencePenalty | ||||
| 	} | ||||
| 	if request.FrequencyPenalty != nil { | ||||
| 		replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty | ||||
| 	} | ||||
| 	if request.MaxTokens > 0 { | ||||
| 		replicateRequest.Input.MaxTokens = request.MaxTokens | ||||
| 	} else if request.MaxTokens == 0 { | ||||
| 		replicateRequest.Input.MaxTokens = 500 | ||||
| 	} | ||||
|  | ||||
| 	return replicateRequest, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| 	a.meta = meta | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	if !slices.Contains(ModelList, meta.OriginModelName) { | ||||
| 		return "", errors.Errorf("model %s not supported", meta.OriginModelName) | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	adaptor.SetupCommonRequestHeader(c, req, meta) | ||||
| 	req.Header.Set("Authorization", "Bearer "+meta.APIKey) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	logger.Info(c, "send request to replicate") | ||||
| 	return adaptor.DoRequestHelper(a, c, meta, requestBody) | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	switch meta.Mode { | ||||
| 	case relaymode.ImagesGenerations: | ||||
| 		err, usage = ImageHandler(c, resp) | ||||
| 	case relaymode.ChatCompletions: | ||||
| 		err, usage = ChatHandler(c, resp) | ||||
| 	default: | ||||
| 		err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetModelList() []string { | ||||
| 	return ModelList | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return "replicate" | ||||
| } | ||||
							
								
								
									
										191
									
								
								relay/adaptor/replicate/chat.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								relay/adaptor/replicate/chat.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,191 @@ | ||||
| package replicate | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func ChatHandler(c *gin.Context, resp *http.Response) ( | ||||
| 	srvErr *model.ErrorWithStatusCode, usage *model.Usage) { | ||||
| 	if resp.StatusCode != http.StatusCreated { | ||||
| 		payload, _ := io.ReadAll(resp.Body) | ||||
| 		return openai.ErrorWrapper( | ||||
| 				errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), | ||||
| 				"bad_status_code", http.StatusInternalServerError), | ||||
| 			nil | ||||
| 	} | ||||
|  | ||||
| 	respBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	respData := new(ChatResponse) | ||||
| 	if err = json.Unmarshal(respBody, respData); err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	for { | ||||
| 		err = func() error { | ||||
| 			// get task | ||||
| 			taskReq, err := http.NewRequestWithContext(c.Request.Context(), | ||||
| 				http.MethodGet, respData.URLs.Get, nil) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "new request") | ||||
| 			} | ||||
|  | ||||
| 			taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) | ||||
| 			taskResp, err := http.DefaultClient.Do(taskReq) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "get task") | ||||
| 			} | ||||
| 			defer taskResp.Body.Close() | ||||
|  | ||||
| 			if taskResp.StatusCode != http.StatusOK { | ||||
| 				payload, _ := io.ReadAll(taskResp.Body) | ||||
| 				return errors.Errorf("bad status code [%d]%s", | ||||
| 					taskResp.StatusCode, string(payload)) | ||||
| 			} | ||||
|  | ||||
| 			taskBody, err := io.ReadAll(taskResp.Body) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "read task response") | ||||
| 			} | ||||
|  | ||||
| 			taskData := new(ChatResponse) | ||||
| 			if err = json.Unmarshal(taskBody, taskData); err != nil { | ||||
| 				return errors.Wrap(err, "decode task response") | ||||
| 			} | ||||
|  | ||||
| 			switch taskData.Status { | ||||
| 			case "succeeded": | ||||
| 			case "failed", "canceled": | ||||
| 				return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error) | ||||
| 			default: | ||||
| 				time.Sleep(time.Second * 3) | ||||
| 				return errNextLoop | ||||
| 			} | ||||
|  | ||||
| 			if taskData.URLs.Stream == "" { | ||||
| 				return errors.New("stream url is empty") | ||||
| 			} | ||||
|  | ||||
| 			// request stream url | ||||
| 			responseText, err := chatStreamHandler(c, taskData.URLs.Stream) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "chat stream handler") | ||||
| 			} | ||||
|  | ||||
| 			ctxMeta := meta.GetByContext(c) | ||||
| 			usage = openai.ResponseText2Usage(responseText, | ||||
| 				ctxMeta.ActualModelName, ctxMeta.PromptTokens) | ||||
| 			return nil | ||||
| 		}() | ||||
| 		if err != nil { | ||||
| 			if errors.Is(err, errNextLoop) { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
|  | ||||
| 		break | ||||
| 	} | ||||
|  | ||||
| 	return nil, usage | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	eventPrefix = "event: " | ||||
| 	dataPrefix  = "data: " | ||||
| 	done        = "[DONE]" | ||||
| ) | ||||
|  | ||||
| func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) { | ||||
| 	// request stream endpoint | ||||
| 	streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil) | ||||
| 	if err != nil { | ||||
| 		return "", errors.Wrap(err, "new request to stream") | ||||
| 	} | ||||
|  | ||||
| 	streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) | ||||
| 	streamReq.Header.Set("Accept", "text/event-stream") | ||||
| 	streamReq.Header.Set("Cache-Control", "no-store") | ||||
|  | ||||
| 	resp, err := http.DefaultClient.Do(streamReq) | ||||
| 	if err != nil { | ||||
| 		return "", errors.Wrap(err, "do request to stream") | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		payload, _ := io.ReadAll(resp.Body) | ||||
| 		return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload)) | ||||
| 	} | ||||
|  | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(bufio.ScanLines) | ||||
|  | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	doneRendered := false | ||||
| 	for scanner.Scan() { | ||||
| 		line := strings.TrimSpace(scanner.Text()) | ||||
| 		if line == "" { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// Handle comments starting with ':' | ||||
| 		if strings.HasPrefix(line, ":") { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// Parse SSE fields | ||||
| 		if strings.HasPrefix(line, eventPrefix) { | ||||
| 			event := strings.TrimSpace(line[len(eventPrefix):]) | ||||
| 			var data string | ||||
| 			// Read the following lines to get data and id | ||||
| 			for scanner.Scan() { | ||||
| 				nextLine := scanner.Text() | ||||
| 				if nextLine == "" { | ||||
| 					break | ||||
| 				} | ||||
| 				if strings.HasPrefix(nextLine, dataPrefix) { | ||||
| 					data = nextLine[len(dataPrefix):] | ||||
| 				} else if strings.HasPrefix(nextLine, "id:") { | ||||
| 					// id = strings.TrimSpace(nextLine[len("id:"):]) | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			if event == "output" { | ||||
| 				render.StringData(c, data) | ||||
| 				responseText += data | ||||
| 			} else if event == "done" { | ||||
| 				render.Done(c) | ||||
| 				doneRendered = true | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := scanner.Err(); err != nil { | ||||
| 		return "", errors.Wrap(err, "scan stream") | ||||
| 	} | ||||
|  | ||||
| 	if !doneRendered { | ||||
| 		render.Done(c) | ||||
| 	} | ||||
|  | ||||
| 	return responseText, nil | ||||
| } | ||||
							
								
								
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | ||||
| package replicate | ||||
|  | ||||
| // ModelList is a list of models that can be used with Replicate. | ||||
| // | ||||
| // https://replicate.com/pricing | ||||
| var ModelList = []string{ | ||||
| 	// ------------------------------------- | ||||
| 	// image model | ||||
| 	// ------------------------------------- | ||||
| 	"black-forest-labs/flux-1.1-pro", | ||||
| 	"black-forest-labs/flux-1.1-pro-ultra", | ||||
| 	"black-forest-labs/flux-canny-dev", | ||||
| 	"black-forest-labs/flux-canny-pro", | ||||
| 	"black-forest-labs/flux-depth-dev", | ||||
| 	"black-forest-labs/flux-depth-pro", | ||||
| 	"black-forest-labs/flux-dev", | ||||
| 	"black-forest-labs/flux-dev-lora", | ||||
| 	"black-forest-labs/flux-fill-dev", | ||||
| 	"black-forest-labs/flux-fill-pro", | ||||
| 	"black-forest-labs/flux-pro", | ||||
| 	"black-forest-labs/flux-redux-dev", | ||||
| 	"black-forest-labs/flux-redux-schnell", | ||||
| 	"black-forest-labs/flux-schnell", | ||||
| 	"black-forest-labs/flux-schnell-lora", | ||||
| 	"ideogram-ai/ideogram-v2", | ||||
| 	"ideogram-ai/ideogram-v2-turbo", | ||||
| 	"recraft-ai/recraft-v3", | ||||
| 	"recraft-ai/recraft-v3-svg", | ||||
| 	"stability-ai/stable-diffusion-3", | ||||
| 	"stability-ai/stable-diffusion-3.5-large", | ||||
| 	"stability-ai/stable-diffusion-3.5-large-turbo", | ||||
| 	"stability-ai/stable-diffusion-3.5-medium", | ||||
| 	// ------------------------------------- | ||||
| 	// language model | ||||
| 	// ------------------------------------- | ||||
| 	"ibm-granite/granite-20b-code-instruct-8k", | ||||
| 	"ibm-granite/granite-3.0-2b-instruct", | ||||
| 	"ibm-granite/granite-3.0-8b-instruct", | ||||
| 	"ibm-granite/granite-8b-code-instruct-128k", | ||||
| 	"meta/llama-2-13b", | ||||
| 	"meta/llama-2-13b-chat", | ||||
| 	"meta/llama-2-70b", | ||||
| 	"meta/llama-2-70b-chat", | ||||
| 	"meta/llama-2-7b", | ||||
| 	"meta/llama-2-7b-chat", | ||||
| 	"meta/meta-llama-3.1-405b-instruct", | ||||
| 	"meta/meta-llama-3-70b", | ||||
| 	"meta/meta-llama-3-70b-instruct", | ||||
| 	"meta/meta-llama-3-8b", | ||||
| 	"meta/meta-llama-3-8b-instruct", | ||||
| 	"mistralai/mistral-7b-instruct-v0.2", | ||||
| 	"mistralai/mistral-7b-v0.1", | ||||
| 	"mistralai/mixtral-8x7b-instruct-v0.1", | ||||
| 	// ------------------------------------- | ||||
| 	// video model | ||||
| 	// ------------------------------------- | ||||
| 	// "minimax/video-01",  // TODO: implement the adaptor | ||||
| } | ||||
							
								
								
									
										222
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										222
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,222 @@ | ||||
| package replicate | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"image" | ||||
| 	"image/png" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"golang.org/x/image/webp" | ||||
| 	"golang.org/x/sync/errgroup" | ||||
| ) | ||||
|  | ||||
| // ImagesEditsHandler just copy response body to client | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-fill-pro | ||||
| // func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| // 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| // 	for k, v := range resp.Header { | ||||
| // 		c.Writer.Header().Set(k, v[0]) | ||||
| // 	} | ||||
|  | ||||
| // 	if _, err := io.Copy(c.Writer, resp.Body); err != nil { | ||||
| // 		return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil | ||||
| // 	} | ||||
| // 	defer resp.Body.Close() | ||||
|  | ||||
| // 	return nil, nil | ||||
| // } | ||||
|  | ||||
| var errNextLoop = errors.New("next_loop") | ||||
|  | ||||
| func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	if resp.StatusCode != http.StatusCreated { | ||||
| 		payload, _ := io.ReadAll(resp.Body) | ||||
| 		return openai.ErrorWrapper( | ||||
| 				errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), | ||||
| 				"bad_status_code", http.StatusInternalServerError), | ||||
| 			nil | ||||
| 	} | ||||
|  | ||||
| 	respBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	respData := new(ImageResponse) | ||||
| 	if err = json.Unmarshal(respBody, respData); err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	for { | ||||
| 		err = func() error { | ||||
| 			// get task | ||||
| 			taskReq, err := http.NewRequestWithContext(c.Request.Context(), | ||||
| 				http.MethodGet, respData.URLs.Get, nil) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "new request") | ||||
| 			} | ||||
|  | ||||
| 			taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) | ||||
| 			taskResp, err := http.DefaultClient.Do(taskReq) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "get task") | ||||
| 			} | ||||
| 			defer taskResp.Body.Close() | ||||
|  | ||||
| 			if taskResp.StatusCode != http.StatusOK { | ||||
| 				payload, _ := io.ReadAll(taskResp.Body) | ||||
| 				return errors.Errorf("bad status code [%d]%s", | ||||
| 					taskResp.StatusCode, string(payload)) | ||||
| 			} | ||||
|  | ||||
| 			taskBody, err := io.ReadAll(taskResp.Body) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "read task response") | ||||
| 			} | ||||
|  | ||||
| 			taskData := new(ImageResponse) | ||||
| 			if err = json.Unmarshal(taskBody, taskData); err != nil { | ||||
| 				return errors.Wrap(err, "decode task response") | ||||
| 			} | ||||
|  | ||||
| 			switch taskData.Status { | ||||
| 			case "succeeded": | ||||
| 			case "failed", "canceled": | ||||
| 				return errors.Errorf("task failed: %s", taskData.Status) | ||||
| 			default: | ||||
| 				time.Sleep(time.Second * 3) | ||||
| 				return errNextLoop | ||||
| 			} | ||||
|  | ||||
| 			output, err := taskData.GetOutput() | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "get output") | ||||
| 			} | ||||
| 			if len(output) == 0 { | ||||
| 				return errors.New("response output is empty") | ||||
| 			} | ||||
|  | ||||
| 			var mu sync.Mutex | ||||
| 			var pool errgroup.Group | ||||
| 			respBody := &openai.ImageResponse{ | ||||
| 				Created: taskData.CompletedAt.Unix(), | ||||
| 				Data:    []openai.ImageData{}, | ||||
| 			} | ||||
|  | ||||
| 			for _, imgOut := range output { | ||||
| 				imgOut := imgOut | ||||
| 				pool.Go(func() error { | ||||
| 					// download image | ||||
| 					downloadReq, err := http.NewRequestWithContext(c.Request.Context(), | ||||
| 						http.MethodGet, imgOut, nil) | ||||
| 					if err != nil { | ||||
| 						return errors.Wrap(err, "new request") | ||||
| 					} | ||||
|  | ||||
| 					imgResp, err := http.DefaultClient.Do(downloadReq) | ||||
| 					if err != nil { | ||||
| 						return errors.Wrap(err, "download image") | ||||
| 					} | ||||
| 					defer imgResp.Body.Close() | ||||
|  | ||||
| 					if imgResp.StatusCode != http.StatusOK { | ||||
| 						payload, _ := io.ReadAll(imgResp.Body) | ||||
| 						return errors.Errorf("bad status code [%d]%s", | ||||
| 							imgResp.StatusCode, string(payload)) | ||||
| 					} | ||||
|  | ||||
| 					imgData, err := io.ReadAll(imgResp.Body) | ||||
| 					if err != nil { | ||||
| 						return errors.Wrap(err, "read image") | ||||
| 					} | ||||
|  | ||||
| 					imgData, err = ConvertImageToPNG(imgData) | ||||
| 					if err != nil { | ||||
| 						return errors.Wrap(err, "convert image") | ||||
| 					} | ||||
|  | ||||
| 					mu.Lock() | ||||
| 					respBody.Data = append(respBody.Data, openai.ImageData{ | ||||
| 						B64Json: fmt.Sprintf("data:image/png;base64,%s", | ||||
| 							base64.StdEncoding.EncodeToString(imgData)), | ||||
| 					}) | ||||
| 					mu.Unlock() | ||||
|  | ||||
| 					return nil | ||||
| 				}) | ||||
| 			} | ||||
|  | ||||
| 			if err := pool.Wait(); err != nil { | ||||
| 				if len(respBody.Data) == 0 { | ||||
| 					return errors.WithStack(err) | ||||
| 				} | ||||
|  | ||||
| 				logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err)) | ||||
| 			} | ||||
|  | ||||
| 			c.JSON(http.StatusOK, respBody) | ||||
| 			return nil | ||||
| 		}() | ||||
| 		if err != nil { | ||||
| 			if errors.Is(err, errNextLoop) { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
|  | ||||
| 		break | ||||
| 	} | ||||
|  | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| // ConvertImageToPNG converts a WebP image to PNG format | ||||
| func ConvertImageToPNG(webpData []byte) ([]byte, error) { | ||||
| 	// bypass if it's already a PNG image | ||||
| 	if bytes.HasPrefix(webpData, []byte("\x89PNG")) { | ||||
| 		return webpData, nil | ||||
| 	} | ||||
|  | ||||
| 	// check if is jpeg, convert to png | ||||
| 	if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) { | ||||
| 		img, _, err := image.Decode(bytes.NewReader(webpData)) | ||||
| 		if err != nil { | ||||
| 			return nil, errors.Wrap(err, "decode jpeg") | ||||
| 		} | ||||
|  | ||||
| 		var pngBuffer bytes.Buffer | ||||
| 		if err := png.Encode(&pngBuffer, img); err != nil { | ||||
| 			return nil, errors.Wrap(err, "encode png") | ||||
| 		} | ||||
|  | ||||
| 		return pngBuffer.Bytes(), nil | ||||
| 	} | ||||
|  | ||||
| 	// Decode the WebP image | ||||
| 	img, err := webp.Decode(bytes.NewReader(webpData)) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "decode webp") | ||||
| 	} | ||||
|  | ||||
| 	// Encode the image as PNG | ||||
| 	var pngBuffer bytes.Buffer | ||||
| 	if err := png.Encode(&pngBuffer, img); err != nil { | ||||
| 		return nil, errors.Wrap(err, "encode png") | ||||
| 	} | ||||
|  | ||||
| 	return pngBuffer.Bytes(), nil | ||||
| } | ||||
							
								
								
									
										159
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,159 @@ | ||||
| package replicate | ||||
|  | ||||
| import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/pkg/errors" | ||||
| ) | ||||
|  | ||||
| // DrawImageRequest draw image by fluxpro | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json | ||||
| type DrawImageRequest struct { | ||||
| 	Input ImageInput `json:"input"` | ||||
| } | ||||
|  | ||||
| // ImageInput is input of DrawImageByFluxProRequest | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema | ||||
| type ImageInput struct { | ||||
| 	Steps           int    `json:"steps" binding:"required,min=1"` | ||||
| 	Prompt          string `json:"prompt" binding:"required,min=5"` | ||||
| 	ImagePrompt     string `json:"image_prompt"` | ||||
| 	Guidance        int    `json:"guidance" binding:"required,min=2,max=5"` | ||||
| 	Interval        int    `json:"interval" binding:"required,min=1,max=4"` | ||||
| 	AspectRatio     string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"` | ||||
| 	SafetyTolerance int    `json:"safety_tolerance" binding:"required,min=1,max=5"` | ||||
| 	Seed            int    `json:"seed"` | ||||
| 	NImages         int    `json:"n_images" binding:"required,min=1,max=8"` | ||||
| 	Width           int    `json:"width" binding:"required,min=256,max=1440"` | ||||
| 	Height          int    `json:"height" binding:"required,min=256,max=1440"` | ||||
| } | ||||
|  | ||||
| // InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-fill-pro/api/schema | ||||
| type InpaintingImageByFlusReplicateRequest struct { | ||||
| 	Input FluxInpaintingInput `json:"input"` | ||||
| } | ||||
|  | ||||
| // FluxInpaintingInput is input of DrawImageByFluxProRequest | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-fill-pro/api/schema | ||||
| type FluxInpaintingInput struct { | ||||
| 	Mask             string `json:"mask" binding:"required"` | ||||
| 	Image            string `json:"image" binding:"required"` | ||||
| 	Seed             int    `json:"seed"` | ||||
| 	Steps            int    `json:"steps" binding:"required,min=1"` | ||||
| 	Prompt           string `json:"prompt" binding:"required,min=5"` | ||||
| 	Guidance         int    `json:"guidance" binding:"required,min=2,max=5"` | ||||
| 	OutputFormat     string `json:"output_format"` | ||||
| 	SafetyTolerance  int    `json:"safety_tolerance" binding:"required,min=1,max=5"` | ||||
| 	PromptUnsampling bool   `json:"prompt_unsampling"` | ||||
| } | ||||
|  | ||||
| // ImageResponse is response of DrawImageByFluxProRequest | ||||
| // | ||||
| // https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json | ||||
| type ImageResponse struct { | ||||
| 	CompletedAt time.Time        `json:"completed_at"` | ||||
| 	CreatedAt   time.Time        `json:"created_at"` | ||||
| 	DataRemoved bool             `json:"data_removed"` | ||||
| 	Error       string           `json:"error"` | ||||
| 	ID          string           `json:"id"` | ||||
| 	Input       DrawImageRequest `json:"input"` | ||||
| 	Logs        string           `json:"logs"` | ||||
| 	Metrics     FluxMetrics      `json:"metrics"` | ||||
| 	// Output could be `string` or `[]string` | ||||
| 	Output    any       `json:"output"` | ||||
| 	StartedAt time.Time `json:"started_at"` | ||||
| 	Status    string    `json:"status"` | ||||
| 	URLs      FluxURLs  `json:"urls"` | ||||
| 	Version   string    `json:"version"` | ||||
| } | ||||
|  | ||||
| func (r *ImageResponse) GetOutput() ([]string, error) { | ||||
| 	switch v := r.Output.(type) { | ||||
| 	case string: | ||||
| 		return []string{v}, nil | ||||
| 	case []string: | ||||
| 		return v, nil | ||||
| 	case nil: | ||||
| 		return nil, nil | ||||
| 	case []interface{}: | ||||
| 		// convert []interface{} to []string | ||||
| 		ret := make([]string, len(v)) | ||||
| 		for idx, vv := range v { | ||||
| 			if vvv, ok := vv.(string); ok { | ||||
| 				ret[idx] = vvv | ||||
| 			} else { | ||||
| 				return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		return ret, nil | ||||
| 	default: | ||||
| 		return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // FluxMetrics is metrics of ImageResponse | ||||
| type FluxMetrics struct { | ||||
| 	ImageCount  int     `json:"image_count"` | ||||
| 	PredictTime float64 `json:"predict_time"` | ||||
| 	TotalTime   float64 `json:"total_time"` | ||||
| } | ||||
|  | ||||
| // FluxURLs is urls of ImageResponse | ||||
| type FluxURLs struct { | ||||
| 	Get    string `json:"get"` | ||||
| 	Cancel string `json:"cancel"` | ||||
| } | ||||
|  | ||||
| type ReplicateChatRequest struct { | ||||
| 	Input ChatInput `json:"input" form:"input" binding:"required"` | ||||
| } | ||||
|  | ||||
| // ChatInput is input of ChatByReplicateRequest | ||||
| // | ||||
| // https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema | ||||
| type ChatInput struct { | ||||
| 	TopK             int     `json:"top_k"` | ||||
| 	TopP             float64 `json:"top_p"` | ||||
| 	Prompt           string  `json:"prompt"` | ||||
| 	MaxTokens        int     `json:"max_tokens"` | ||||
| 	MinTokens        int     `json:"min_tokens"` | ||||
| 	Temperature      float64 `json:"temperature"` | ||||
| 	SystemPrompt     string  `json:"system_prompt"` | ||||
| 	StopSequences    string  `json:"stop_sequences"` | ||||
| 	PromptTemplate   string  `json:"prompt_template"` | ||||
| 	PresencePenalty  float64 `json:"presence_penalty"` | ||||
| 	FrequencyPenalty float64 `json:"frequency_penalty"` | ||||
| } | ||||
|  | ||||
| // ChatResponse is response of ChatByReplicateRequest | ||||
| // | ||||
| // https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json | ||||
| type ChatResponse struct { | ||||
| 	CompletedAt time.Time   `json:"completed_at"` | ||||
| 	CreatedAt   time.Time   `json:"created_at"` | ||||
| 	DataRemoved bool        `json:"data_removed"` | ||||
| 	Error       string      `json:"error"` | ||||
| 	ID          string      `json:"id"` | ||||
| 	Input       ChatInput   `json:"input"` | ||||
| 	Logs        string      `json:"logs"` | ||||
| 	Metrics     FluxMetrics `json:"metrics"` | ||||
| 	// Output could be `string` or `[]string` | ||||
| 	Output    []string        `json:"output"` | ||||
| 	StartedAt time.Time       `json:"started_at"` | ||||
| 	Status    string          `json:"status"` | ||||
| 	URLs      ChatResponseUrl `json:"urls"` | ||||
| 	Version   string          `json:"version"` | ||||
| } | ||||
|  | ||||
| // ChatResponseUrl is task urls of ChatResponse | ||||
| type ChatResponseUrl struct { | ||||
| 	Stream string `json:"stream"` | ||||
| 	Get    string `json:"get"` | ||||
| 	Cancel string `json:"cancel"` | ||||
| } | ||||
							
								
								
									
										36
									
								
								relay/adaptor/siliconflow/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								relay/adaptor/siliconflow/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | ||||
| package siliconflow | ||||
|  | ||||
| // https://docs.siliconflow.cn/docs/getting-started | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"deepseek-ai/deepseek-llm-67b-chat", | ||||
| 	"Qwen/Qwen1.5-14B-Chat", | ||||
| 	"Qwen/Qwen1.5-7B-Chat", | ||||
| 	"Qwen/Qwen1.5-110B-Chat", | ||||
| 	"Qwen/Qwen1.5-32B-Chat", | ||||
| 	"01-ai/Yi-1.5-6B-Chat", | ||||
| 	"01-ai/Yi-1.5-9B-Chat-16K", | ||||
| 	"01-ai/Yi-1.5-34B-Chat-16K", | ||||
| 	"THUDM/chatglm3-6b", | ||||
| 	"deepseek-ai/DeepSeek-V2-Chat", | ||||
| 	"THUDM/glm-4-9b-chat", | ||||
| 	"Qwen/Qwen2-72B-Instruct", | ||||
| 	"Qwen/Qwen2-7B-Instruct", | ||||
| 	"Qwen/Qwen2-57B-A14B-Instruct", | ||||
| 	"deepseek-ai/DeepSeek-Coder-V2-Instruct", | ||||
| 	"Qwen/Qwen2-1.5B-Instruct", | ||||
| 	"internlm/internlm2_5-7b-chat", | ||||
| 	"BAAI/bge-large-en-v1.5", | ||||
| 	"BAAI/bge-large-zh-v1.5", | ||||
| 	"Pro/Qwen/Qwen2-7B-Instruct", | ||||
| 	"Pro/Qwen/Qwen2-1.5B-Instruct", | ||||
| 	"Pro/Qwen/Qwen1.5-7B-Chat", | ||||
| 	"Pro/THUDM/glm-4-9b-chat", | ||||
| 	"Pro/THUDM/chatglm3-6b", | ||||
| 	"Pro/01-ai/Yi-1.5-9B-Chat-16K", | ||||
| 	"Pro/01-ai/Yi-1.5-6B-Chat", | ||||
| 	"Pro/google/gemma-2-9b-it", | ||||
| 	"Pro/internlm/internlm2_5-7b-chat", | ||||
| 	"Pro/meta-llama/Meta-Llama-3-8B-Instruct", | ||||
| 	"Pro/mistralai/Mistral-7B-Instruct-v0.2", | ||||
| } | ||||
| @@ -1,7 +1,13 @@ | ||||
| package stepfun | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"step-1-8k", | ||||
| 	"step-1-32k", | ||||
| 	"step-1-128k", | ||||
| 	"step-1-256k", | ||||
| 	"step-1-flash", | ||||
| 	"step-2-16k", | ||||
| 	"step-1v-8k", | ||||
| 	"step-1v-32k", | ||||
| 	"step-1-200k", | ||||
| 	"step-1x-medium", | ||||
| } | ||||
|   | ||||
| @@ -2,16 +2,19 @@ package tencent | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| ) | ||||
|  | ||||
| // https://cloud.tencent.com/document/api/1729/101837 | ||||
| @@ -52,10 +55,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	tencentRequest := ConvertRequest(*request) | ||||
| 	var convertedRequest any | ||||
| 	switch relayMode { | ||||
| 	case relaymode.Embeddings: | ||||
| 		a.Action = "GetEmbedding" | ||||
| 		convertedRequest = ConvertEmbeddingRequest(*request) | ||||
| 	default: | ||||
| 		a.Action = "ChatCompletions" | ||||
| 		convertedRequest = ConvertRequest(*request) | ||||
| 	} | ||||
| 	// we have to calculate the sign here | ||||
| 	a.Sign = GetSign(*tencentRequest, a, secretId, secretKey) | ||||
| 	return tencentRequest, nil | ||||
| 	a.Sign = GetSign(convertedRequest, a, secretId, secretKey) | ||||
| 	return convertedRequest, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| @@ -75,7 +86,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met | ||||
| 		err, responseText = StreamHandler(c, resp) | ||||
| 		usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, resp) | ||||
| 		switch meta.Mode { | ||||
| 		case relaymode.Embeddings: | ||||
| 			err, usage = EmbeddingHandler(c, resp) | ||||
| 		default: | ||||
| 			err, usage = Handler(c, resp) | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -5,4 +5,6 @@ var ModelList = []string{ | ||||
| 	"hunyuan-standard", | ||||
| 	"hunyuan-standard-256K", | ||||
| 	"hunyuan-pro", | ||||
| 	"hunyuan-vision", | ||||
| 	"hunyuan-embedding", | ||||
| } | ||||
|   | ||||
| @@ -8,7 +8,6 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| @@ -16,11 +15,14 @@ import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/conv" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"github.com/songquanpeng/one-api/common/render" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| @@ -39,13 +41,73 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 		Model:       &request.Model, | ||||
| 		Stream:      &request.Stream, | ||||
| 		Messages:    messages, | ||||
| 		TopP:        &request.TopP, | ||||
| 		Temperature: &request.Temperature, | ||||
| 		TopP:        request.TopP, | ||||
| 		Temperature: request.Temperature, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { | ||||
| 	return &EmbeddingRequest{ | ||||
| 		InputList: request.ParseInput(), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	var tencentResponseP EmbeddingResponseP | ||||
| 	err := json.NewDecoder(resp.Body).Decode(&tencentResponseP) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	tencentResponse := tencentResponseP.Response | ||||
| 	if tencentResponse.Error.Code != "" { | ||||
| 		return &model.ErrorWithStatusCode{ | ||||
| 			Error: model.Error{ | ||||
| 				Message: tencentResponse.Error.Message, | ||||
| 				Code:    tencentResponse.Error.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	requestModel := c.GetString(ctxkey.RequestModel) | ||||
| 	fullTextResponse := embeddingResponseTencent2OpenAI(&tencentResponse) | ||||
| 	fullTextResponse.Model = requestModel | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|  | ||||
| func embeddingResponseTencent2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { | ||||
| 	openAIEmbeddingResponse := openai.EmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]openai.EmbeddingResponseItem, 0, len(response.Data)), | ||||
| 		Model:  "hunyuan-embedding", | ||||
| 		Usage:  model.Usage{TotalTokens: response.EmbeddingUsage.TotalTokens}, | ||||
| 	} | ||||
|  | ||||
| 	for _, item := range response.Data { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ | ||||
| 			Object:    item.Object, | ||||
| 			Index:     item.Index, | ||||
| 			Embedding: item.Embedding, | ||||
| 		}) | ||||
| 	} | ||||
| 	return &openAIEmbeddingResponse | ||||
| } | ||||
|  | ||||
| func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      response.ReqID, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Usage: model.Usage{ | ||||
| @@ -148,7 +210,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	TencentResponse = responseP.Response | ||||
| 	if TencentResponse.Error.Code != 0 { | ||||
| 	if TencentResponse.Error.Code != "" { | ||||
| 		return &model.ErrorWithStatusCode{ | ||||
| 			Error: model.Error{ | ||||
| 				Message: TencentResponse.Error.Message, | ||||
| @@ -195,7 +257,7 @@ func hmacSha256(s, key string) string { | ||||
| 	return string(hashed.Sum(nil)) | ||||
| } | ||||
|  | ||||
| func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string { | ||||
| func GetSign(req any, adaptor *Adaptor, secId, secKey string) string { | ||||
| 	// build canonical request string | ||||
| 	host := "hunyuan.tencentcloudapi.com" | ||||
| 	httpRequestMethod := "POST" | ||||
|   | ||||
| @@ -35,16 +35,16 @@ type ChatRequest struct { | ||||
| 	// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。 | ||||
| 	// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。 | ||||
| 	// 3. 非必要不建议使用,不合理的取值会影响效果。 | ||||
| 	TopP *float64 `json:"TopP"` | ||||
| 	TopP *float64 `json:"TopP,omitempty"` | ||||
| 	// 说明: | ||||
| 	// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。 | ||||
| 	// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。 | ||||
| 	// 3. 非必要不建议使用,不合理的取值会影响效果。 | ||||
| 	Temperature *float64 `json:"Temperature"` | ||||
| 	Temperature *float64 `json:"Temperature,omitempty"` | ||||
| } | ||||
|  | ||||
| type Error struct { | ||||
| 	Code    int    `json:"Code"` | ||||
| 	Code    string `json:"Code"` | ||||
| 	Message string `json:"Message"` | ||||
| } | ||||
|  | ||||
| @@ -61,15 +61,41 @@ type ResponseChoices struct { | ||||
| } | ||||
|  | ||||
| type ChatResponse struct { | ||||
| 	Choices []ResponseChoices `json:"Choices,omitempty"` // 结果 | ||||
| 	Created int64             `json:"Created,omitempty"` // unix 时间戳的字符串 | ||||
| 	Id      string            `json:"Id,omitempty"`      // 会话 id | ||||
| 	Usage   Usage             `json:"Usage,omitempty"`   // token 数量 | ||||
| 	Error   Error             `json:"Error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值 | ||||
| 	Note    string            `json:"Note,omitempty"`    // 注释 | ||||
| 	ReqID   string            `json:"Req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 | ||||
| 	Choices []ResponseChoices `json:"Choices,omitempty"`   // 结果 | ||||
| 	Created int64             `json:"Created,omitempty"`   // unix 时间戳的字符串 | ||||
| 	Id      string            `json:"Id,omitempty"`        // 会话 id | ||||
| 	Usage   Usage             `json:"Usage,omitempty"`     // token 数量 | ||||
| 	Error   Error             `json:"Error,omitempty"`     // 错误信息 注意:此字段可能返回 null,表示取不到有效值 | ||||
| 	Note    string            `json:"Note,omitempty"`      // 注释 | ||||
| 	ReqID   string            `json:"RequestId,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 | ||||
| } | ||||
|  | ||||
| type ChatResponseP struct { | ||||
| 	Response ChatResponse `json:"Response,omitempty"` | ||||
| } | ||||
|  | ||||
| type EmbeddingRequest struct { | ||||
| 	InputList []string `json:"InputList"` | ||||
| } | ||||
|  | ||||
| type EmbeddingData struct { | ||||
| 	Embedding []float64 `json:"Embedding"` | ||||
| 	Index     int       `json:"Index"` | ||||
| 	Object    string    `json:"Object"` | ||||
| } | ||||
|  | ||||
| type EmbeddingUsage struct { | ||||
| 	PromptTokens int `json:"PromptTokens"` | ||||
| 	TotalTokens  int `json:"TotalTokens"` | ||||
| } | ||||
|  | ||||
| type EmbeddingResponse struct { | ||||
| 	Data           []EmbeddingData `json:"Data"` | ||||
| 	EmbeddingUsage EmbeddingUsage  `json:"Usage,omitempty"` | ||||
| 	RequestId      string          `json:"RequestId,omitempty"` | ||||
| 	Error          Error           `json:"Error,omitempty"` | ||||
| } | ||||
|  | ||||
| type EmbeddingResponseP struct { | ||||
| 	Response EmbeddingResponse `json:"Response,omitempty"` | ||||
| } | ||||
|   | ||||
							
								
								
									
										117
									
								
								relay/adaptor/vertexai/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								relay/adaptor/vertexai/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,117 @@ | ||||
| package vertexai | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	channelhelper "github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| var _ adaptor.Adaptor = new(Adaptor) | ||||
|  | ||||
| const channelName = "vertexai" | ||||
|  | ||||
| type Adaptor struct{} | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
|  | ||||
| 	adaptor := GetAdaptor(request.Model) | ||||
| 	if adaptor == nil { | ||||
| 		return nil, errors.New("adaptor not found") | ||||
| 	} | ||||
|  | ||||
| 	return adaptor.ConvertRequest(c, relayMode, request) | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	adaptor := GetAdaptor(meta.ActualModelName) | ||||
| 	if adaptor == nil { | ||||
| 		return nil, &relaymodel.ErrorWithStatusCode{ | ||||
| 			StatusCode: http.StatusInternalServerError, | ||||
| 			Error: relaymodel.Error{ | ||||
| 				Message: "adaptor not found", | ||||
| 			}, | ||||
| 		} | ||||
| 	} | ||||
| 	return adaptor.DoResponse(c, resp, meta) | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetModelList() (models []string) { | ||||
| 	models = modelList | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return channelName | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	suffix := "" | ||||
| 	if strings.HasPrefix(meta.ActualModelName, "gemini") { | ||||
| 		if meta.IsStream { | ||||
| 			suffix = "streamGenerateContent?alt=sse" | ||||
| 		} else { | ||||
| 			suffix = "generateContent" | ||||
| 		} | ||||
| 	} else { | ||||
| 		if meta.IsStream { | ||||
| 			suffix = "streamRawPredict?alt=sse" | ||||
| 		} else { | ||||
| 			suffix = "rawPredict" | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if meta.BaseURL != "" { | ||||
| 		return fmt.Sprintf( | ||||
| 			"%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", | ||||
| 			meta.BaseURL, | ||||
| 			meta.Config.VertexAIProjectID, | ||||
| 			meta.Config.Region, | ||||
| 			meta.ActualModelName, | ||||
| 			suffix, | ||||
| 		), nil | ||||
| 	} | ||||
| 	return fmt.Sprintf( | ||||
| 		"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", | ||||
| 		meta.Config.Region, | ||||
| 		meta.Config.VertexAIProjectID, | ||||
| 		meta.Config.Region, | ||||
| 		meta.ActualModelName, | ||||
| 		suffix, | ||||
| 	), nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	adaptor.SetupCommonRequestHeader(c, req, meta) | ||||
| 	token, err := getToken(c, meta.ChannelId, meta.Config.VertexAIADC) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", "Bearer "+token) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	return request, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return channelhelper.DoRequestHelper(a, c, meta, requestBody) | ||||
| } | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user