mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	Compare commits
	
		
			127 Commits
		
	
	
		
			v0.6.8-alp
			...
			v0.6.11-al
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					c351e196e6 | ||
| 
						 | 
					a316ed7abc | ||
| 
						 | 
					0895d8660e | ||
| 
						 | 
					be1ed114f4 | ||
| 
						 | 
					eb6da573a3 | ||
| 
						 | 
					0a6273fc08 | ||
| 
						 | 
					5997fce454 | ||
| 
						 | 
					0df6d7a131 | ||
| 
						 | 
					93fdb60de5 | ||
| 
						 | 
					4db834da95 | ||
| 
						 | 
					6818ed5ca8 | ||
| 
						 | 
					7be3b5547d | ||
| 
						 | 
					2d7ea61d67 | ||
| 
						 | 
					83b34be067 | ||
| 
						 | 
					d5d879afdc | ||
| 
						 | 
					0f205a3aa3 | ||
| 
						 | 
					76c3f87351 | ||
| 
						 | 
					6d9a92f8f7 | ||
| 
						 | 
					835f0e0d67 | ||
| 
						 | 
					a6981f0d51 | ||
| 
						 | 
					678d613179 | ||
| 
						 | 
					be089a072b | ||
| 
						 | 
					45d10aa3df | ||
| 
						 | 
					9cdd48ac22 | ||
| 
						 | 
					310e7120e5 | ||
| 
						 | 
					3d29713268 | ||
| 
						 | 
					f2c7c424e9 | ||
| 
						 | 
					38a42bb265 | ||
| 
						 | 
					fa2e8f44b1 | ||
| 
						 | 
					9f74101543 | ||
| 
						 | 
					28a271a896 | ||
| 
						 | 
					e8ea87fff3 | ||
| 
						 | 
					abe2d2dba8 | ||
| 
						 | 
					4bcaa064d6 | ||
| 
						 | 
					52d81e0e24 | ||
| 
						 | 
					dc8c3bc69e | ||
| 
						 | 
					b4e69df802 | ||
| 
						 | 
					d9f74bdff3 | ||
| 
						 | 
					fa2a772731 | ||
| 
						 | 
					4f68f3e1b3 | ||
| 
						 | 
					0bab887b2d | ||
| 
						 | 
					0230d36643 | ||
| 
						 | 
					bad57d049a | ||
| 
						 | 
					dc470ce82e | ||
| 
						 | 
					ea0721d525 | ||
| 
						 | 
					d0402f9086 | ||
| 
						 | 
					1fead8e7f7 | ||
| 
						 | 
					09911a301d | ||
| 
						 | 
					f95e6b78b8 | ||
| 
						 | 
					605bb06667 | ||
| 
						 | 
					d88e07fd9a | ||
| 
						 | 
					3915ce9814 | ||
| 
						 | 
					999defc88b | ||
| 
						 | 
					b51c47bc77 | ||
| 
						 | 
					4f25cde132 | ||
| 
						 | 
					d89e9d7e44 | ||
| 
						 | 
					a858292b54 | ||
| 
						 | 
					ff589b5e4a | ||
| 
						 | 
					95e8c16338 | ||
| 
						 | 
					381172cb36 | ||
| 
						 | 
					59eae186a3 | ||
| 
						 | 
					ce52f355bb | ||
| 
						 | 
					cb9d0a74c9 | ||
| 
						 | 
					49ffb1c60d | ||
| 
						 | 
					2f16649896 | ||
| 
						 | 
					af3aa57bd6 | ||
| 
						 | 
					e9f117ff72 | ||
| 
						 | 
					6bb5247bd6 | ||
| 
						 | 
					305ce14fe3 | ||
| 
						 | 
					36c8f4f15c | ||
| 
						 | 
					45b51ea0ee | ||
| 
						 | 
					7c8628bd95 | ||
| 
						 | 
					6ab87f8a08 | ||
| 
						 | 
					833fa7ad6f | ||
| 
						 | 
					6eb0770a89 | ||
| 
						 | 
					92cd46d64f | ||
| 
						 | 
					2b2dc2c733 | ||
| 
						 | 
					a3d7df7f89 | ||
| 
						 | 
					c368232f50 | ||
| 
						 | 
					cbfc983dc3 | ||
| 
						 | 
					8ec092ba44 | ||
| 
						 | 
					b0b88a79ff | ||
| 
						 | 
					7e51b04221 | ||
| 
						 | 
					f75a17f8eb | ||
| 
						 | 
					6f13a3bb3c | ||
| 
						 | 
					f092eed1db | ||
| 
						 | 
					629378691b | ||
| 
						 | 
					3716e1b0e6 | ||
| 
						 | 
					a4d6e7a886 | ||
| 
						 | 
					cb772e5d06 | ||
| 
						 | 
					e32cb0b844 | ||
| 
						 | 
					fdd7bf41c0 | ||
| 
						 | 
					29389ed44f | ||
| 
						 | 
					88acc5a614 | ||
| 
						 | 
					a21681096a | ||
| 
						 | 
					32f90a79a8 | ||
| 
						 | 
					99c8c77504 | ||
| 
						 | 
					649ecbf29c | ||
| 
						 | 
					3a27c90910 | ||
| 
						 | 
					cba82404ae | ||
| 
						 | 
					c9ac670ba1 | ||
| 
						 | 
					15f815c23c | ||
| 
						 | 
					89b63ca96f | ||
| 
						 | 
					8cc54489b9 | ||
| 
						 | 
					58bf60805e | ||
| 
						 | 
					6714cf96d6 | ||
| 
						 | 
					f9774698e9 | ||
| 
						 | 
					2af6f6a166 | ||
| 
						 | 
					04bb3ef392 | ||
| 
						 | 
					b4bfa418a8 | ||
| 
						 | 
					e7e99e558a | ||
| 
						 | 
					402fcf7f79 | ||
| 
						 | 
					36039e329e | ||
| 
						 | 
					c936198ac8 | ||
| 
						 | 
					296ab013b8 | ||
| 
						 | 
					5f03c856b4 | ||
| 
						 | 
					39383e5532 | ||
| 
						 | 
					2a892c1937 | ||
| 
						 | 
					adba54acd3 | ||
| 
						 | 
					6209ff9ea9 | ||
| 
						 | 
					1c44d7e1cd | ||
| 
						 | 
					a3eefb7af0 | ||
| 
						 | 
					b65bee46fb | ||
| 
						 | 
					422a4e8ee5 | ||
| 
						 | 
					cf9b5f0b92 | ||
| 
						 | 
					65acb94f45 | ||
| 
						 | 
					6ad169975f | 
							
								
								
									
										10
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							@@ -1,19 +1,17 @@
 | 
			
		||||
name: CI
 | 
			
		||||
 | 
			
		||||
# This setup assumes that you run the unit tests with code coverage in the same
 | 
			
		||||
# workflow that will also print the coverage report as comment to the pull request. 
 | 
			
		||||
# workflow that will also print the coverage report as comment to the pull request.
 | 
			
		||||
# Therefore, you need to trigger this workflow when a pull request is (re)opened or
 | 
			
		||||
# when new code is pushed to the branch of the pull request. In addition, you also
 | 
			
		||||
# need to trigger this workflow when new code is pushed to the main branch because 
 | 
			
		||||
# need to trigger this workflow when new code is pushed to the main branch because
 | 
			
		||||
# we need to upload the code coverage results as artifact for the main branch as
 | 
			
		||||
# well since it will be the baseline code coverage.
 | 
			
		||||
# 
 | 
			
		||||
#
 | 
			
		||||
# We do not want to trigger the workflow for pushes to *any* branch because this
 | 
			
		||||
# would trigger our jobs twice on pull requests (once from "push" event and once
 | 
			
		||||
# from "pull_request->synchronize")
 | 
			
		||||
on:
 | 
			
		||||
  pull_request:
 | 
			
		||||
    types: [opened, reopened, synchronize]
 | 
			
		||||
  push:
 | 
			
		||||
    branches:
 | 
			
		||||
      - 'main'
 | 
			
		||||
@@ -31,7 +29,7 @@ jobs:
 | 
			
		||||
        with:
 | 
			
		||||
          go-version: ^1.22
 | 
			
		||||
 | 
			
		||||
      # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a 
 | 
			
		||||
      # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
 | 
			
		||||
      # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt")
 | 
			
		||||
      # in the next step as well as the next job.
 | 
			
		||||
      - name: Test
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										61
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										61
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							@@ -1,61 +0,0 @@
 | 
			
		||||
name: Publish Docker image (amd64)
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - 'v*.*.*'
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
    inputs:
 | 
			
		||||
      name:
 | 
			
		||||
        description: 'reason'
 | 
			
		||||
        required: false
 | 
			
		||||
jobs:
 | 
			
		||||
  push_to_registries:
 | 
			
		||||
    name: Push Docker image to multiple registries
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    permissions:
 | 
			
		||||
      packages: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Check out the repo
 | 
			
		||||
        uses: actions/checkout@v3
 | 
			
		||||
 | 
			
		||||
      - name: Check repository URL
 | 
			
		||||
        run: |
 | 
			
		||||
          REPO_URL=$(git config --get remote.origin.url)
 | 
			
		||||
          if [[ $REPO_URL == *"pro" ]]; then
 | 
			
		||||
            exit 1
 | 
			
		||||
          fi        
 | 
			
		||||
 | 
			
		||||
      - name: Save version info
 | 
			
		||||
        run: |
 | 
			
		||||
          git describe --tags > VERSION 
 | 
			
		||||
 | 
			
		||||
      - name: Log in to Docker Hub
 | 
			
		||||
        uses: docker/login-action@v2
 | 
			
		||||
        with:
 | 
			
		||||
          username: ${{ secrets.DOCKERHUB_USERNAME }}
 | 
			
		||||
          password: ${{ secrets.DOCKERHUB_TOKEN }}
 | 
			
		||||
 | 
			
		||||
      - name: Log in to the Container registry
 | 
			
		||||
        uses: docker/login-action@v2
 | 
			
		||||
        with:
 | 
			
		||||
          registry: ghcr.io
 | 
			
		||||
          username: ${{ github.actor }}
 | 
			
		||||
          password: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
 | 
			
		||||
      - name: Extract metadata (tags, labels) for Docker
 | 
			
		||||
        id: meta
 | 
			
		||||
        uses: docker/metadata-action@v4
 | 
			
		||||
        with:
 | 
			
		||||
          images: |
 | 
			
		||||
            justsong/one-api
 | 
			
		||||
            ghcr.io/${{ github.repository }}
 | 
			
		||||
 | 
			
		||||
      - name: Build and push Docker images
 | 
			
		||||
        uses: docker/build-push-action@v3
 | 
			
		||||
        with:
 | 
			
		||||
          context: .
 | 
			
		||||
          push: true
 | 
			
		||||
          tags: ${{ steps.meta.outputs.tags }}
 | 
			
		||||
          labels: ${{ steps.meta.outputs.labels }}
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
name: Publish Docker image (amd64, English)
 | 
			
		||||
name: Publish Docker image (English)
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
@@ -34,6 +34,13 @@ jobs:
 | 
			
		||||
      - name: Translate
 | 
			
		||||
        run: |
 | 
			
		||||
          python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json
 | 
			
		||||
 | 
			
		||||
      - name: Set up QEMU
 | 
			
		||||
        uses: docker/setup-qemu-action@v2
 | 
			
		||||
 | 
			
		||||
      - name: Set up Docker Buildx
 | 
			
		||||
        uses: docker/setup-buildx-action@v2
 | 
			
		||||
 | 
			
		||||
      - name: Log in to Docker Hub
 | 
			
		||||
        uses: docker/login-action@v2
 | 
			
		||||
        with:
 | 
			
		||||
@@ -51,6 +58,7 @@ jobs:
 | 
			
		||||
        uses: docker/build-push-action@v3
 | 
			
		||||
        with:
 | 
			
		||||
          context: .
 | 
			
		||||
          platforms: linux/amd64,linux/arm64
 | 
			
		||||
          push: true
 | 
			
		||||
          tags: ${{ steps.meta.outputs.tags }}
 | 
			
		||||
          labels: ${{ steps.meta.outputs.labels }}
 | 
			
		||||
@@ -1,10 +1,9 @@
 | 
			
		||||
name: Publish Docker image (arm64)
 | 
			
		||||
name: Publish Docker image
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - 'v*.*.*'
 | 
			
		||||
      - '!*-alpha*'
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
    inputs:
 | 
			
		||||
      name:
 | 
			
		||||
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -9,4 +9,5 @@ logs
 | 
			
		||||
data
 | 
			
		||||
/web/node_modules
 | 
			
		||||
cmd.md
 | 
			
		||||
.env
 | 
			
		||||
.env
 | 
			
		||||
/one-api
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										21
									
								
								Dockerfile
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								Dockerfile
									
									
									
									
									
								
							@@ -1,24 +1,23 @@
 | 
			
		||||
FROM node:16 as builder
 | 
			
		||||
FROM --platform=$BUILDPLATFORM node:16 AS builder
 | 
			
		||||
 | 
			
		||||
WORKDIR /web
 | 
			
		||||
COPY ./VERSION .
 | 
			
		||||
COPY ./web .
 | 
			
		||||
 | 
			
		||||
WORKDIR /web/default
 | 
			
		||||
RUN npm install
 | 
			
		||||
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
 | 
			
		||||
RUN npm install --prefix /web/default & \
 | 
			
		||||
    npm install --prefix /web/berry & \
 | 
			
		||||
    npm install --prefix /web/air & \
 | 
			
		||||
    wait
 | 
			
		||||
 | 
			
		||||
WORKDIR /web/berry
 | 
			
		||||
RUN npm install
 | 
			
		||||
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
 | 
			
		||||
 | 
			
		||||
WORKDIR /web/air
 | 
			
		||||
RUN npm install
 | 
			
		||||
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
 | 
			
		||||
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat /web/default/VERSION) npm run build --prefix /web/default & \
 | 
			
		||||
    DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat /web/berry/VERSION) npm run build --prefix /web/berry & \
 | 
			
		||||
    DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat /web/air/VERSION) npm run build --prefix /web/air & \
 | 
			
		||||
    wait
 | 
			
		||||
 | 
			
		||||
FROM golang:alpine AS builder2
 | 
			
		||||
 | 
			
		||||
RUN apk add --no-cache g++
 | 
			
		||||
RUN apk add --no-cache gcc musl-dev libc-dev sqlite-dev build-base
 | 
			
		||||
 | 
			
		||||
ENV GO111MODULE=on \
 | 
			
		||||
    CGO_ENABLED=1 \
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										23
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								README.md
									
									
									
									
									
								
							@@ -89,6 +89,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
   + [x] [DeepL](https://www.deepl.com/)
 | 
			
		||||
   + [x] [together.ai](https://www.together.ai/)
 | 
			
		||||
   + [x] [novita.ai](https://www.novita.ai/)
 | 
			
		||||
   + [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud)
 | 
			
		||||
   + [x] [xAI](https://x.ai/)
 | 
			
		||||
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
 | 
			
		||||
3. 支持通过**负载均衡**的方式访问多个渠道。
 | 
			
		||||
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
 | 
			
		||||
@@ -113,8 +115,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
21. 支持 Cloudflare Turnstile 用户校验。
 | 
			
		||||
22. 支持用户管理,支持**多种用户登录注册方式**:
 | 
			
		||||
    + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
 | 
			
		||||
    + 支持使用飞书进行授权登录。
 | 
			
		||||
    + [GitHub 开放授权](https://github.com/settings/applications/new)。
 | 
			
		||||
    + 支持[飞书授权登录](https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/authen-v1/authorize/get)([这里有 One API 的实现细节阐述供参考](https://iamazing.cn/page/feishu-oauth-login))。
 | 
			
		||||
    + 支持 [GitHub 授权登录](https://github.com/settings/applications/new)。
 | 
			
		||||
    + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
 | 
			
		||||
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
 | 
			
		||||
24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。
 | 
			
		||||
@@ -173,6 +175,10 @@ sudo service nginx restart
 | 
			
		||||
 | 
			
		||||
初始账号用户名为 `root`,密码为 `123456`。
 | 
			
		||||
 | 
			
		||||
### 通过宝塔面板进行一键部署
 | 
			
		||||
1. 安装宝塔面板9.2.0及以上版本,前往 [宝塔面板](https://www.bt.cn/new/download.html?r=dk_oneapi) 官网,选择正式版的脚本下载安装;
 | 
			
		||||
2. 安装后登录宝塔面板,在左侧菜单栏中点击 `Docker`,首次进入会提示安装 `Docker` 服务,点击立即安装,按提示完成安装;
 | 
			
		||||
3. 安装完成后在应用商店中搜索 `One-API`,点击安装,配置域名等基本信息即可完成安装;
 | 
			
		||||
 | 
			
		||||
### 基于 Docker Compose 进行部署
 | 
			
		||||
 | 
			
		||||
@@ -216,7 +222,7 @@ docker-compose ps
 | 
			
		||||
3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。
 | 
			
		||||
4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis,无论主从。
 | 
			
		||||
5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。
 | 
			
		||||
6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟。
 | 
			
		||||
6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟(Redis 集群或者哨兵模式的支持请参考环境变量说明)。
 | 
			
		||||
7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis,并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。
 | 
			
		||||
 | 
			
		||||
环境变量的具体使用方法详见[此处](#环境变量)。
 | 
			
		||||
@@ -251,9 +257,9 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
 | 
			
		||||
#### QChatGPT - QQ机器人
 | 
			
		||||
项目主页:https://github.com/RockChinQ/QChatGPT
 | 
			
		||||
 | 
			
		||||
根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。
 | 
			
		||||
根据[文档](https://qchatgpt.rockchin.top)完成部署后,在 `data/provider.json`设置`requester.openai-chat-completions.base-url`为 One API 实例地址,并填写 API Key 到 `keys.openai` 组中,设置 `model` 为要使用的模型名称。
 | 
			
		||||
 | 
			
		||||
可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。
 | 
			
		||||
运行期间可以通过`!model`命令查看、切换可用模型。
 | 
			
		||||
 | 
			
		||||
### 部署到第三方平台
 | 
			
		||||
<details>
 | 
			
		||||
@@ -345,6 +351,11 @@ graph LR
 | 
			
		||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
 | 
			
		||||
   + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
 | 
			
		||||
   + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。
 | 
			
		||||
   + 如果需要使用哨兵或者集群模式:
 | 
			
		||||
     + 则需要把该环境变量设置为节点列表,例如:`localhost:49153,localhost:49154,localhost:49155`。
 | 
			
		||||
     + 除此之外还需要设置以下环境变量:
 | 
			
		||||
       + `REDIS_PASSWORD`:Redis 集群或者哨兵模式下的密码设置。
 | 
			
		||||
       + `REDIS_MASTER_NAME`:Redis 哨兵模式下主节点的名称。
 | 
			
		||||
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
 | 
			
		||||
   + 例子:`SESSION_SECRET=random_string`
 | 
			
		||||
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。
 | 
			
		||||
@@ -398,6 +409,8 @@ graph LR
 | 
			
		||||
26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
 | 
			
		||||
27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
 | 
			
		||||
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
 | 
			
		||||
29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage,默认不开启,可选值为 `true` 和 `false`。
 | 
			
		||||
30. `TEST_PROMPT`:测试模型时的用户 prompt,默认为 `Print your model name exactly and do not output without any other text.`。
 | 
			
		||||
 | 
			
		||||
### 命令行参数
 | 
			
		||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
 | 
			
		||||
 
 | 
			
		||||
@@ -1,13 +1,14 @@
 | 
			
		||||
package config
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/env"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/env"
 | 
			
		||||
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -35,6 +36,7 @@ var PasswordLoginEnabled = true
 | 
			
		||||
var PasswordRegisterEnabled = true
 | 
			
		||||
var EmailVerificationEnabled = false
 | 
			
		||||
var GitHubOAuthEnabled = false
 | 
			
		||||
var OidcEnabled = false
 | 
			
		||||
var WeChatAuthEnabled = false
 | 
			
		||||
var TurnstileCheckEnabled = false
 | 
			
		||||
var RegisterEnabled = true
 | 
			
		||||
@@ -70,6 +72,13 @@ var GitHubClientSecret = ""
 | 
			
		||||
var LarkClientId = ""
 | 
			
		||||
var LarkClientSecret = ""
 | 
			
		||||
 | 
			
		||||
var OidcClientId = ""
 | 
			
		||||
var OidcClientSecret = ""
 | 
			
		||||
var OidcWellKnown = ""
 | 
			
		||||
var OidcAuthorizationEndpoint = ""
 | 
			
		||||
var OidcTokenEndpoint = ""
 | 
			
		||||
var OidcUserinfoEndpoint = ""
 | 
			
		||||
 | 
			
		||||
var WeChatServerAddress = ""
 | 
			
		||||
var WeChatServerToken = ""
 | 
			
		||||
var WeChatAccountQRCodeImageURL = ""
 | 
			
		||||
@@ -147,9 +156,11 @@ var InitialRootAccessToken = os.Getenv("INITIAL_ROOT_ACCESS_TOKEN")
 | 
			
		||||
 | 
			
		||||
var GeminiVersion = env.String("GEMINI_VERSION", "v1")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
 | 
			
		||||
 | 
			
		||||
var RelayProxy = env.String("RELAY_PROXY", "")
 | 
			
		||||
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
 | 
			
		||||
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
 | 
			
		||||
 | 
			
		||||
var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)
 | 
			
		||||
var TestPrompt = env.String("TEST_PROMPT", "Print your model name exactly and do not output without any other text.")
 | 
			
		||||
 
 | 
			
		||||
@@ -20,4 +20,5 @@ const (
 | 
			
		||||
	BaseURL           = "base_url"
 | 
			
		||||
	AvailableModels   = "available_models"
 | 
			
		||||
	KeyRequestBody    = "key_request_body"
 | 
			
		||||
	SystemPrompt      = "system_prompt"
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -31,15 +31,15 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
 | 
			
		||||
	contentType := c.Request.Header.Get("Content-Type")
 | 
			
		||||
	if strings.HasPrefix(contentType, "application/json") {
 | 
			
		||||
		err = json.Unmarshal(requestBody, &v)
 | 
			
		||||
		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
	} else {
 | 
			
		||||
		// skip for now
 | 
			
		||||
		// TODO: someday non json request have variant model, we will need to implementation this
 | 
			
		||||
		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
		err = c.ShouldBind(&v)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	// Reset request body
 | 
			
		||||
	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,9 +1,8 @@
 | 
			
		||||
package helper
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net"
 | 
			
		||||
@@ -11,6 +10,10 @@ import (
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func OpenBrowser(url string) {
 | 
			
		||||
@@ -106,6 +109,18 @@ func GenRequestID() string {
 | 
			
		||||
	return GetTimeString() + random.GetRandomNumberString(8)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SetRequestID(ctx context.Context, id string) context.Context {
 | 
			
		||||
	return context.WithValue(ctx, RequestIdKey, id)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetRequestID(ctx context.Context) string {
 | 
			
		||||
	rawRequestId := ctx.Value(RequestIdKey)
 | 
			
		||||
	if rawRequestId == nil {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	return rawRequestId.(string)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetResponseID(c *gin.Context) string {
 | 
			
		||||
	logID := c.GetString(RequestIdKey)
 | 
			
		||||
	return fmt.Sprintf("chatcmpl-%s", logID)
 | 
			
		||||
@@ -137,3 +152,23 @@ func String2Int(str string) int {
 | 
			
		||||
	}
 | 
			
		||||
	return num
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Float64PtrMax(p *float64, maxValue float64) *float64 {
 | 
			
		||||
	if p == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	if *p > maxValue {
 | 
			
		||||
		return &maxValue
 | 
			
		||||
	}
 | 
			
		||||
	return p
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Float64PtrMin(p *float64, minValue float64) *float64 {
 | 
			
		||||
	if p == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	if *p < minValue {
 | 
			
		||||
		return &minValue
 | 
			
		||||
	}
 | 
			
		||||
	return p
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -13,3 +13,8 @@ func GetTimeString() string {
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CalcElapsedTime return the elapsed time in milliseconds (ms)
 | 
			
		||||
func CalcElapsedTime(start time.Time) int64 {
 | 
			
		||||
	return time.Now().Sub(start).Milliseconds()
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -7,19 +7,25 @@ import (
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type loggerLevel string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	loggerDEBUG = "DEBUG"
 | 
			
		||||
	loggerINFO  = "INFO"
 | 
			
		||||
	loggerWarn  = "WARN"
 | 
			
		||||
	loggerError = "ERR"
 | 
			
		||||
	loggerDEBUG loggerLevel = "DEBUG"
 | 
			
		||||
	loggerINFO  loggerLevel = "INFO"
 | 
			
		||||
	loggerWarn  loggerLevel = "WARN"
 | 
			
		||||
	loggerError loggerLevel = "ERROR"
 | 
			
		||||
	loggerFatal loggerLevel = "FATAL"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var setupLogOnce sync.Once
 | 
			
		||||
@@ -44,27 +50,26 @@ func SetupLogger() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SysLog(s string) {
 | 
			
		||||
	t := time.Now()
 | 
			
		||||
	_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
 | 
			
		||||
	logHelper(nil, loggerINFO, s)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SysLogf(format string, a ...any) {
 | 
			
		||||
	SysLog(fmt.Sprintf(format, a...))
 | 
			
		||||
	logHelper(nil, loggerINFO, fmt.Sprintf(format, a...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SysError(s string) {
 | 
			
		||||
	t := time.Now()
 | 
			
		||||
	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
 | 
			
		||||
	logHelper(nil, loggerError, s)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SysErrorf(format string, a ...any) {
 | 
			
		||||
	SysError(fmt.Sprintf(format, a...))
 | 
			
		||||
	logHelper(nil, loggerError, fmt.Sprintf(format, a...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Debug(ctx context.Context, msg string) {
 | 
			
		||||
	if config.DebugEnabled {
 | 
			
		||||
		logHelper(ctx, loggerDEBUG, msg)
 | 
			
		||||
	if !config.DebugEnabled {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	logHelper(ctx, loggerDEBUG, msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Info(ctx context.Context, msg string) {
 | 
			
		||||
@@ -80,37 +85,65 @@ func Error(ctx context.Context, msg string) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Debugf(ctx context.Context, format string, a ...any) {
 | 
			
		||||
	Debug(ctx, fmt.Sprintf(format, a...))
 | 
			
		||||
	logHelper(ctx, loggerDEBUG, fmt.Sprintf(format, a...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Infof(ctx context.Context, format string, a ...any) {
 | 
			
		||||
	Info(ctx, fmt.Sprintf(format, a...))
 | 
			
		||||
	logHelper(ctx, loggerINFO, fmt.Sprintf(format, a...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Warnf(ctx context.Context, format string, a ...any) {
 | 
			
		||||
	Warn(ctx, fmt.Sprintf(format, a...))
 | 
			
		||||
	logHelper(ctx, loggerWarn, fmt.Sprintf(format, a...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Errorf(ctx context.Context, format string, a ...any) {
 | 
			
		||||
	Error(ctx, fmt.Sprintf(format, a...))
 | 
			
		||||
	logHelper(ctx, loggerError, fmt.Sprintf(format, a...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func logHelper(ctx context.Context, level string, msg string) {
 | 
			
		||||
func FatalLog(s string) {
 | 
			
		||||
	logHelper(nil, loggerFatal, s)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func FatalLogf(format string, a ...any) {
 | 
			
		||||
	logHelper(nil, loggerFatal, fmt.Sprintf(format, a...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func logHelper(ctx context.Context, level loggerLevel, msg string) {
 | 
			
		||||
	writer := gin.DefaultErrorWriter
 | 
			
		||||
	if level == loggerINFO {
 | 
			
		||||
		writer = gin.DefaultWriter
 | 
			
		||||
	}
 | 
			
		||||
	id := ctx.Value(helper.RequestIdKey)
 | 
			
		||||
	if id == nil {
 | 
			
		||||
		id = helper.GenRequestID()
 | 
			
		||||
	var requestId string
 | 
			
		||||
	if ctx != nil {
 | 
			
		||||
		rawRequestId := helper.GetRequestID(ctx)
 | 
			
		||||
		if rawRequestId != "" {
 | 
			
		||||
			requestId = fmt.Sprintf(" | %s", rawRequestId)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	lineInfo, funcName := getLineInfo()
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
 | 
			
		||||
	_, _ = fmt.Fprintf(writer, "[%s] %v%s%s %s%s \n", level, now.Format("2006/01/02 - 15:04:05"), requestId, lineInfo, funcName, msg)
 | 
			
		||||
	SetupLogger()
 | 
			
		||||
	if level == loggerFatal {
 | 
			
		||||
		os.Exit(1)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func FatalLog(v ...any) {
 | 
			
		||||
	t := time.Now()
 | 
			
		||||
	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
 | 
			
		||||
	os.Exit(1)
 | 
			
		||||
func getLineInfo() (string, string) {
 | 
			
		||||
	funcName := "[unknown] "
 | 
			
		||||
	pc, file, line, ok := runtime.Caller(3)
 | 
			
		||||
	if ok {
 | 
			
		||||
		if fn := runtime.FuncForPC(pc); fn != nil {
 | 
			
		||||
			parts := strings.Split(fn.Name(), ".")
 | 
			
		||||
			funcName = "[" + parts[len(parts)-1] + "] "
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		file = "unknown"
 | 
			
		||||
		line = 0
 | 
			
		||||
	}
 | 
			
		||||
	parts := strings.Split(file, "one-api/")
 | 
			
		||||
	if len(parts) > 1 {
 | 
			
		||||
		file = parts[1]
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf(" | %s:%d", file, line), funcName
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,13 +2,15 @@ package common
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"os"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var RDB *redis.Client
 | 
			
		||||
var RDB redis.Cmdable
 | 
			
		||||
var RedisEnabled = true
 | 
			
		||||
 | 
			
		||||
// InitRedisClient This function is called after init()
 | 
			
		||||
@@ -23,13 +25,23 @@ func InitRedisClient() (err error) {
 | 
			
		||||
		logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	logger.SysLog("Redis is enabled")
 | 
			
		||||
	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.FatalLog("failed to parse Redis connection string: " + err.Error())
 | 
			
		||||
	redisConnString := os.Getenv("REDIS_CONN_STRING")
 | 
			
		||||
	if os.Getenv("REDIS_MASTER_NAME") == "" {
 | 
			
		||||
		logger.SysLog("Redis is enabled")
 | 
			
		||||
		opt, err := redis.ParseURL(redisConnString)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.FatalLog("failed to parse Redis connection string: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
		RDB = redis.NewClient(opt)
 | 
			
		||||
	} else {
 | 
			
		||||
		// cluster mode
 | 
			
		||||
		logger.SysLog("Redis cluster mode enabled")
 | 
			
		||||
		RDB = redis.NewUniversalClient(&redis.UniversalOptions{
 | 
			
		||||
			Addrs:      strings.Split(redisConnString, ","),
 | 
			
		||||
			Password:   os.Getenv("REDIS_PASSWORD"),
 | 
			
		||||
			MasterName: os.Getenv("REDIS_MASTER_NAME"),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	RDB = redis.NewClient(opt)
 | 
			
		||||
 | 
			
		||||
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,9 +3,10 @@ package render
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func StringData(c *gin.Context, str string) {
 | 
			
		||||
 
 | 
			
		||||
@@ -5,16 +5,18 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-contrib/sessions"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
	"github.com/songquanpeng/one-api/controller"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type GitHubOAuthResponse struct {
 | 
			
		||||
@@ -81,6 +83,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GitHubOAuth(c *gin.Context) {
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
	session := sessions.Default(c)
 | 
			
		||||
	state := c.Query("state")
 | 
			
		||||
	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
 | 
			
		||||
@@ -136,7 +139,7 @@ func GitHubOAuth(c *gin.Context) {
 | 
			
		||||
			user.Role = model.RoleCommonUser
 | 
			
		||||
			user.Status = model.UserStatusEnabled
 | 
			
		||||
 | 
			
		||||
			if err := user.Insert(0); err != nil {
 | 
			
		||||
			if err := user.Insert(ctx, 0); err != nil {
 | 
			
		||||
				c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
					"success": false,
 | 
			
		||||
					"message": err.Error(),
 | 
			
		||||
 
 | 
			
		||||
@@ -5,15 +5,17 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-contrib/sessions"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/controller"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type LarkOAuthResponse struct {
 | 
			
		||||
@@ -40,7 +42,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData))
 | 
			
		||||
	req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -79,6 +81,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func LarkOAuth(c *gin.Context) {
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
	session := sessions.Default(c)
 | 
			
		||||
	state := c.Query("state")
 | 
			
		||||
	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
 | 
			
		||||
@@ -125,7 +128,7 @@ func LarkOAuth(c *gin.Context) {
 | 
			
		||||
			user.Role = model.RoleCommonUser
 | 
			
		||||
			user.Status = model.UserStatusEnabled
 | 
			
		||||
 | 
			
		||||
			if err := user.Insert(0); err != nil {
 | 
			
		||||
			if err := user.Insert(ctx, 0); err != nil {
 | 
			
		||||
				c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
					"success": false,
 | 
			
		||||
					"message": err.Error(),
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										228
									
								
								controller/auth/oidc.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								controller/auth/oidc.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,228 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-contrib/sessions"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/controller"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type OidcResponse struct {
 | 
			
		||||
	AccessToken  string `json:"access_token"`
 | 
			
		||||
	IDToken      string `json:"id_token"`
 | 
			
		||||
	RefreshToken string `json:"refresh_token"`
 | 
			
		||||
	TokenType    string `json:"token_type"`
 | 
			
		||||
	ExpiresIn    int    `json:"expires_in"`
 | 
			
		||||
	Scope        string `json:"scope"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OidcUser struct {
 | 
			
		||||
	OpenID            string `json:"sub"`
 | 
			
		||||
	Email             string `json:"email"`
 | 
			
		||||
	Name              string `json:"name"`
 | 
			
		||||
	PreferredUsername string `json:"preferred_username"`
 | 
			
		||||
	Picture           string `json:"picture"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
 | 
			
		||||
	if code == "" {
 | 
			
		||||
		return nil, errors.New("无效的参数")
 | 
			
		||||
	}
 | 
			
		||||
	values := map[string]string{
 | 
			
		||||
		"client_id":     config.OidcClientId,
 | 
			
		||||
		"client_secret": config.OidcClientSecret,
 | 
			
		||||
		"code":          code,
 | 
			
		||||
		"grant_type":    "authorization_code",
 | 
			
		||||
		"redirect_uri":  fmt.Sprintf("%s/oauth/oidc", config.ServerAddress),
 | 
			
		||||
	}
 | 
			
		||||
	jsonData, err := json.Marshal(values)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Content-Type", "application/json")
 | 
			
		||||
	req.Header.Set("Accept", "application/json")
 | 
			
		||||
	client := http.Client{
 | 
			
		||||
		Timeout: 5 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
	res, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysLog(err.Error())
 | 
			
		||||
		return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
 | 
			
		||||
	}
 | 
			
		||||
	defer res.Body.Close()
 | 
			
		||||
	var oidcResponse OidcResponse
 | 
			
		||||
	err = json.NewDecoder(res.Body).Decode(&oidcResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
 | 
			
		||||
	res2, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysLog(err.Error())
 | 
			
		||||
		return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
 | 
			
		||||
	}
 | 
			
		||||
	var oidcUser OidcUser
 | 
			
		||||
	err = json.NewDecoder(res2.Body).Decode(&oidcUser)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &oidcUser, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func OidcAuth(c *gin.Context) {
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
	session := sessions.Default(c)
 | 
			
		||||
	state := c.Query("state")
 | 
			
		||||
	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
 | 
			
		||||
		c.JSON(http.StatusForbidden, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "state is empty or not same",
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	username := session.Get("username")
 | 
			
		||||
	if username != nil {
 | 
			
		||||
		OidcBind(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if !config.OidcEnabled {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "管理员未开启通过 OIDC 登录以及注册",
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	code := c.Query("code")
 | 
			
		||||
	oidcUser, err := getOidcUserInfoByCode(code)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	user := model.User{
 | 
			
		||||
		OidcId: oidcUser.OpenID,
 | 
			
		||||
	}
 | 
			
		||||
	if model.IsOidcIdAlreadyTaken(user.OidcId) {
 | 
			
		||||
		err := user.FillUserByOidcId()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": err.Error(),
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		if config.RegisterEnabled {
 | 
			
		||||
			user.Email = oidcUser.Email
 | 
			
		||||
			if oidcUser.PreferredUsername != "" {
 | 
			
		||||
				user.Username = oidcUser.PreferredUsername
 | 
			
		||||
			} else {
 | 
			
		||||
				user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
 | 
			
		||||
			}
 | 
			
		||||
			if oidcUser.Name != "" {
 | 
			
		||||
				user.DisplayName = oidcUser.Name
 | 
			
		||||
			} else {
 | 
			
		||||
				user.DisplayName = "OIDC User"
 | 
			
		||||
			}
 | 
			
		||||
			err := user.Insert(ctx, 0)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
					"success": false,
 | 
			
		||||
					"message": err.Error(),
 | 
			
		||||
				})
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": "管理员关闭了新用户注册",
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Status != model.UserStatusEnabled {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"message": "用户已被封禁",
 | 
			
		||||
			"success": false,
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	controller.SetupLogin(&user, c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func OidcBind(c *gin.Context) {
 | 
			
		||||
	if !config.OidcEnabled {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "管理员未开启通过 OIDC 登录以及注册",
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	code := c.Query("code")
 | 
			
		||||
	oidcUser, err := getOidcUserInfoByCode(code)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	user := model.User{
 | 
			
		||||
		OidcId: oidcUser.OpenID,
 | 
			
		||||
	}
 | 
			
		||||
	if model.IsOidcIdAlreadyTaken(user.OidcId) {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "该 OIDC 账户已被绑定",
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	session := sessions.Default(c)
 | 
			
		||||
	id := session.Get("id")
 | 
			
		||||
	// id := c.GetInt("id")  // critical bug!
 | 
			
		||||
	user.Id = id.(int)
 | 
			
		||||
	err = user.FillUserById()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	user.OidcId = oidcUser.OpenID
 | 
			
		||||
	err = user.Update(false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "bind",
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@@ -4,14 +4,16 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/controller"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type wechatLoginResponse struct {
 | 
			
		||||
@@ -52,6 +54,7 @@ func getWeChatIdByCode(code string) (string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func WeChatAuth(c *gin.Context) {
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
	if !config.WeChatAuthEnabled {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"message": "管理员未开启通过微信登录以及注册",
 | 
			
		||||
@@ -87,7 +90,7 @@ func WeChatAuth(c *gin.Context) {
 | 
			
		||||
			user.Role = model.RoleCommonUser
 | 
			
		||||
			user.Status = model.UserStatusEnabled
 | 
			
		||||
 | 
			
		||||
			if err := user.Insert(0); err != nil {
 | 
			
		||||
			if err := user.Insert(ctx, 0); err != nil {
 | 
			
		||||
				c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
					"success": false,
 | 
			
		||||
					"message": err.Error(),
 | 
			
		||||
 
 | 
			
		||||
@@ -17,9 +17,11 @@ func GetSubscription(c *gin.Context) {
 | 
			
		||||
	if config.DisplayTokenStatEnabled {
 | 
			
		||||
		tokenId := c.GetInt(ctxkey.TokenId)
 | 
			
		||||
		token, err = model.GetTokenById(tokenId)
 | 
			
		||||
		expiredTime = token.ExpiredTime
 | 
			
		||||
		remainQuota = token.RemainQuota
 | 
			
		||||
		usedQuota = token.UsedQuota
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			expiredTime = token.ExpiredTime
 | 
			
		||||
			remainQuota = token.RemainQuota
 | 
			
		||||
			usedQuota = token.UsedQuota
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
		remainQuota, err = model.GetUserQuota(userId)
 | 
			
		||||
 
 | 
			
		||||
@@ -4,16 +4,17 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/client"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/monitor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channeltype"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
@@ -81,6 +82,36 @@ type APGC2DGPTUsageResponse struct {
 | 
			
		||||
	TotalUsed      float64 `json:"total_used"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SiliconFlowUsageResponse struct {
 | 
			
		||||
	Code    int    `json:"code"`
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
	Status  bool   `json:"status"`
 | 
			
		||||
	Data    struct {
 | 
			
		||||
		ID            string `json:"id"`
 | 
			
		||||
		Name          string `json:"name"`
 | 
			
		||||
		Image         string `json:"image"`
 | 
			
		||||
		Email         string `json:"email"`
 | 
			
		||||
		IsAdmin       bool   `json:"isAdmin"`
 | 
			
		||||
		Balance       string `json:"balance"`
 | 
			
		||||
		Status        string `json:"status"`
 | 
			
		||||
		Introduction  string `json:"introduction"`
 | 
			
		||||
		Role          string `json:"role"`
 | 
			
		||||
		ChargeBalance string `json:"chargeBalance"`
 | 
			
		||||
		TotalBalance  string `json:"totalBalance"`
 | 
			
		||||
		Category      string `json:"category"`
 | 
			
		||||
	} `json:"data"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DeepSeekUsageResponse struct {
 | 
			
		||||
	IsAvailable  bool `json:"is_available"`
 | 
			
		||||
	BalanceInfos []struct {
 | 
			
		||||
		Currency        string `json:"currency"`
 | 
			
		||||
		TotalBalance    string `json:"total_balance"`
 | 
			
		||||
		GrantedBalance  string `json:"granted_balance"`
 | 
			
		||||
		ToppedUpBalance string `json:"topped_up_balance"`
 | 
			
		||||
	} `json:"balance_infos"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAuthHeader get auth header
 | 
			
		||||
func GetAuthHeader(token string) http.Header {
 | 
			
		||||
	h := http.Header{}
 | 
			
		||||
@@ -203,6 +234,57 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	return response.TotalAvailable, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	url := "https://api.siliconflow.cn/v1/user/info"
 | 
			
		||||
	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	response := SiliconFlowUsageResponse{}
 | 
			
		||||
	err = json.Unmarshal(body, &response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	if response.Code != 20000 {
 | 
			
		||||
		return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
 | 
			
		||||
	}
 | 
			
		||||
	balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	channel.UpdateBalance(balance)
 | 
			
		||||
	return balance, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	url := "https://api.deepseek.com/user/balance"
 | 
			
		||||
	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	response := DeepSeekUsageResponse{}
 | 
			
		||||
	err = json.Unmarshal(body, &response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	index := -1
 | 
			
		||||
	for i, balanceInfo := range response.BalanceInfos {
 | 
			
		||||
		if balanceInfo.Currency == "CNY" {
 | 
			
		||||
			index = i
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if index == -1 {
 | 
			
		||||
		return 0, errors.New("currency CNY not found")
 | 
			
		||||
	}
 | 
			
		||||
	balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	channel.UpdateBalance(balance)
 | 
			
		||||
	return balance, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	baseURL := channeltype.ChannelBaseURLs[channel.Type]
 | 
			
		||||
	if channel.GetBaseURL() == "" {
 | 
			
		||||
@@ -227,6 +309,10 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
		return updateChannelAPI2GPTBalance(channel)
 | 
			
		||||
	case channeltype.AIGC2D:
 | 
			
		||||
		return updateChannelAIGC2DBalance(channel)
 | 
			
		||||
	case channeltype.SiliconFlow:
 | 
			
		||||
		return updateChannelSiliconFlowBalance(channel)
 | 
			
		||||
	case channeltype.DeepSeek:
 | 
			
		||||
		return updateChannelDeepSeekBalance(channel)
 | 
			
		||||
	default:
 | 
			
		||||
		return 0, errors.New("尚未实现")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
@@ -15,14 +16,17 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/message"
 | 
			
		||||
	"github.com/songquanpeng/one-api/middleware"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/monitor"
 | 
			
		||||
	relay "github.com/songquanpeng/one-api/relay"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channeltype"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/controller"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
@@ -35,18 +39,34 @@ func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
 | 
			
		||||
		model = "gpt-3.5-turbo"
 | 
			
		||||
	}
 | 
			
		||||
	testRequest := &relaymodel.GeneralOpenAIRequest{
 | 
			
		||||
		MaxTokens: 2,
 | 
			
		||||
		Model:     model,
 | 
			
		||||
		Model: model,
 | 
			
		||||
	}
 | 
			
		||||
	testMessage := relaymodel.Message{
 | 
			
		||||
		Role:    "user",
 | 
			
		||||
		Content: "hi",
 | 
			
		||||
		Content: config.TestPrompt,
 | 
			
		||||
	}
 | 
			
		||||
	testRequest.Messages = append(testRequest.Messages, testMessage)
 | 
			
		||||
	return testRequest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) {
 | 
			
		||||
func parseTestResponse(resp string) (*openai.TextResponse, string, error) {
 | 
			
		||||
	var response openai.TextResponse
 | 
			
		||||
	err := json.Unmarshal([]byte(resp), &response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, "", err
 | 
			
		||||
	}
 | 
			
		||||
	if len(response.Choices) == 0 {
 | 
			
		||||
		return nil, "", errors.New("response has no choices")
 | 
			
		||||
	}
 | 
			
		||||
	stringContent, ok := response.Choices[0].Content.(string)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil, "", errors.New("response content is not string")
 | 
			
		||||
	}
 | 
			
		||||
	return &response, stringContent, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testChannel(ctx context.Context, channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (responseMessage string, err error, openaiErr *relaymodel.Error) {
 | 
			
		||||
	startTime := time.Now()
 | 
			
		||||
	w := httptest.NewRecorder()
 | 
			
		||||
	c, _ := gin.CreateTestContext(w)
 | 
			
		||||
	c.Request = &http.Request{
 | 
			
		||||
@@ -66,7 +86,7 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
 | 
			
		||||
	apiType := channeltype.ToAPIType(channel.Type)
 | 
			
		||||
	adaptor := relay.GetAdaptor(apiType)
 | 
			
		||||
	if adaptor == nil {
 | 
			
		||||
		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
 | 
			
		||||
		return "", fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
 | 
			
		||||
	}
 | 
			
		||||
	adaptor.Init(meta)
 | 
			
		||||
	modelName := request.Model
 | 
			
		||||
@@ -76,49 +96,77 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
 | 
			
		||||
		if len(modelNames) > 0 {
 | 
			
		||||
			modelName = modelNames[0]
 | 
			
		||||
		}
 | 
			
		||||
		if modelMap != nil && modelMap[modelName] != "" {
 | 
			
		||||
			modelName = modelMap[modelName]
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if modelMap != nil && modelMap[modelName] != "" {
 | 
			
		||||
		modelName = modelMap[modelName]
 | 
			
		||||
	}
 | 
			
		||||
	meta.OriginModelName, meta.ActualModelName = request.Model, modelName
 | 
			
		||||
	request.Model = modelName
 | 
			
		||||
	convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err, nil
 | 
			
		||||
		return "", err, nil
 | 
			
		||||
	}
 | 
			
		||||
	jsonData, err := json.Marshal(convertedRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err, nil
 | 
			
		||||
		return "", err, nil
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		logContent := fmt.Sprintf("渠道 %s 测试成功,响应:%s", channel.Name, responseMessage)
 | 
			
		||||
		if err != nil || openaiErr != nil {
 | 
			
		||||
			errorMessage := ""
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				errorMessage = err.Error()
 | 
			
		||||
			} else {
 | 
			
		||||
				errorMessage = openaiErr.Message
 | 
			
		||||
			}
 | 
			
		||||
			logContent = fmt.Sprintf("渠道 %s 测试失败,错误:%s", channel.Name, errorMessage)
 | 
			
		||||
		}
 | 
			
		||||
		go model.RecordTestLog(ctx, &model.Log{
 | 
			
		||||
			ChannelId:   channel.Id,
 | 
			
		||||
			ModelName:   modelName,
 | 
			
		||||
			Content:     logContent,
 | 
			
		||||
			ElapsedTime: helper.CalcElapsedTime(startTime),
 | 
			
		||||
		})
 | 
			
		||||
	}()
 | 
			
		||||
	logger.SysLog(string(jsonData))
 | 
			
		||||
	requestBody := bytes.NewBuffer(jsonData)
 | 
			
		||||
	c.Request.Body = io.NopCloser(requestBody)
 | 
			
		||||
	resp, err := adaptor.DoRequest(c, meta, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err, nil
 | 
			
		||||
		return "", err, nil
 | 
			
		||||
	}
 | 
			
		||||
	if resp != nil && resp.StatusCode != http.StatusOK {
 | 
			
		||||
		err := controller.RelayErrorHandler(resp)
 | 
			
		||||
		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
 | 
			
		||||
		errorMessage := err.Error.Message
 | 
			
		||||
		if errorMessage != "" {
 | 
			
		||||
			errorMessage = ", error message: " + errorMessage
 | 
			
		||||
		}
 | 
			
		||||
		return "", fmt.Errorf("http status code: %d%s", resp.StatusCode, errorMessage), &err.Error
 | 
			
		||||
	}
 | 
			
		||||
	usage, respErr := adaptor.DoResponse(c, resp, meta)
 | 
			
		||||
	if respErr != nil {
 | 
			
		||||
		return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
 | 
			
		||||
		return "", fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
 | 
			
		||||
	}
 | 
			
		||||
	if usage == nil {
 | 
			
		||||
		return errors.New("usage is nil"), nil
 | 
			
		||||
		return "", errors.New("usage is nil"), nil
 | 
			
		||||
	}
 | 
			
		||||
	rawResponse := w.Body.String()
 | 
			
		||||
	_, responseMessage, err = parseTestResponse(rawResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err, nil
 | 
			
		||||
	}
 | 
			
		||||
	result := w.Result()
 | 
			
		||||
	// print result.Body
 | 
			
		||||
	respBody, err := io.ReadAll(result.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err, nil
 | 
			
		||||
		return "", err, nil
 | 
			
		||||
	}
 | 
			
		||||
	logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
 | 
			
		||||
	return nil, nil
 | 
			
		||||
	return responseMessage, nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestChannel(c *gin.Context) {
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
	id, err := strconv.Atoi(c.Param("id"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
@@ -135,10 +183,10 @@ func TestChannel(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	model := c.Query("model")
 | 
			
		||||
	testRequest := buildTestRequest(model)
 | 
			
		||||
	modelName := c.Query("model")
 | 
			
		||||
	testRequest := buildTestRequest(modelName)
 | 
			
		||||
	tik := time.Now()
 | 
			
		||||
	err, _ = testChannel(channel, testRequest)
 | 
			
		||||
	responseMessage, err, _ := testChannel(ctx, channel, testRequest)
 | 
			
		||||
	tok := time.Now()
 | 
			
		||||
	milliseconds := tok.Sub(tik).Milliseconds()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -148,18 +196,18 @@ func TestChannel(c *gin.Context) {
 | 
			
		||||
	consumedTime := float64(milliseconds) / 1000.0
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
			"time":    consumedTime,
 | 
			
		||||
			"model":   model,
 | 
			
		||||
			"success":   false,
 | 
			
		||||
			"message":   err.Error(),
 | 
			
		||||
			"time":      consumedTime,
 | 
			
		||||
			"modelName": modelName,
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"time":    consumedTime,
 | 
			
		||||
		"model":   model,
 | 
			
		||||
		"success":   true,
 | 
			
		||||
		"message":   responseMessage,
 | 
			
		||||
		"time":      consumedTime,
 | 
			
		||||
		"modelName": modelName,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@@ -167,7 +215,7 @@ func TestChannel(c *gin.Context) {
 | 
			
		||||
var testAllChannelsLock sync.Mutex
 | 
			
		||||
var testAllChannelsRunning bool = false
 | 
			
		||||
 | 
			
		||||
func testChannels(notify bool, scope string) error {
 | 
			
		||||
func testChannels(ctx context.Context, notify bool, scope string) error {
 | 
			
		||||
	if config.RootUserEmail == "" {
 | 
			
		||||
		config.RootUserEmail = model.GetRootUserEmail()
 | 
			
		||||
	}
 | 
			
		||||
@@ -191,7 +239,7 @@ func testChannels(notify bool, scope string) error {
 | 
			
		||||
			isChannelEnabled := channel.Status == model.ChannelStatusEnabled
 | 
			
		||||
			tik := time.Now()
 | 
			
		||||
			testRequest := buildTestRequest("")
 | 
			
		||||
			err, openaiErr := testChannel(channel, testRequest)
 | 
			
		||||
			_, err, openaiErr := testChannel(ctx, channel, testRequest)
 | 
			
		||||
			tok := time.Now()
 | 
			
		||||
			milliseconds := tok.Sub(tik).Milliseconds()
 | 
			
		||||
			if isChannelEnabled && milliseconds > disableThreshold {
 | 
			
		||||
@@ -225,11 +273,12 @@ func testChannels(notify bool, scope string) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestChannels(c *gin.Context) {
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
	scope := c.Query("scope")
 | 
			
		||||
	if scope == "" {
 | 
			
		||||
		scope = "all"
 | 
			
		||||
	}
 | 
			
		||||
	err := testChannels(true, scope)
 | 
			
		||||
	err := testChannels(ctx, true, scope)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
@@ -245,10 +294,11 @@ func TestChannels(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func AutomaticallyTestChannels(frequency int) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	for {
 | 
			
		||||
		time.Sleep(time.Duration(frequency) * time.Minute)
 | 
			
		||||
		logger.SysLog("testing all channels")
 | 
			
		||||
		_ = testChannels(false, "all")
 | 
			
		||||
		_ = testChannels(ctx, false, "all")
 | 
			
		||||
		logger.SysLog("channel test finished")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -18,24 +18,30 @@ func GetStatus(c *gin.Context) {
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data": gin.H{
 | 
			
		||||
			"version":             common.Version,
 | 
			
		||||
			"start_time":          common.StartTime,
 | 
			
		||||
			"email_verification":  config.EmailVerificationEnabled,
 | 
			
		||||
			"github_oauth":        config.GitHubOAuthEnabled,
 | 
			
		||||
			"github_client_id":    config.GitHubClientId,
 | 
			
		||||
			"lark_client_id":      config.LarkClientId,
 | 
			
		||||
			"system_name":         config.SystemName,
 | 
			
		||||
			"logo":                config.Logo,
 | 
			
		||||
			"footer_html":         config.Footer,
 | 
			
		||||
			"wechat_qrcode":       config.WeChatAccountQRCodeImageURL,
 | 
			
		||||
			"wechat_login":        config.WeChatAuthEnabled,
 | 
			
		||||
			"server_address":      config.ServerAddress,
 | 
			
		||||
			"turnstile_check":     config.TurnstileCheckEnabled,
 | 
			
		||||
			"turnstile_site_key":  config.TurnstileSiteKey,
 | 
			
		||||
			"top_up_link":         config.TopUpLink,
 | 
			
		||||
			"chat_link":           config.ChatLink,
 | 
			
		||||
			"quota_per_unit":      config.QuotaPerUnit,
 | 
			
		||||
			"display_in_currency": config.DisplayInCurrencyEnabled,
 | 
			
		||||
			"version":                     common.Version,
 | 
			
		||||
			"start_time":                  common.StartTime,
 | 
			
		||||
			"email_verification":          config.EmailVerificationEnabled,
 | 
			
		||||
			"github_oauth":                config.GitHubOAuthEnabled,
 | 
			
		||||
			"github_client_id":            config.GitHubClientId,
 | 
			
		||||
			"lark_client_id":              config.LarkClientId,
 | 
			
		||||
			"system_name":                 config.SystemName,
 | 
			
		||||
			"logo":                        config.Logo,
 | 
			
		||||
			"footer_html":                 config.Footer,
 | 
			
		||||
			"wechat_qrcode":               config.WeChatAccountQRCodeImageURL,
 | 
			
		||||
			"wechat_login":                config.WeChatAuthEnabled,
 | 
			
		||||
			"server_address":              config.ServerAddress,
 | 
			
		||||
			"turnstile_check":             config.TurnstileCheckEnabled,
 | 
			
		||||
			"turnstile_site_key":          config.TurnstileSiteKey,
 | 
			
		||||
			"top_up_link":                 config.TopUpLink,
 | 
			
		||||
			"chat_link":                   config.ChatLink,
 | 
			
		||||
			"quota_per_unit":              config.QuotaPerUnit,
 | 
			
		||||
			"display_in_currency":         config.DisplayInCurrencyEnabled,
 | 
			
		||||
			"oidc":                        config.OidcEnabled,
 | 
			
		||||
			"oidc_client_id":              config.OidcClientId,
 | 
			
		||||
			"oidc_well_known":             config.OidcWellKnown,
 | 
			
		||||
			"oidc_authorization_endpoint": config.OidcAuthorizationEndpoint,
 | 
			
		||||
			"oidc_token_endpoint":         config.OidcTokenEndpoint,
 | 
			
		||||
			"oidc_userinfo_endpoint":      config.OidcUserinfoEndpoint,
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
 
 | 
			
		||||
@@ -34,6 +34,8 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case relaymode.AudioTranscription:
 | 
			
		||||
		err = controller.RelayAudioHelper(c, relayMode)
 | 
			
		||||
	case relaymode.Proxy:
 | 
			
		||||
		err = controller.RelayProxyHelper(c, relayMode)
 | 
			
		||||
	default:
 | 
			
		||||
		err = controller.RelayTextHelper(c)
 | 
			
		||||
	}
 | 
			
		||||
@@ -58,7 +60,7 @@ func Relay(c *gin.Context) {
 | 
			
		||||
	channelName := c.GetString(ctxkey.ChannelName)
 | 
			
		||||
	group := c.GetString(ctxkey.Group)
 | 
			
		||||
	originalModel := c.GetString(ctxkey.OriginalModel)
 | 
			
		||||
	go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
 | 
			
		||||
	go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
 | 
			
		||||
	requestId := c.GetString(helper.RequestIdKey)
 | 
			
		||||
	retryTimes := config.RetryTimes
 | 
			
		||||
	if !shouldRetry(c, bizErr.StatusCode) {
 | 
			
		||||
@@ -85,12 +87,14 @@ func Relay(c *gin.Context) {
 | 
			
		||||
		channelId := c.GetInt(ctxkey.ChannelId)
 | 
			
		||||
		lastFailedChannelId = channelId
 | 
			
		||||
		channelName := c.GetString(ctxkey.ChannelName)
 | 
			
		||||
		go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
 | 
			
		||||
		go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
 | 
			
		||||
	}
 | 
			
		||||
	if bizErr != nil {
 | 
			
		||||
		if bizErr.StatusCode == http.StatusTooManyRequests {
 | 
			
		||||
			bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// BUG: bizErr is in race condition
 | 
			
		||||
		bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId)
 | 
			
		||||
		c.JSON(bizErr.StatusCode, gin.H{
 | 
			
		||||
			"error": bizErr.Error,
 | 
			
		||||
@@ -117,7 +121,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
 | 
			
		||||
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err model.ErrorWithStatusCode) {
 | 
			
		||||
	logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
 | 
			
		||||
	// https://platform.openai.com/docs/guides/error-codes/api-errors
 | 
			
		||||
	if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
 | 
			
		||||
 
 | 
			
		||||
@@ -109,6 +109,7 @@ func Logout(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Register(c *gin.Context) {
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
	if !config.RegisterEnabled {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"message": "管理员关闭了新用户注册",
 | 
			
		||||
@@ -166,7 +167,7 @@ func Register(c *gin.Context) {
 | 
			
		||||
	if config.EmailVerificationEnabled {
 | 
			
		||||
		cleanUser.Email = user.Email
 | 
			
		||||
	}
 | 
			
		||||
	if err := cleanUser.Insert(inviterId); err != nil {
 | 
			
		||||
	if err := cleanUser.Insert(ctx, inviterId); err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
@@ -362,6 +363,7 @@ func GetSelf(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateUser(c *gin.Context) {
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
	var updatedUser model.User
 | 
			
		||||
	err := json.NewDecoder(c.Request.Body).Decode(&updatedUser)
 | 
			
		||||
	if err != nil || updatedUser.Id == 0 {
 | 
			
		||||
@@ -416,7 +418,7 @@ func UpdateUser(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if originUser.Quota != updatedUser.Quota {
 | 
			
		||||
		model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
 | 
			
		||||
		model.RecordLog(ctx, originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
@@ -535,6 +537,7 @@ func DeleteSelf(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CreateUser(c *gin.Context) {
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
	var user model.User
 | 
			
		||||
	err := json.NewDecoder(c.Request.Body).Decode(&user)
 | 
			
		||||
	if err != nil || user.Username == "" || user.Password == "" {
 | 
			
		||||
@@ -568,7 +571,7 @@ func CreateUser(c *gin.Context) {
 | 
			
		||||
		Password:    user.Password,
 | 
			
		||||
		DisplayName: user.DisplayName,
 | 
			
		||||
	}
 | 
			
		||||
	if err := cleanUser.Insert(0); err != nil {
 | 
			
		||||
	if err := cleanUser.Insert(ctx, 0); err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
@@ -747,6 +750,7 @@ type topUpRequest struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TopUp(c *gin.Context) {
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
	req := topUpRequest{}
 | 
			
		||||
	err := c.ShouldBindJSON(&req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -757,7 +761,7 @@ func TopUp(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	id := c.GetInt("id")
 | 
			
		||||
	quota, err := model.Redeem(req.Key, id)
 | 
			
		||||
	quota, err := model.Redeem(ctx, req.Key, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
@@ -780,6 +784,7 @@ type adminTopUpRequest struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func AdminTopUp(c *gin.Context) {
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
	req := adminTopUpRequest{}
 | 
			
		||||
	err := c.ShouldBindJSON(&req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -800,7 +805,7 @@ func AdminTopUp(c *gin.Context) {
 | 
			
		||||
	if req.Remark == "" {
 | 
			
		||||
		req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota)))
 | 
			
		||||
	}
 | 
			
		||||
	model.RecordTopupLog(req.UserId, req.Remark, req.Quota)
 | 
			
		||||
	model.RecordTopupLog(ctx, req.UserId, req.Remark, req.Quota)
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										37
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								go.mod
									
									
									
									
									
								
							@@ -4,6 +4,7 @@ module github.com/songquanpeng/one-api
 | 
			
		||||
go 1.20
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	cloud.google.com/go/iam v1.1.10
 | 
			
		||||
	github.com/aws/aws-sdk-go-v2 v1.27.0
 | 
			
		||||
	github.com/aws/aws-sdk-go-v2/credentials v1.17.15
 | 
			
		||||
	github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3
 | 
			
		||||
@@ -19,12 +20,14 @@ require (
 | 
			
		||||
	github.com/gorilla/websocket v1.5.1
 | 
			
		||||
	github.com/jinzhu/copier v0.4.0
 | 
			
		||||
	github.com/joho/godotenv v1.5.1
 | 
			
		||||
	github.com/patrickmn/go-cache v2.1.0+incompatible
 | 
			
		||||
	github.com/pkg/errors v0.9.1
 | 
			
		||||
	github.com/pkoukk/tiktoken-go v0.1.7
 | 
			
		||||
	github.com/smartystreets/goconvey v1.8.1
 | 
			
		||||
	github.com/stretchr/testify v1.9.0
 | 
			
		||||
	golang.org/x/crypto v0.23.0
 | 
			
		||||
	golang.org/x/crypto v0.31.0
 | 
			
		||||
	golang.org/x/image v0.18.0
 | 
			
		||||
	google.golang.org/api v0.187.0
 | 
			
		||||
	gorm.io/driver/mysql v1.5.6
 | 
			
		||||
	gorm.io/driver/postgres v1.5.7
 | 
			
		||||
	gorm.io/driver/sqlite v1.5.5
 | 
			
		||||
@@ -32,6 +35,9 @@ require (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	cloud.google.com/go/auth v0.6.1 // indirect
 | 
			
		||||
	cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
 | 
			
		||||
	cloud.google.com/go/compute/metadata v0.3.0 // indirect
 | 
			
		||||
	filippo.io/edwards25519 v1.1.0 // indirect
 | 
			
		||||
	github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
 | 
			
		||||
	github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect
 | 
			
		||||
@@ -45,13 +51,21 @@ require (
 | 
			
		||||
	github.com/davecgh/go-spew v1.1.1 // indirect
 | 
			
		||||
	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
 | 
			
		||||
	github.com/dlclark/regexp2 v1.11.0 // indirect
 | 
			
		||||
	github.com/felixge/httpsnoop v1.0.4 // indirect
 | 
			
		||||
	github.com/fsnotify/fsnotify v1.7.0 // indirect
 | 
			
		||||
	github.com/gabriel-vasile/mimetype v1.4.3 // indirect
 | 
			
		||||
	github.com/gin-contrib/sse v0.1.0 // indirect
 | 
			
		||||
	github.com/go-logr/logr v1.4.1 // indirect
 | 
			
		||||
	github.com/go-logr/stdr v1.2.2 // indirect
 | 
			
		||||
	github.com/go-playground/locales v0.14.1 // indirect
 | 
			
		||||
	github.com/go-playground/universal-translator v0.18.1 // indirect
 | 
			
		||||
	github.com/go-sql-driver/mysql v1.8.1 // indirect
 | 
			
		||||
	github.com/goccy/go-json v0.10.3 // indirect
 | 
			
		||||
	github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
 | 
			
		||||
	github.com/golang/protobuf v1.5.4 // indirect
 | 
			
		||||
	github.com/google/s2a-go v0.1.7 // indirect
 | 
			
		||||
	github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
 | 
			
		||||
	github.com/googleapis/gax-go/v2 v2.12.5 // indirect
 | 
			
		||||
	github.com/gopherjs/gopherjs v1.17.2 // indirect
 | 
			
		||||
	github.com/gorilla/context v1.1.2 // indirect
 | 
			
		||||
	github.com/gorilla/securecookie v1.1.2 // indirect
 | 
			
		||||
@@ -76,11 +90,22 @@ require (
 | 
			
		||||
	github.com/smarty/assertions v1.15.0 // indirect
 | 
			
		||||
	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
 | 
			
		||||
	github.com/ugorji/go/codec v1.2.12 // indirect
 | 
			
		||||
	go.opencensus.io v0.24.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/otel v1.24.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/otel/metric v1.24.0 // indirect
 | 
			
		||||
	go.opentelemetry.io/otel/trace v1.24.0 // indirect
 | 
			
		||||
	golang.org/x/arch v0.8.0 // indirect
 | 
			
		||||
	golang.org/x/net v0.25.0 // indirect
 | 
			
		||||
	golang.org/x/sync v0.7.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.20.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.16.0 // indirect
 | 
			
		||||
	google.golang.org/protobuf v1.34.1 // indirect
 | 
			
		||||
	golang.org/x/net v0.26.0 // indirect
 | 
			
		||||
	golang.org/x/oauth2 v0.21.0 // indirect
 | 
			
		||||
	golang.org/x/sync v0.10.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.28.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.21.0 // indirect
 | 
			
		||||
	golang.org/x/time v0.5.0 // indirect
 | 
			
		||||
	google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect
 | 
			
		||||
	google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect
 | 
			
		||||
	google.golang.org/grpc v1.64.1 // indirect
 | 
			
		||||
	google.golang.org/protobuf v1.34.2 // indirect
 | 
			
		||||
	gopkg.in/yaml.v3 v3.0.1 // indirect
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										156
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										156
									
								
								go.sum
									
									
									
									
									
								
							@@ -1,5 +1,15 @@
 | 
			
		||||
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
 | 
			
		||||
cloud.google.com/go/auth v0.6.1 h1:T0Zw1XM5c1GlpN2HYr2s+m3vr1p2wy+8VN+Z1FKxW38=
 | 
			
		||||
cloud.google.com/go/auth v0.6.1/go.mod h1:eFHG7zDzbXHKmjJddFG/rBlcGp6t25SwRUiEQSlO4x4=
 | 
			
		||||
cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4=
 | 
			
		||||
cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q=
 | 
			
		||||
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
 | 
			
		||||
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
 | 
			
		||||
cloud.google.com/go/iam v1.1.10 h1:ZSAr64oEhQSClwBL670MsJAW5/RLiC6kfw3Bqmd5ZDI=
 | 
			
		||||
cloud.google.com/go/iam v1.1.10/go.mod h1:iEgMq62sg8zx446GCaijmA2Miwg5o3UbO+nI47WHJps=
 | 
			
		||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
 | 
			
		||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
 | 
			
		||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
 | 
			
		||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
 | 
			
		||||
@@ -18,12 +28,15 @@ github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc
 | 
			
		||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
 | 
			
		||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
 | 
			
		||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
 | 
			
		||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
 | 
			
		||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
 | 
			
		||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 | 
			
		||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
 | 
			
		||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
 | 
			
		||||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
 | 
			
		||||
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
 | 
			
		||||
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
 | 
			
		||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
 | 
			
		||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
 | 
			
		||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 | 
			
		||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 | 
			
		||||
@@ -32,6 +45,12 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
 | 
			
		||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
 | 
			
		||||
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
 | 
			
		||||
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
 | 
			
		||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
 | 
			
		||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
 | 
			
		||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
 | 
			
		||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
 | 
			
		||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
 | 
			
		||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
 | 
			
		||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
 | 
			
		||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
 | 
			
		||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
 | 
			
		||||
@@ -48,6 +67,11 @@ github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0Nglqm
 | 
			
		||||
github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw=
 | 
			
		||||
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
 | 
			
		||||
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
 | 
			
		||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
 | 
			
		||||
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
 | 
			
		||||
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
 | 
			
		||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
 | 
			
		||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
 | 
			
		||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
 | 
			
		||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
 | 
			
		||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
 | 
			
		||||
@@ -64,11 +88,40 @@ github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
 | 
			
		||||
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
 | 
			
		||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
 | 
			
		||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
 | 
			
		||||
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
 | 
			
		||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
 | 
			
		||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
 | 
			
		||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
 | 
			
		||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
 | 
			
		||||
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
 | 
			
		||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
 | 
			
		||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
 | 
			
		||||
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
 | 
			
		||||
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
 | 
			
		||||
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
 | 
			
		||||
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
 | 
			
		||||
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
 | 
			
		||||
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
 | 
			
		||||
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
 | 
			
		||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
 | 
			
		||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
 | 
			
		||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
 | 
			
		||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
 | 
			
		||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
 | 
			
		||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 | 
			
		||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 | 
			
		||||
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 | 
			
		||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
 | 
			
		||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
 | 
			
		||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
 | 
			
		||||
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
 | 
			
		||||
github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw=
 | 
			
		||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 | 
			
		||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
 | 
			
		||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 | 
			
		||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs=
 | 
			
		||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
 | 
			
		||||
github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBYGmXdxA=
 | 
			
		||||
github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E=
 | 
			
		||||
github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
 | 
			
		||||
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
 | 
			
		||||
github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o=
 | 
			
		||||
@@ -120,6 +173,8 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY
 | 
			
		||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
 | 
			
		||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
 | 
			
		||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
 | 
			
		||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
 | 
			
		||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
 | 
			
		||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
 | 
			
		||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
 | 
			
		||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
 | 
			
		||||
@@ -128,6 +183,7 @@ github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQ
 | 
			
		||||
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
 | 
			
		||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 | 
			
		||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 | 
			
		||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 | 
			
		||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
 | 
			
		||||
github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
 | 
			
		||||
github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
 | 
			
		||||
@@ -149,26 +205,96 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
 | 
			
		||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
 | 
			
		||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
 | 
			
		||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
 | 
			
		||||
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
 | 
			
		||||
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
 | 
			
		||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg=
 | 
			
		||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0=
 | 
			
		||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
 | 
			
		||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
 | 
			
		||||
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
 | 
			
		||||
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
 | 
			
		||||
go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI=
 | 
			
		||||
go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco=
 | 
			
		||||
go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI=
 | 
			
		||||
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
 | 
			
		||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
 | 
			
		||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
 | 
			
		||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
 | 
			
		||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
 | 
			
		||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 | 
			
		||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
 | 
			
		||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
 | 
			
		||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 | 
			
		||||
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
 | 
			
		||||
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
 | 
			
		||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
 | 
			
		||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
 | 
			
		||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
 | 
			
		||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
 | 
			
		||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
 | 
			
		||||
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
 | 
			
		||||
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
 | 
			
		||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 | 
			
		||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 | 
			
		||||
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 | 
			
		||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
 | 
			
		||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
 | 
			
		||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
 | 
			
		||||
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
 | 
			
		||||
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
 | 
			
		||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
 | 
			
		||||
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
 | 
			
		||||
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
 | 
			
		||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 | 
			
		||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 | 
			
		||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 | 
			
		||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
 | 
			
		||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
 | 
			
		||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 | 
			
		||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 | 
			
		||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
 | 
			
		||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 | 
			
		||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
 | 
			
		||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
 | 
			
		||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
 | 
			
		||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
 | 
			
		||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
 | 
			
		||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
 | 
			
		||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 | 
			
		||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 | 
			
		||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 | 
			
		||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
 | 
			
		||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
 | 
			
		||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
 | 
			
		||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
 | 
			
		||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 | 
			
		||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 | 
			
		||||
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
 | 
			
		||||
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
 | 
			
		||||
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
 | 
			
		||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 | 
			
		||||
google.golang.org/api v0.187.0 h1:Mxs7VATVC2v7CY+7Xwm4ndkX71hpElcvx0D1Ji/p1eo=
 | 
			
		||||
google.golang.org/api v0.187.0/go.mod h1:KIHlTc4x7N7gKKuVsdmfBXN13yEEWXWFURWY6SBp2gk=
 | 
			
		||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
 | 
			
		||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
 | 
			
		||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
 | 
			
		||||
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
 | 
			
		||||
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
 | 
			
		||||
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 h1:MuYw1wJzT+ZkybKfaOXKp5hJiZDn2iHaXRw0mRYdHSc=
 | 
			
		||||
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4/go.mod h1:px9SlOOZBg1wM1zdnr8jEL4CNGUBZ+ZKYtNPApNQc4c=
 | 
			
		||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d h1:k3zyW3BYYR30e8v3x0bTDdE9vpYFjZHK+HcyqkrppWk=
 | 
			
		||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY=
 | 
			
		||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
 | 
			
		||||
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
 | 
			
		||||
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
 | 
			
		||||
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
 | 
			
		||||
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
 | 
			
		||||
google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA=
 | 
			
		||||
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
 | 
			
		||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
 | 
			
		||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
 | 
			
		||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
 | 
			
		||||
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
 | 
			
		||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
 | 
			
		||||
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
 | 
			
		||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
 | 
			
		||||
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
 | 
			
		||||
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
 | 
			
		||||
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
 | 
			
		||||
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
 | 
			
		||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 | 
			
		||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
 | 
			
		||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
 | 
			
		||||
@@ -185,5 +311,7 @@ gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATa
 | 
			
		||||
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
 | 
			
		||||
gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s=
 | 
			
		||||
gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
 | 
			
		||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
 | 
			
		||||
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
 | 
			
		||||
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
 | 
			
		||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
 | 
			
		||||
 
 | 
			
		||||
@@ -140,6 +140,12 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// set channel id for proxy relay
 | 
			
		||||
		if channelId := c.Param("channelid"); channelId != "" {
 | 
			
		||||
			c.Set(ctxkey.SpecificChannelId, channelId)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,21 +2,24 @@ package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channeltype"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ModelRequest struct {
 | 
			
		||||
	Model string `json:"model"`
 | 
			
		||||
	Model string `json:"model" form:"model"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Distribute() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		ctx := c.Request.Context()
 | 
			
		||||
		userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
		userGroup, _ := model.CacheGetUserGroup(userId)
 | 
			
		||||
		c.Set(ctxkey.Group, userGroup)
 | 
			
		||||
@@ -52,6 +55,7 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		logger.Debugf(ctx, "user id %d, user group: %s, request model: %s, using channel #%d", userId, userGroup, requestModel, channel.Id)
 | 
			
		||||
		SetupContextForSelectedChannel(c, channel, requestModel)
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
@@ -61,6 +65,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 | 
			
		||||
	c.Set(ctxkey.Channel, channel.Type)
 | 
			
		||||
	c.Set(ctxkey.ChannelId, channel.Id)
 | 
			
		||||
	c.Set(ctxkey.ChannelName, channel.Name)
 | 
			
		||||
	if channel.SystemPrompt != nil && *channel.SystemPrompt != "" {
 | 
			
		||||
		c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt)
 | 
			
		||||
	}
 | 
			
		||||
	c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
 | 
			
		||||
	c.Set(ctxkey.OriginalModel, modelName) // for retry
 | 
			
		||||
	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										27
									
								
								middleware/gzip.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								middleware/gzip.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
			
		||||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"compress/gzip"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GzipDecodeMiddleware() gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		if c.GetHeader("Content-Encoding") == "gzip" {
 | 
			
		||||
			gzipReader, err := gzip.NewReader(c.Request.Body)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.AbortWithStatus(http.StatusBadRequest)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			defer gzipReader.Close()
 | 
			
		||||
 | 
			
		||||
			// Replace the request body with the decompressed data
 | 
			
		||||
			c.Request.Body = io.NopCloser(gzipReader)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Continue processing the request
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -3,11 +3,12 @@ package middleware
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var timeFormat = "2006-01-02T15:04:05.000Z"
 | 
			
		||||
@@ -70,6 +71,11 @@ func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
 | 
			
		||||
	if maxRequestNum == 0 {
 | 
			
		||||
		return func(c *gin.Context) {
 | 
			
		||||
			c.Next()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if common.RedisEnabled {
 | 
			
		||||
		return func(c *gin.Context) {
 | 
			
		||||
			redisRateLimiter(c, maxRequestNum, duration, mark)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,8 +1,8 @@
 | 
			
		||||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -10,7 +10,7 @@ func RequestId() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		id := helper.GenRequestID()
 | 
			
		||||
		c.Set(helper.RequestIdKey, id)
 | 
			
		||||
		ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id)
 | 
			
		||||
		ctx := helper.SetRequestID(c.Request.Context(), id)
 | 
			
		||||
		c.Request = c.Request.WithContext(ctx)
 | 
			
		||||
		c.Header(helper.RequestIdKey, id)
 | 
			
		||||
		c.Next()
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package model
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
@@ -36,16 +37,19 @@ type Channel struct {
 | 
			
		||||
	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
 | 
			
		||||
	Priority           *int64  `json:"priority" gorm:"bigint;default:0"`
 | 
			
		||||
	Config             string  `json:"config"`
 | 
			
		||||
	SystemPrompt       *string `json:"system_prompt" gorm:"type:text"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChannelConfig struct {
 | 
			
		||||
	Region     string `json:"region,omitempty"`
 | 
			
		||||
	SK         string `json:"sk,omitempty"`
 | 
			
		||||
	AK         string `json:"ak,omitempty"`
 | 
			
		||||
	UserID     string `json:"user_id,omitempty"`
 | 
			
		||||
	APIVersion string `json:"api_version,omitempty"`
 | 
			
		||||
	LibraryID  string `json:"library_id,omitempty"`
 | 
			
		||||
	Plugin     string `json:"plugin,omitempty"`
 | 
			
		||||
	Region            string `json:"region,omitempty"`
 | 
			
		||||
	SK                string `json:"sk,omitempty"`
 | 
			
		||||
	AK                string `json:"ak,omitempty"`
 | 
			
		||||
	UserID            string `json:"user_id,omitempty"`
 | 
			
		||||
	APIVersion        string `json:"api_version,omitempty"`
 | 
			
		||||
	LibraryID         string `json:"library_id,omitempty"`
 | 
			
		||||
	Plugin            string `json:"plugin,omitempty"`
 | 
			
		||||
	VertexAIProjectID string `json:"vertex_ai_project_id,omitempty"`
 | 
			
		||||
	VertexAIADC       string `json:"vertex_ai_adc,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										100
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										100
									
								
								model/log.go
									
									
									
									
									
								
							@@ -3,26 +3,32 @@ package model
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Log struct {
 | 
			
		||||
	Id               int    `json:"id"`
 | 
			
		||||
	UserId           int    `json:"user_id" gorm:"index"`
 | 
			
		||||
	CreatedAt        int64  `json:"created_at" gorm:"bigint;index:idx_created_at_type"`
 | 
			
		||||
	Type             int    `json:"type" gorm:"index:idx_created_at_type"`
 | 
			
		||||
	Content          string `json:"content"`
 | 
			
		||||
	Username         string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
 | 
			
		||||
	TokenName        string `json:"token_name" gorm:"index;default:''"`
 | 
			
		||||
	ModelName        string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
 | 
			
		||||
	Quota            int    `json:"quota" gorm:"default:0"`
 | 
			
		||||
	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"`
 | 
			
		||||
	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"`
 | 
			
		||||
	ChannelId        int    `json:"channel" gorm:"index"`
 | 
			
		||||
	Id                int    `json:"id"`
 | 
			
		||||
	UserId            int    `json:"user_id" gorm:"index"`
 | 
			
		||||
	CreatedAt         int64  `json:"created_at" gorm:"bigint;index:idx_created_at_type"`
 | 
			
		||||
	Type              int    `json:"type" gorm:"index:idx_created_at_type"`
 | 
			
		||||
	Content           string `json:"content"`
 | 
			
		||||
	Username          string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
 | 
			
		||||
	TokenName         string `json:"token_name" gorm:"index;default:''"`
 | 
			
		||||
	ModelName         string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
 | 
			
		||||
	Quota             int    `json:"quota" gorm:"default:0"`
 | 
			
		||||
	PromptTokens      int    `json:"prompt_tokens" gorm:"default:0"`
 | 
			
		||||
	CompletionTokens  int    `json:"completion_tokens" gorm:"default:0"`
 | 
			
		||||
	ChannelId         int    `json:"channel" gorm:"index"`
 | 
			
		||||
	RequestId         string `json:"request_id" gorm:"default:''"`
 | 
			
		||||
	ElapsedTime       int64  `json:"elapsed_time" gorm:"default:0"` // unit is ms
 | 
			
		||||
	IsStream          bool   `json:"is_stream" gorm:"default:false"`
 | 
			
		||||
	SystemPromptReset bool   `json:"system_prompt_reset" gorm:"default:false"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@@ -31,9 +37,21 @@ const (
 | 
			
		||||
	LogTypeConsume
 | 
			
		||||
	LogTypeManage
 | 
			
		||||
	LogTypeSystem
 | 
			
		||||
	LogTypeTest
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func RecordLog(userId int, logType int, content string) {
 | 
			
		||||
func recordLogHelper(ctx context.Context, log *Log) {
 | 
			
		||||
	requestId := helper.GetRequestID(ctx)
 | 
			
		||||
	log.RequestId = requestId
 | 
			
		||||
	err := LOG_DB.Create(log).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(ctx, "failed to record log: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	logger.Infof(ctx, "record log: %+v", log)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RecordLog(ctx context.Context, userId int, logType int, content string) {
 | 
			
		||||
	if logType == LogTypeConsume && !config.LogConsumeEnabled {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -44,13 +62,10 @@ func RecordLog(userId int, logType int, content string) {
 | 
			
		||||
		Type:      logType,
 | 
			
		||||
		Content:   content,
 | 
			
		||||
	}
 | 
			
		||||
	err := LOG_DB.Create(log).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError("failed to record log: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	recordLogHelper(ctx, log)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RecordTopupLog(userId int, content string, quota int) {
 | 
			
		||||
func RecordTopupLog(ctx context.Context, userId int, content string, quota int) {
 | 
			
		||||
	log := &Log{
 | 
			
		||||
		UserId:    userId,
 | 
			
		||||
		Username:  GetUsernameById(userId),
 | 
			
		||||
@@ -59,34 +74,23 @@ func RecordTopupLog(userId int, content string, quota int) {
 | 
			
		||||
		Content:   content,
 | 
			
		||||
		Quota:     quota,
 | 
			
		||||
	}
 | 
			
		||||
	err := LOG_DB.Create(log).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError("failed to record log: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	recordLogHelper(ctx, log)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
 | 
			
		||||
	logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
 | 
			
		||||
func RecordConsumeLog(ctx context.Context, log *Log) {
 | 
			
		||||
	if !config.LogConsumeEnabled {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	log := &Log{
 | 
			
		||||
		UserId:           userId,
 | 
			
		||||
		Username:         GetUsernameById(userId),
 | 
			
		||||
		CreatedAt:        helper.GetTimestamp(),
 | 
			
		||||
		Type:             LogTypeConsume,
 | 
			
		||||
		Content:          content,
 | 
			
		||||
		PromptTokens:     promptTokens,
 | 
			
		||||
		CompletionTokens: completionTokens,
 | 
			
		||||
		TokenName:        tokenName,
 | 
			
		||||
		ModelName:        modelName,
 | 
			
		||||
		Quota:            int(quota),
 | 
			
		||||
		ChannelId:        channelId,
 | 
			
		||||
	}
 | 
			
		||||
	err := LOG_DB.Create(log).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(ctx, "failed to record log: "+err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	log.Username = GetUsernameById(log.UserId)
 | 
			
		||||
	log.CreatedAt = helper.GetTimestamp()
 | 
			
		||||
	log.Type = LogTypeConsume
 | 
			
		||||
	recordLogHelper(ctx, log)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RecordTestLog(ctx context.Context, log *Log) {
 | 
			
		||||
	log.CreatedAt = helper.GetTimestamp()
 | 
			
		||||
	log.Type = LogTypeTest
 | 
			
		||||
	recordLogHelper(ctx, log)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
 | 
			
		||||
@@ -152,7 +156,11 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
 | 
			
		||||
	tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)")
 | 
			
		||||
	ifnull := "ifnull"
 | 
			
		||||
	if common.UsingPostgreSQL {
 | 
			
		||||
		ifnull = "COALESCE"
 | 
			
		||||
	}
 | 
			
		||||
	tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull))
 | 
			
		||||
	if username != "" {
 | 
			
		||||
		tx = tx.Where("username = ?", username)
 | 
			
		||||
	}
 | 
			
		||||
@@ -176,7 +184,11 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
 | 
			
		||||
	tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
 | 
			
		||||
	ifnull := "ifnull"
 | 
			
		||||
	if common.UsingPostgreSQL {
 | 
			
		||||
		ifnull = "COALESCE"
 | 
			
		||||
	}
 | 
			
		||||
	tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull))
 | 
			
		||||
	if username != "" {
 | 
			
		||||
		tx = tx.Where("username = ?", username)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -28,6 +28,7 @@ func InitOptionMap() {
 | 
			
		||||
	config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
 | 
			
		||||
	config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
 | 
			
		||||
	config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
 | 
			
		||||
	config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled)
 | 
			
		||||
	config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
 | 
			
		||||
	config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
 | 
			
		||||
	config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
 | 
			
		||||
@@ -130,6 +131,8 @@ func updateOptionMap(key string, value string) (err error) {
 | 
			
		||||
			config.EmailVerificationEnabled = boolValue
 | 
			
		||||
		case "GitHubOAuthEnabled":
 | 
			
		||||
			config.GitHubOAuthEnabled = boolValue
 | 
			
		||||
		case "OidcEnabled":
 | 
			
		||||
			config.OidcEnabled = boolValue
 | 
			
		||||
		case "WeChatAuthEnabled":
 | 
			
		||||
			config.WeChatAuthEnabled = boolValue
 | 
			
		||||
		case "TurnstileCheckEnabled":
 | 
			
		||||
@@ -176,6 +179,18 @@ func updateOptionMap(key string, value string) (err error) {
 | 
			
		||||
		config.LarkClientId = value
 | 
			
		||||
	case "LarkClientSecret":
 | 
			
		||||
		config.LarkClientSecret = value
 | 
			
		||||
	case "OidcClientId":
 | 
			
		||||
		config.OidcClientId = value
 | 
			
		||||
	case "OidcClientSecret":
 | 
			
		||||
		config.OidcClientSecret = value
 | 
			
		||||
	case "OidcWellKnown":
 | 
			
		||||
		config.OidcWellKnown = value
 | 
			
		||||
	case "OidcAuthorizationEndpoint":
 | 
			
		||||
		config.OidcAuthorizationEndpoint = value
 | 
			
		||||
	case "OidcTokenEndpoint":
 | 
			
		||||
		config.OidcTokenEndpoint = value
 | 
			
		||||
	case "OidcUserinfoEndpoint":
 | 
			
		||||
		config.OidcUserinfoEndpoint = value
 | 
			
		||||
	case "Footer":
 | 
			
		||||
		config.Footer = value
 | 
			
		||||
	case "SystemName":
 | 
			
		||||
 
 | 
			
		||||
@@ -1,11 +1,14 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@@ -48,7 +51,7 @@ func GetRedemptionById(id int) (*Redemption, error) {
 | 
			
		||||
	return &redemption, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Redeem(key string, userId int) (quota int64, err error) {
 | 
			
		||||
func Redeem(ctx context.Context, key string, userId int) (quota int64, err error) {
 | 
			
		||||
	if key == "" {
 | 
			
		||||
		return 0, errors.New("未提供兑换码")
 | 
			
		||||
	}
 | 
			
		||||
@@ -82,7 +85,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, errors.New("兑换失败," + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
 | 
			
		||||
	RecordLog(ctx, userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
 | 
			
		||||
	return redemption.Quota, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -30,7 +30,7 @@ type Token struct {
 | 
			
		||||
	RemainQuota    int64   `json:"remain_quota" gorm:"bigint;default:0"`
 | 
			
		||||
	UnlimitedQuota bool    `json:"unlimited_quota" gorm:"default:false"`
 | 
			
		||||
	UsedQuota      int64   `json:"used_quota" gorm:"bigint;default:0"` // used quota
 | 
			
		||||
	Models         *string `json:"models" gorm:"default:''"`           // allowed models
 | 
			
		||||
	Models         *string `json:"models" gorm:"type:text"`            // allowed models
 | 
			
		||||
	Subnet         *string `json:"subnet" gorm:"default:''"`           // allowed subnet
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -121,30 +121,40 @@ func GetTokenById(id int) (*Token, error) {
 | 
			
		||||
	return &token, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (token *Token) Insert() error {
 | 
			
		||||
func (t *Token) Insert() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Create(token).Error
 | 
			
		||||
	err = DB.Create(t).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Update Make sure your token's fields is completed, because this will update non-zero values
 | 
			
		||||
func (token *Token) Update() error {
 | 
			
		||||
func (t *Token) Update() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error
 | 
			
		||||
	err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (token *Token) SelectUpdate() error {
 | 
			
		||||
func (t *Token) SelectUpdate() error {
 | 
			
		||||
	// This can update zero values
 | 
			
		||||
	return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
 | 
			
		||||
	return DB.Model(t).Select("accessed_time", "status").Updates(t).Error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (token *Token) Delete() error {
 | 
			
		||||
func (t *Token) Delete() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Delete(token).Error
 | 
			
		||||
	err = DB.Delete(t).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *Token) GetModels() string {
 | 
			
		||||
	if t == nil {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	if t.Models == nil {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	return *t.Models
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteTokenById(id int, userId int) (err error) {
 | 
			
		||||
	// Why we need userId here? In case user want to delete other's token.
 | 
			
		||||
	if id == 0 || userId == 0 {
 | 
			
		||||
@@ -254,14 +264,14 @@ func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
 | 
			
		||||
 | 
			
		||||
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
 | 
			
		||||
	token, err := GetTokenById(tokenId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if quota > 0 {
 | 
			
		||||
		err = DecreaseUserQuota(token.UserId, quota)
 | 
			
		||||
	} else {
 | 
			
		||||
		err = IncreaseUserQuota(token.UserId, -quota)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if !token.UnlimitedQuota {
 | 
			
		||||
		if quota > 0 {
 | 
			
		||||
			err = DecreaseTokenQuota(tokenId, quota)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,16 +1,19 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/blacklist"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@@ -39,6 +42,7 @@ type User struct {
 | 
			
		||||
	GitHubId         string `json:"github_id" gorm:"column:github_id;index"`
 | 
			
		||||
	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"`
 | 
			
		||||
	LarkId           string `json:"lark_id" gorm:"column:lark_id;index"`
 | 
			
		||||
	OidcId           string `json:"oidc_id" gorm:"column:oidc_id;index"`
 | 
			
		||||
	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database!
 | 
			
		||||
	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
 | 
			
		||||
	Quota            int64  `json:"quota" gorm:"bigint;default:0"`
 | 
			
		||||
@@ -91,7 +95,7 @@ func GetUserById(id int, selectAll bool) (*User, error) {
 | 
			
		||||
	if selectAll {
 | 
			
		||||
		err = DB.First(&user, "id = ?", id).Error
 | 
			
		||||
	} else {
 | 
			
		||||
		err = DB.Omit("password").First(&user, "id = ?", id).Error
 | 
			
		||||
		err = DB.Omit("password", "access_token").First(&user, "id = ?", id).Error
 | 
			
		||||
	}
 | 
			
		||||
	return &user, err
 | 
			
		||||
}
 | 
			
		||||
@@ -113,7 +117,7 @@ func DeleteUserById(id int) (err error) {
 | 
			
		||||
	return user.Delete()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (user *User) Insert(inviterId int) error {
 | 
			
		||||
func (user *User) Insert(ctx context.Context, inviterId int) error {
 | 
			
		||||
	var err error
 | 
			
		||||
	if user.Password != "" {
 | 
			
		||||
		user.Password, err = common.Password2Hash(user.Password)
 | 
			
		||||
@@ -129,16 +133,16 @@ func (user *User) Insert(inviterId int) error {
 | 
			
		||||
		return result.Error
 | 
			
		||||
	}
 | 
			
		||||
	if config.QuotaForNewUser > 0 {
 | 
			
		||||
		RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
 | 
			
		||||
		RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
 | 
			
		||||
	}
 | 
			
		||||
	if inviterId != 0 {
 | 
			
		||||
		if config.QuotaForInvitee > 0 {
 | 
			
		||||
			_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
 | 
			
		||||
			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
 | 
			
		||||
			RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
 | 
			
		||||
		}
 | 
			
		||||
		if config.QuotaForInviter > 0 {
 | 
			
		||||
			_ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
 | 
			
		||||
			RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
 | 
			
		||||
			RecordLog(ctx, inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	// create default token
 | 
			
		||||
@@ -245,6 +249,14 @@ func (user *User) FillUserByLarkId() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (user *User) FillUserByOidcId() error {
 | 
			
		||||
	if user.OidcId == "" {
 | 
			
		||||
		return errors.New("oidc id 为空!")
 | 
			
		||||
	}
 | 
			
		||||
	DB.Where(User{OidcId: user.OidcId}).First(user)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (user *User) FillUserByWeChatId() error {
 | 
			
		||||
	if user.WeChatId == "" {
 | 
			
		||||
		return errors.New("WeChat id 为空!")
 | 
			
		||||
@@ -277,6 +289,10 @@ func IsLarkIdAlreadyTaken(githubId string) bool {
 | 
			
		||||
	return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IsOidcIdAlreadyTaken(oidcId string) bool {
 | 
			
		||||
	return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IsUsernameAlreadyTaken(username string) bool {
 | 
			
		||||
	return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,11 @@
 | 
			
		||||
package monitor
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ShouldDisableChannel(err *model.Error, statusCode int) bool {
 | 
			
		||||
@@ -18,31 +19,23 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	switch err.Type {
 | 
			
		||||
	case "insufficient_quota":
 | 
			
		||||
		return true
 | 
			
		||||
	// https://docs.anthropic.com/claude/reference/errors
 | 
			
		||||
	case "authentication_error":
 | 
			
		||||
		return true
 | 
			
		||||
	case "permission_error":
 | 
			
		||||
		return true
 | 
			
		||||
	case "forbidden":
 | 
			
		||||
	case "insufficient_quota", "authentication_error", "permission_error", "forbidden":
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
 | 
			
		||||
		return true
 | 
			
		||||
	} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	//if strings.Contains(err.Message, "quota") {
 | 
			
		||||
	//	return true
 | 
			
		||||
	//}
 | 
			
		||||
	if strings.Contains(err.Message, "credit") {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if strings.Contains(err.Message, "balance") {
 | 
			
		||||
 | 
			
		||||
	lowerMessage := strings.ToLower(err.Message)
 | 
			
		||||
	if strings.Contains(lowerMessage, "your access was terminated") ||
 | 
			
		||||
		strings.Contains(lowerMessage, "violation of our policies") ||
 | 
			
		||||
		strings.Contains(lowerMessage, "your credit balance is too low") ||
 | 
			
		||||
		strings.Contains(lowerMessage, "organization has been disabled") ||
 | 
			
		||||
		strings.Contains(lowerMessage, "credit") ||
 | 
			
		||||
		strings.Contains(lowerMessage, "balance") ||
 | 
			
		||||
		strings.Contains(lowerMessage, "permission denied") ||
 | 
			
		||||
		strings.Contains(lowerMessage, "organization has been restricted") || // groq
 | 
			
		||||
		strings.Contains(lowerMessage, "已欠费") {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
 
 | 
			
		||||
@@ -15,7 +15,10 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/ollama"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/palm"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/proxy"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/replicate"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/tencent"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/zhipu"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/apitype"
 | 
			
		||||
@@ -55,6 +58,12 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
 | 
			
		||||
		return &cloudflare.Adaptor{}
 | 
			
		||||
	case apitype.DeepL:
 | 
			
		||||
		return &deepl.Adaptor{}
 | 
			
		||||
	case apitype.VertexAI:
 | 
			
		||||
		return &vertexai.Adaptor{}
 | 
			
		||||
	case apitype.Proxy:
 | 
			
		||||
		return &proxy.Adaptor{}
 | 
			
		||||
	case apitype.Replicate:
 | 
			
		||||
		return &replicate.Adaptor{}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,23 @@
 | 
			
		||||
package ali
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
 | 
			
		||||
	"text-embedding-v1",
 | 
			
		||||
	"qwen-turbo", "qwen-turbo-latest",
 | 
			
		||||
	"qwen-plus", "qwen-plus-latest",
 | 
			
		||||
	"qwen-max", "qwen-max-latest",
 | 
			
		||||
	"qwen-max-longcontext",
 | 
			
		||||
	"qwen-vl-max", "qwen-vl-max-latest", "qwen-vl-plus", "qwen-vl-plus-latest",
 | 
			
		||||
	"qwen-vl-ocr", "qwen-vl-ocr-latest",
 | 
			
		||||
	"qwen-audio-turbo",
 | 
			
		||||
	"qwen-math-plus", "qwen-math-plus-latest", "qwen-math-turbo", "qwen-math-turbo-latest",
 | 
			
		||||
	"qwen-coder-plus", "qwen-coder-plus-latest", "qwen-coder-turbo", "qwen-coder-turbo-latest",
 | 
			
		||||
	"qwq-32b-preview", "qwen2.5-72b-instruct", "qwen2.5-32b-instruct", "qwen2.5-14b-instruct", "qwen2.5-7b-instruct", "qwen2.5-3b-instruct", "qwen2.5-1.5b-instruct", "qwen2.5-0.5b-instruct",
 | 
			
		||||
	"qwen2-72b-instruct", "qwen2-57b-a14b-instruct", "qwen2-7b-instruct", "qwen2-1.5b-instruct", "qwen2-0.5b-instruct",
 | 
			
		||||
	"qwen1.5-110b-chat", "qwen1.5-72b-chat", "qwen1.5-32b-chat", "qwen1.5-14b-chat", "qwen1.5-7b-chat", "qwen1.5-1.8b-chat", "qwen1.5-0.5b-chat",
 | 
			
		||||
	"qwen-72b-chat", "qwen-14b-chat", "qwen-7b-chat", "qwen-1.8b-chat", "qwen-1.8b-longcontext-chat",
 | 
			
		||||
	"qwen2-vl-7b-instruct", "qwen2-vl-2b-instruct", "qwen-vl-v1", "qwen-vl-chat-v1",
 | 
			
		||||
	"qwen2-audio-instruct", "qwen-audio-chat",
 | 
			
		||||
	"qwen2.5-math-72b-instruct", "qwen2.5-math-7b-instruct", "qwen2.5-math-1.5b-instruct", "qwen2-math-72b-instruct", "qwen2-math-7b-instruct", "qwen2-math-1.5b-instruct",
 | 
			
		||||
	"qwen2.5-coder-32b-instruct", "qwen2.5-coder-14b-instruct", "qwen2.5-coder-7b-instruct", "qwen2.5-coder-3b-instruct", "qwen2.5-coder-1.5b-instruct", "qwen2.5-coder-0.5b-instruct",
 | 
			
		||||
	"text-embedding-v1", "text-embedding-v3", "text-embedding-v2", "text-embedding-async-v2", "text-embedding-async-v1",
 | 
			
		||||
	"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package ali
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -35,9 +36,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
		enableSearch = true
 | 
			
		||||
		aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
 | 
			
		||||
	}
 | 
			
		||||
	if request.TopP >= 1 {
 | 
			
		||||
		request.TopP = 0.9999
 | 
			
		||||
	}
 | 
			
		||||
	request.TopP = helper.Float64PtrMax(request.TopP, 0.9999)
 | 
			
		||||
	return &ChatRequest{
 | 
			
		||||
		Model: aliModel,
 | 
			
		||||
		Input: Input{
 | 
			
		||||
@@ -59,7 +58,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
 | 
			
		||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
 | 
			
		||||
	return &EmbeddingRequest{
 | 
			
		||||
		Model: "text-embedding-v1",
 | 
			
		||||
		Model: request.Model,
 | 
			
		||||
		Input: struct {
 | 
			
		||||
			Texts []string `json:"texts"`
 | 
			
		||||
		}{
 | 
			
		||||
@@ -102,8 +101,9 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	requestModel := c.GetString(ctxkey.RequestModel)
 | 
			
		||||
	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
 | 
			
		||||
	fullTextResponse.Model = requestModel
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
 
 | 
			
		||||
@@ -16,13 +16,13 @@ type Input struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Parameters struct {
 | 
			
		||||
	TopP              float64      `json:"top_p,omitempty"`
 | 
			
		||||
	TopP              *float64     `json:"top_p,omitempty"`
 | 
			
		||||
	TopK              int          `json:"top_k,omitempty"`
 | 
			
		||||
	Seed              uint64       `json:"seed,omitempty"`
 | 
			
		||||
	EnableSearch      bool         `json:"enable_search,omitempty"`
 | 
			
		||||
	IncrementalOutput bool         `json:"incremental_output,omitempty"`
 | 
			
		||||
	MaxTokens         int          `json:"max_tokens,omitempty"`
 | 
			
		||||
	Temperature       float64      `json:"temperature,omitempty"`
 | 
			
		||||
	Temperature       *float64     `json:"temperature,omitempty"`
 | 
			
		||||
	ResultFormat      string       `json:"result_format,omitempty"`
 | 
			
		||||
	Tools             []model.Tool `json:"tools,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,12 +3,14 @@ package anthropic
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
@@ -31,6 +33,13 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("anthropic-version", anthropicVersion)
 | 
			
		||||
	req.Header.Set("anthropic-beta", "messages-2023-12-15")
 | 
			
		||||
 | 
			
		||||
	// https://x.com/alexalbert__/status/1812921642143900036
 | 
			
		||||
	// claude-3-5-sonnet can support 8k context
 | 
			
		||||
	if strings.HasPrefix(meta.ActualModelName, "claude-3-5-sonnet") {
 | 
			
		||||
		req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,10 @@ package anthropic
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"claude-instant-1.2", "claude-2.0", "claude-2.1",
 | 
			
		||||
	"claude-3-haiku-20240307",
 | 
			
		||||
	"claude-3-5-haiku-20241022",
 | 
			
		||||
	"claude-3-sonnet-20240229",
 | 
			
		||||
	"claude-3-opus-20240229",
 | 
			
		||||
	"claude-3-5-sonnet-20240620",
 | 
			
		||||
	"claude-3-5-sonnet-20241022",
 | 
			
		||||
	"claude-3-5-sonnet-latest",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -48,8 +48,8 @@ type Request struct {
 | 
			
		||||
	MaxTokens     int       `json:"max_tokens,omitempty"`
 | 
			
		||||
	StopSequences []string  `json:"stop_sequences,omitempty"`
 | 
			
		||||
	Stream        bool      `json:"stream,omitempty"`
 | 
			
		||||
	Temperature   float64   `json:"temperature,omitempty"`
 | 
			
		||||
	TopP          float64   `json:"top_p,omitempty"`
 | 
			
		||||
	Temperature   *float64  `json:"temperature,omitempty"`
 | 
			
		||||
	TopP          *float64  `json:"top_p,omitempty"`
 | 
			
		||||
	TopK          int       `json:"top_k,omitempty"`
 | 
			
		||||
	Tools         []Tool    `json:"tools,omitempty"`
 | 
			
		||||
	ToolChoice    any       `json:"tool_choice,omitempty"`
 | 
			
		||||
 
 | 
			
		||||
@@ -29,10 +29,13 @@ var AwsModelIDMap = map[string]string{
 | 
			
		||||
	"claude-instant-1.2":         "anthropic.claude-instant-v1",
 | 
			
		||||
	"claude-2.0":                 "anthropic.claude-v2",
 | 
			
		||||
	"claude-2.1":                 "anthropic.claude-v2:1",
 | 
			
		||||
	"claude-3-sonnet-20240229":   "anthropic.claude-3-sonnet-20240229-v1:0",
 | 
			
		||||
	"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
 | 
			
		||||
	"claude-3-opus-20240229":     "anthropic.claude-3-opus-20240229-v1:0",
 | 
			
		||||
	"claude-3-haiku-20240307":    "anthropic.claude-3-haiku-20240307-v1:0",
 | 
			
		||||
	"claude-3-sonnet-20240229":   "anthropic.claude-3-sonnet-20240229-v1:0",
 | 
			
		||||
	"claude-3-opus-20240229":     "anthropic.claude-3-opus-20240229-v1:0",
 | 
			
		||||
	"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
 | 
			
		||||
	"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
 | 
			
		||||
	"claude-3-5-sonnet-latest":   "anthropic.claude-3-5-sonnet-20241022-v2:0",
 | 
			
		||||
	"claude-3-5-haiku-20241022":  "anthropic.claude-3-5-haiku-20241022-v1:0",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func awsModelID(requestModel string) (string, error) {
 | 
			
		||||
 
 | 
			
		||||
@@ -11,8 +11,8 @@ type Request struct {
 | 
			
		||||
	Messages         []anthropic.Message `json:"messages"`
 | 
			
		||||
	System           string              `json:"system,omitempty"`
 | 
			
		||||
	MaxTokens        int                 `json:"max_tokens,omitempty"`
 | 
			
		||||
	Temperature      float64             `json:"temperature,omitempty"`
 | 
			
		||||
	TopP             float64             `json:"top_p,omitempty"`
 | 
			
		||||
	Temperature      *float64            `json:"temperature,omitempty"`
 | 
			
		||||
	TopP             *float64            `json:"top_p,omitempty"`
 | 
			
		||||
	TopK             int                 `json:"top_k,omitempty"`
 | 
			
		||||
	StopSequences    []string            `json:"stop_sequences,omitempty"`
 | 
			
		||||
	Tools            []anthropic.Tool    `json:"tools,omitempty"`
 | 
			
		||||
 
 | 
			
		||||
@@ -4,10 +4,10 @@ package aws
 | 
			
		||||
//
 | 
			
		||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
 | 
			
		||||
type Request struct {
 | 
			
		||||
	Prompt      string  `json:"prompt"`
 | 
			
		||||
	MaxGenLen   int     `json:"max_gen_len,omitempty"`
 | 
			
		||||
	Temperature float64 `json:"temperature,omitempty"`
 | 
			
		||||
	TopP        float64 `json:"top_p,omitempty"`
 | 
			
		||||
	Prompt      string   `json:"prompt"`
 | 
			
		||||
	MaxGenLen   int      `json:"max_gen_len,omitempty"`
 | 
			
		||||
	Temperature *float64 `json:"temperature,omitempty"`
 | 
			
		||||
	TopP        *float64 `json:"top_p,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Response is the response from AWS Llama3
 | 
			
		||||
 
 | 
			
		||||
@@ -35,9 +35,9 @@ type Message struct {
 | 
			
		||||
 | 
			
		||||
type ChatRequest struct {
 | 
			
		||||
	Messages        []Message `json:"messages"`
 | 
			
		||||
	Temperature     float64   `json:"temperature,omitempty"`
 | 
			
		||||
	TopP            float64   `json:"top_p,omitempty"`
 | 
			
		||||
	PenaltyScore    float64   `json:"penalty_score,omitempty"`
 | 
			
		||||
	Temperature     *float64  `json:"temperature,omitempty"`
 | 
			
		||||
	TopP            *float64  `json:"top_p,omitempty"`
 | 
			
		||||
	PenaltyScore    *float64  `json:"penalty_score,omitempty"`
 | 
			
		||||
	Stream          bool      `json:"stream,omitempty"`
 | 
			
		||||
	System          string    `json:"system,omitempty"`
 | 
			
		||||
	DisableSearch   bool      `json:"disable_search,omitempty"`
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
@@ -28,14 +29,32 @@ func (a *Adaptor) Init(meta *meta.Meta) {
 | 
			
		||||
	a.meta = meta
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WorkerAI cannot be used across accounts with AIGateWay
 | 
			
		||||
// https://developers.cloudflare.com/ai-gateway/providers/workersai/#openai-compatible-endpoints
 | 
			
		||||
// https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/workers-ai
 | 
			
		||||
func (a *Adaptor) isAIGateWay(baseURL string) bool {
 | 
			
		||||
	return strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") && strings.HasSuffix(baseURL, "/workers-ai")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
 | 
			
		||||
	isAIGateWay := a.isAIGateWay(meta.BaseURL)
 | 
			
		||||
	var urlPrefix string
 | 
			
		||||
	if isAIGateWay {
 | 
			
		||||
		urlPrefix = meta.BaseURL
 | 
			
		||||
	} else {
 | 
			
		||||
		urlPrefix = fmt.Sprintf("%s/client/v4/accounts/%s/ai", meta.BaseURL, meta.Config.UserID)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch meta.Mode {
 | 
			
		||||
	case relaymode.ChatCompletions:
 | 
			
		||||
		return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", meta.BaseURL, meta.Config.UserID), nil
 | 
			
		||||
		return fmt.Sprintf("%s/v1/chat/completions", urlPrefix), nil
 | 
			
		||||
	case relaymode.Embeddings:
 | 
			
		||||
		return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", meta.BaseURL, meta.Config.UserID), nil
 | 
			
		||||
		return fmt.Sprintf("%s/v1/embeddings", urlPrefix), nil
 | 
			
		||||
	default:
 | 
			
		||||
		return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil
 | 
			
		||||
		if isAIGateWay {
 | 
			
		||||
			return fmt.Sprintf("%s/%s", urlPrefix, meta.ActualModelName), nil
 | 
			
		||||
		}
 | 
			
		||||
		return fmt.Sprintf("%s/run/%s", urlPrefix, meta.ActualModelName), nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package cloudflare
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"@cf/meta/llama-3.1-8b-instruct",
 | 
			
		||||
	"@cf/meta/llama-2-7b-chat-fp16",
 | 
			
		||||
	"@cf/meta/llama-2-7b-chat-int8",
 | 
			
		||||
	"@cf/mistral/mistral-7b-instruct-v0.1",
 | 
			
		||||
 
 | 
			
		||||
@@ -9,5 +9,5 @@ type Request struct {
 | 
			
		||||
	Prompt      string          `json:"prompt,omitempty"`
 | 
			
		||||
	Raw         bool            `json:"raw,omitempty"`
 | 
			
		||||
	Stream      bool            `json:"stream,omitempty"`
 | 
			
		||||
	Temperature float64         `json:"temperature,omitempty"`
 | 
			
		||||
	Temperature *float64        `json:"temperature,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -43,7 +43,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
 | 
			
		||||
		K:                textRequest.TopK,
 | 
			
		||||
		Stream:           textRequest.Stream,
 | 
			
		||||
		FrequencyPenalty: textRequest.FrequencyPenalty,
 | 
			
		||||
		PresencePenalty:  textRequest.FrequencyPenalty,
 | 
			
		||||
		PresencePenalty:  textRequest.PresencePenalty,
 | 
			
		||||
		Seed:             int(textRequest.Seed),
 | 
			
		||||
	}
 | 
			
		||||
	if cohereRequest.Model == "" {
 | 
			
		||||
 
 | 
			
		||||
@@ -10,15 +10,15 @@ type Request struct {
 | 
			
		||||
	PromptTruncation string        `json:"prompt_truncation,omitempty"` // 默认值为"AUTO"
 | 
			
		||||
	Connectors       []Connector   `json:"connectors,omitempty"`
 | 
			
		||||
	Documents        []Document    `json:"documents,omitempty"`
 | 
			
		||||
	Temperature      float64       `json:"temperature,omitempty"` // 默认值为0.3
 | 
			
		||||
	Temperature      *float64      `json:"temperature,omitempty"` // 默认值为0.3
 | 
			
		||||
	MaxTokens        int           `json:"max_tokens,omitempty"`
 | 
			
		||||
	MaxInputTokens   int           `json:"max_input_tokens,omitempty"`
 | 
			
		||||
	K                int           `json:"k,omitempty"` // 默认值为0
 | 
			
		||||
	P                float64       `json:"p,omitempty"` // 默认值为0.75
 | 
			
		||||
	P                *float64      `json:"p,omitempty"` // 默认值为0.75
 | 
			
		||||
	Seed             int           `json:"seed,omitempty"`
 | 
			
		||||
	StopSequences    []string      `json:"stop_sequences,omitempty"`
 | 
			
		||||
	FrequencyPenalty float64       `json:"frequency_penalty,omitempty"` // 默认值为0.0
 | 
			
		||||
	PresencePenalty  float64       `json:"presence_penalty,omitempty"`  // 默认值为0.0
 | 
			
		||||
	FrequencyPenalty *float64      `json:"frequency_penalty,omitempty"` // 默认值为0.0
 | 
			
		||||
	PresencePenalty  *float64      `json:"presence_penalty,omitempty"`  // 默认值为0.0
 | 
			
		||||
	Tools            []Tool        `json:"tools,omitempty"`
 | 
			
		||||
	ToolResults      []ToolResult  `json:"tool_results,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -7,8 +7,12 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetRequestURL(meta *meta.Meta) (string, error) {
 | 
			
		||||
	if meta.Mode == relaymode.ChatCompletions {
 | 
			
		||||
	switch meta.Mode {
 | 
			
		||||
	case relaymode.ChatCompletions:
 | 
			
		||||
		return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil
 | 
			
		||||
	case relaymode.Embeddings:
 | 
			
		||||
		return fmt.Sprintf("%s/api/v3/embeddings", meta.BaseURL), nil
 | 
			
		||||
	default:
 | 
			
		||||
	}
 | 
			
		||||
	return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,6 @@ import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
@@ -24,7 +23,15 @@ func (a *Adaptor) Init(meta *meta.Meta) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
 | 
			
		||||
	version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion)
 | 
			
		||||
	var defaultVersion string
 | 
			
		||||
	switch meta.ActualModelName {
 | 
			
		||||
	case "gemini-2.0-flash-exp",
 | 
			
		||||
		"gemini-2.0-flash-thinking-exp",
 | 
			
		||||
		"gemini-2.0-flash-thinking-exp-01-21":
 | 
			
		||||
		defaultVersion = "v1beta"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion)
 | 
			
		||||
	action := ""
 | 
			
		||||
	switch meta.Mode {
 | 
			
		||||
	case relaymode.Embeddings:
 | 
			
		||||
@@ -36,6 +43,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
 | 
			
		||||
	if meta.IsStream {
 | 
			
		||||
		action = "streamGenerateContent?alt=sse"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,9 @@ package gemini
 | 
			
		||||
// https://ai.google.dev/models/gemini
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
 | 
			
		||||
	"gemini-pro-vision", "gemini-1.0-pro-vision-001", "embedding-001", "text-embedding-004",
 | 
			
		||||
	"gemini-pro", "gemini-1.0-pro",
 | 
			
		||||
	"gemini-1.5-flash", "gemini-1.5-pro",
 | 
			
		||||
	"text-embedding-004", "aqa",
 | 
			
		||||
	"gemini-2.0-flash-exp",
 | 
			
		||||
	"gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,11 +4,12 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
@@ -28,6 +29,11 @@ const (
 | 
			
		||||
	VisionMaxImageNum = 16
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var mimeTypeMap = map[string]string{
 | 
			
		||||
	"json_object": "application/json",
 | 
			
		||||
	"text":        "text/plain",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
 | 
			
		||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
	geminiRequest := ChatRequest{
 | 
			
		||||
@@ -49,6 +55,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
				Category:  "HARM_CATEGORY_DANGEROUS_CONTENT",
 | 
			
		||||
				Threshold: config.GeminiSafetySetting,
 | 
			
		||||
			},
 | 
			
		||||
			{
 | 
			
		||||
				Category:  "HARM_CATEGORY_CIVIC_INTEGRITY",
 | 
			
		||||
				Threshold: config.GeminiSafetySetting,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		GenerationConfig: ChatGenerationConfig{
 | 
			
		||||
			Temperature:     textRequest.Temperature,
 | 
			
		||||
@@ -56,6 +66,15 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
			MaxOutputTokens: textRequest.MaxTokens,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	if textRequest.ResponseFormat != nil {
 | 
			
		||||
		if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok {
 | 
			
		||||
			geminiRequest.GenerationConfig.ResponseMimeType = mimeType
 | 
			
		||||
		}
 | 
			
		||||
		if textRequest.ResponseFormat.JsonSchema != nil {
 | 
			
		||||
			geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema
 | 
			
		||||
			geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"]
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if textRequest.Tools != nil {
 | 
			
		||||
		functions := make([]model.Function, 0, len(textRequest.Tools))
 | 
			
		||||
		for _, tool := range textRequest.Tools {
 | 
			
		||||
@@ -232,7 +251,14 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
 | 
			
		||||
			if candidate.Content.Parts[0].FunctionCall != nil {
 | 
			
		||||
				choice.Message.ToolCalls = getToolCalls(&candidate)
 | 
			
		||||
			} else {
 | 
			
		||||
				choice.Message.Content = candidate.Content.Parts[0].Text
 | 
			
		||||
				var builder strings.Builder
 | 
			
		||||
				for _, part := range candidate.Content.Parts {
 | 
			
		||||
					if i > 0 {
 | 
			
		||||
						builder.WriteString("\n")
 | 
			
		||||
					}
 | 
			
		||||
					builder.WriteString(part.Text)
 | 
			
		||||
				}
 | 
			
		||||
				choice.Message.Content = builder.String()
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			choice.Message.Content = ""
 | 
			
		||||
 
 | 
			
		||||
@@ -65,10 +65,12 @@ type ChatTools struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatGenerationConfig struct {
 | 
			
		||||
	Temperature     float64  `json:"temperature,omitempty"`
 | 
			
		||||
	TopP            float64  `json:"topP,omitempty"`
 | 
			
		||||
	TopK            float64  `json:"topK,omitempty"`
 | 
			
		||||
	MaxOutputTokens int      `json:"maxOutputTokens,omitempty"`
 | 
			
		||||
	CandidateCount  int      `json:"candidateCount,omitempty"`
 | 
			
		||||
	StopSequences   []string `json:"stopSequences,omitempty"`
 | 
			
		||||
	ResponseMimeType string   `json:"responseMimeType,omitempty"`
 | 
			
		||||
	ResponseSchema   any      `json:"responseSchema,omitempty"`
 | 
			
		||||
	Temperature      *float64 `json:"temperature,omitempty"`
 | 
			
		||||
	TopP             *float64 `json:"topP,omitempty"`
 | 
			
		||||
	TopK             float64  `json:"topK,omitempty"`
 | 
			
		||||
	MaxOutputTokens  int      `json:"maxOutputTokens,omitempty"`
 | 
			
		||||
	CandidateCount   int      `json:"candidateCount,omitempty"`
 | 
			
		||||
	StopSequences    []string `json:"stopSequences,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,9 +4,24 @@ package groq
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"gemma-7b-it",
 | 
			
		||||
	"llama2-7b-2048",
 | 
			
		||||
	"llama2-70b-4096",
 | 
			
		||||
	"mixtral-8x7b-32768",
 | 
			
		||||
	"llama3-8b-8192",
 | 
			
		||||
	"gemma2-9b-it",
 | 
			
		||||
	"llama-3.1-70b-versatile",
 | 
			
		||||
	"llama-3.1-8b-instant",
 | 
			
		||||
	"llama-3.2-11b-text-preview",
 | 
			
		||||
	"llama-3.2-11b-vision-preview",
 | 
			
		||||
	"llama-3.2-1b-preview",
 | 
			
		||||
	"llama-3.2-3b-preview",
 | 
			
		||||
	"llama-3.2-11b-vision-preview",
 | 
			
		||||
	"llama-3.2-90b-text-preview",
 | 
			
		||||
	"llama-3.2-90b-vision-preview",
 | 
			
		||||
	"llama-guard-3-8b",
 | 
			
		||||
	"llama3-70b-8192",
 | 
			
		||||
	"llama3-8b-8192",
 | 
			
		||||
	"llama3-groq-70b-8192-tool-use-preview",
 | 
			
		||||
	"llama3-groq-8b-8192-tool-use-preview",
 | 
			
		||||
	"llava-v1.5-7b-4096-preview",
 | 
			
		||||
	"mixtral-8x7b-32768",
 | 
			
		||||
	"distil-whisper-large-v3-en",
 | 
			
		||||
	"whisper-large-v3",
 | 
			
		||||
	"whisper-large-v3-turbo",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
 | 
			
		||||
	// https://github.com/ollama/ollama/blob/main/docs/api.md
 | 
			
		||||
	fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
 | 
			
		||||
	if meta.Mode == relaymode.Embeddings {
 | 
			
		||||
		fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
 | 
			
		||||
		fullRequestURL = fmt.Sprintf("%s/api/embed", meta.BaseURL)
 | 
			
		||||
	}
 | 
			
		||||
	return fullRequestURL, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -31,6 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
			TopP:             request.TopP,
 | 
			
		||||
			FrequencyPenalty: request.FrequencyPenalty,
 | 
			
		||||
			PresencePenalty:  request.PresencePenalty,
 | 
			
		||||
			NumPredict:       request.MaxTokens,
 | 
			
		||||
			NumCtx:           request.NumCtx,
 | 
			
		||||
		},
 | 
			
		||||
		Stream: request.Stream,
 | 
			
		||||
	}
 | 
			
		||||
@@ -118,8 +120,10 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := strings.TrimPrefix(scanner.Text(), "}")
 | 
			
		||||
		data = data + "}"
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if strings.HasPrefix(data, "}") {
 | 
			
		||||
			data = strings.TrimPrefix(data, "}") + "}"
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var ollamaResponse ChatResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &ollamaResponse)
 | 
			
		||||
@@ -157,8 +161,15 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
 | 
			
		||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
 | 
			
		||||
	return &EmbeddingRequest{
 | 
			
		||||
		Model:  request.Model,
 | 
			
		||||
		Prompt: strings.Join(request.ParseInput(), " "),
 | 
			
		||||
		Model: request.Model,
 | 
			
		||||
		Input: request.ParseInput(),
 | 
			
		||||
		Options: &Options{
 | 
			
		||||
			Seed:             int(request.Seed),
 | 
			
		||||
			Temperature:      request.Temperature,
 | 
			
		||||
			TopP:             request.TopP,
 | 
			
		||||
			FrequencyPenalty: request.FrequencyPenalty,
 | 
			
		||||
			PresencePenalty:  request.PresencePenalty,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -201,15 +212,17 @@ func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.Embeddi
 | 
			
		||||
	openAIEmbeddingResponse := openai.EmbeddingResponse{
 | 
			
		||||
		Object: "list",
 | 
			
		||||
		Data:   make([]openai.EmbeddingResponseItem, 0, 1),
 | 
			
		||||
		Model:  "text-embedding-v1",
 | 
			
		||||
		Model:  response.Model,
 | 
			
		||||
		Usage:  model.Usage{TotalTokens: 0},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
 | 
			
		||||
		Object:    `embedding`,
 | 
			
		||||
		Index:     0,
 | 
			
		||||
		Embedding: response.Embedding,
 | 
			
		||||
	})
 | 
			
		||||
	for i, embedding := range response.Embeddings {
 | 
			
		||||
		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
 | 
			
		||||
			Object:    `embedding`,
 | 
			
		||||
			Index:     i,
 | 
			
		||||
			Embedding: embedding,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	return &openAIEmbeddingResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,12 +1,14 @@
 | 
			
		||||
package ollama
 | 
			
		||||
 | 
			
		||||
type Options struct {
 | 
			
		||||
	Seed             int     `json:"seed,omitempty"`
 | 
			
		||||
	Temperature      float64 `json:"temperature,omitempty"`
 | 
			
		||||
	TopK             int     `json:"top_k,omitempty"`
 | 
			
		||||
	TopP             float64 `json:"top_p,omitempty"`
 | 
			
		||||
	FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
 | 
			
		||||
	PresencePenalty  float64 `json:"presence_penalty,omitempty"`
 | 
			
		||||
	Seed             int      `json:"seed,omitempty"`
 | 
			
		||||
	Temperature      *float64 `json:"temperature,omitempty"`
 | 
			
		||||
	TopK             int      `json:"top_k,omitempty"`
 | 
			
		||||
	TopP             *float64 `json:"top_p,omitempty"`
 | 
			
		||||
	FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
 | 
			
		||||
	PresencePenalty  *float64 `json:"presence_penalty,omitempty"`
 | 
			
		||||
	NumPredict       int      `json:"num_predict,omitempty"`
 | 
			
		||||
	NumCtx           int      `json:"num_ctx,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
@@ -37,11 +39,15 @@ type ChatResponse struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddingRequest struct {
 | 
			
		||||
	Model  string `json:"model"`
 | 
			
		||||
	Prompt string `json:"prompt"`
 | 
			
		||||
	Model string   `json:"model"`
 | 
			
		||||
	Input []string `json:"input"`
 | 
			
		||||
	// Truncate  bool     `json:"truncate,omitempty"`
 | 
			
		||||
	Options *Options `json:"options,omitempty"`
 | 
			
		||||
	// KeepAlive string   `json:"keep_alive,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddingResponse struct {
 | 
			
		||||
	Error     string    `json:"error,omitempty"`
 | 
			
		||||
	Embedding []float64 `json:"embedding,omitempty"`
 | 
			
		||||
	Error      string      `json:"error,omitempty"`
 | 
			
		||||
	Model      string      `json:"model"`
 | 
			
		||||
	Embeddings [][]float64 `json:"embeddings"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -75,6 +75,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
	if request.Stream {
 | 
			
		||||
		// always return usage in stream mode
 | 
			
		||||
		if request.StreamOptions == nil {
 | 
			
		||||
			request.StreamOptions = &model.StreamOptions{}
 | 
			
		||||
		}
 | 
			
		||||
		request.StreamOptions.IncludeUsage = true
 | 
			
		||||
	}
 | 
			
		||||
	return request, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -11,8 +11,10 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/mistral"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/moonshot"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/novita"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/siliconflow"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/stepfun"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/togetherai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/xai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channeltype"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -30,6 +32,8 @@ var CompatibleChannels = []int{
 | 
			
		||||
	channeltype.DeepSeek,
 | 
			
		||||
	channeltype.TogetherAI,
 | 
			
		||||
	channeltype.Novita,
 | 
			
		||||
	channeltype.SiliconFlow,
 | 
			
		||||
	channeltype.XAI,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetCompatibleChannelMeta(channelType int) (string, []string) {
 | 
			
		||||
@@ -60,6 +64,10 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
 | 
			
		||||
		return "doubao", doubao.ModelList
 | 
			
		||||
	case channeltype.Novita:
 | 
			
		||||
		return "novita", novita.ModelList
 | 
			
		||||
	case channeltype.SiliconFlow:
 | 
			
		||||
		return "siliconflow", siliconflow.ModelList
 | 
			
		||||
	case channeltype.XAI:
 | 
			
		||||
		return "xai", xai.ModelList
 | 
			
		||||
	default:
 | 
			
		||||
		return "openai", ModelList
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,10 @@ var ModelList = []string{
 | 
			
		||||
	"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
 | 
			
		||||
	"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
 | 
			
		||||
	"gpt-4o", "gpt-4o-2024-05-13",
 | 
			
		||||
	"gpt-4o-2024-08-06",
 | 
			
		||||
	"gpt-4o-2024-11-20",
 | 
			
		||||
	"chatgpt-4o-latest",
 | 
			
		||||
	"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
 | 
			
		||||
	"gpt-4-vision-preview",
 | 
			
		||||
	"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
 | 
			
		||||
	"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
 | 
			
		||||
@@ -17,4 +21,7 @@ var ModelList = []string{
 | 
			
		||||
	"dall-e-2", "dall-e-3",
 | 
			
		||||
	"whisper-1",
 | 
			
		||||
	"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
 | 
			
		||||
	"o1", "o1-2024-12-17",
 | 
			
		||||
	"o1-preview", "o1-preview-2024-09-12",
 | 
			
		||||
	"o1-mini", "o1-mini-2024-09-12",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,15 +2,16 @@ package openai
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channeltype"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
 | 
			
		||||
func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage {
 | 
			
		||||
	usage := &model.Usage{}
 | 
			
		||||
	usage.PromptTokens = promptTokens
 | 
			
		||||
	usage.CompletionTokens = CountTokenText(responseText, modeName)
 | 
			
		||||
	usage.CompletionTokens = CountTokenText(responseText, modelName)
 | 
			
		||||
	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 | 
			
		||||
	return usage
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -55,8 +55,8 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
 | 
			
		||||
				render.StringData(c, data) // if error happened, pass the data to client
 | 
			
		||||
				continue                   // just ignore the error
 | 
			
		||||
			}
 | 
			
		||||
			if len(streamResponse.Choices) == 0 {
 | 
			
		||||
				// but for empty choice, we should not pass it to client, this is for azure
 | 
			
		||||
			if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
 | 
			
		||||
				// but for empty choice and no usage, we should not pass it to client, this is for azure
 | 
			
		||||
				continue // just ignore empty choice
 | 
			
		||||
			}
 | 
			
		||||
			render.StringData(c, data)
 | 
			
		||||
 
 | 
			
		||||
@@ -97,7 +97,11 @@ func CountTokenMessages(messages []model.Message, model string) int {
 | 
			
		||||
				m := it.(map[string]any)
 | 
			
		||||
				switch m["type"] {
 | 
			
		||||
				case "text":
 | 
			
		||||
					tokenNum += getTokenNum(tokenEncoder, m["text"].(string))
 | 
			
		||||
					if textValue, ok := m["text"]; ok {
 | 
			
		||||
						if textString, ok := textValue.(string); ok {
 | 
			
		||||
							tokenNum += getTokenNum(tokenEncoder, textString)
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				case "image_url":
 | 
			
		||||
					imageUrl, ok := m["image_url"].(map[string]any)
 | 
			
		||||
					if ok {
 | 
			
		||||
@@ -106,7 +110,7 @@ func CountTokenMessages(messages []model.Message, model string) int {
 | 
			
		||||
						if imageUrl["detail"] != nil {
 | 
			
		||||
							detail = imageUrl["detail"].(string)
 | 
			
		||||
						}
 | 
			
		||||
						imageTokens, err := countImageTokens(url, detail)
 | 
			
		||||
						imageTokens, err := countImageTokens(url, detail, model)
 | 
			
		||||
						if err != nil {
 | 
			
		||||
							logger.SysError("error counting image tokens: " + err.Error())
 | 
			
		||||
						} else {
 | 
			
		||||
@@ -130,11 +134,15 @@ const (
 | 
			
		||||
	lowDetailCost         = 85
 | 
			
		||||
	highDetailCostPerTile = 170
 | 
			
		||||
	additionalCost        = 85
 | 
			
		||||
	// gpt-4o-mini cost higher than other model
 | 
			
		||||
	gpt4oMiniLowDetailCost  = 2833
 | 
			
		||||
	gpt4oMiniHighDetailCost = 5667
 | 
			
		||||
	gpt4oMiniAdditionalCost = 2833
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://platform.openai.com/docs/guides/vision/calculating-costs
 | 
			
		||||
// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb
 | 
			
		||||
func countImageTokens(url string, detail string) (_ int, err error) {
 | 
			
		||||
func countImageTokens(url string, detail string, model string) (_ int, err error) {
 | 
			
		||||
	var fetchSize = true
 | 
			
		||||
	var width, height int
 | 
			
		||||
	// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding
 | 
			
		||||
@@ -168,6 +176,9 @@ func countImageTokens(url string, detail string) (_ int, err error) {
 | 
			
		||||
	}
 | 
			
		||||
	switch detail {
 | 
			
		||||
	case "low":
 | 
			
		||||
		if strings.HasPrefix(model, "gpt-4o-mini") {
 | 
			
		||||
			return gpt4oMiniLowDetailCost, nil
 | 
			
		||||
		}
 | 
			
		||||
		return lowDetailCost, nil
 | 
			
		||||
	case "high":
 | 
			
		||||
		if fetchSize {
 | 
			
		||||
@@ -187,6 +198,9 @@ func countImageTokens(url string, detail string) (_ int, err error) {
 | 
			
		||||
			height = int(float64(height) * ratio)
 | 
			
		||||
		}
 | 
			
		||||
		numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512))
 | 
			
		||||
		if strings.HasPrefix(model, "gpt-4o-mini") {
 | 
			
		||||
			return numSquares*gpt4oMiniHighDetailCost + gpt4oMiniAdditionalCost, nil
 | 
			
		||||
		}
 | 
			
		||||
		result := numSquares*highDetailCostPerTile + additionalCost
 | 
			
		||||
		return result, nil
 | 
			
		||||
	default:
 | 
			
		||||
 
 | 
			
		||||
@@ -1,8 +1,16 @@
 | 
			
		||||
package openai
 | 
			
		||||
 | 
			
		||||
import "github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
 | 
			
		||||
	logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err))
 | 
			
		||||
 | 
			
		||||
	Error := model.Error{
 | 
			
		||||
		Message: err.Error(),
 | 
			
		||||
		Type:    "one_api_error",
 | 
			
		||||
 
 | 
			
		||||
@@ -19,11 +19,11 @@ type Prompt struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatRequest struct {
 | 
			
		||||
	Prompt         Prompt  `json:"prompt"`
 | 
			
		||||
	Temperature    float64 `json:"temperature,omitempty"`
 | 
			
		||||
	CandidateCount int     `json:"candidateCount,omitempty"`
 | 
			
		||||
	TopP           float64 `json:"topP,omitempty"`
 | 
			
		||||
	TopK           int     `json:"topK,omitempty"`
 | 
			
		||||
	Prompt         Prompt   `json:"prompt"`
 | 
			
		||||
	Temperature    *float64 `json:"temperature,omitempty"`
 | 
			
		||||
	CandidateCount int      `json:"candidateCount,omitempty"`
 | 
			
		||||
	TopP           *float64 `json:"topP,omitempty"`
 | 
			
		||||
	TopK           int      `json:"topK,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Error struct {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										89
									
								
								relay/adaptor/proxy/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								relay/adaptor/proxy/adaptor.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,89 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	relaymodel "github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ adaptor.Adaptor = new(Adaptor)
 | 
			
		||||
 | 
			
		||||
const channelName = "proxy"
 | 
			
		||||
 | 
			
		||||
type Adaptor struct{}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(meta *meta.Meta) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
 | 
			
		||||
	return nil, errors.New("notimplement")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
 | 
			
		||||
	for k, v := range resp.Header {
 | 
			
		||||
		for _, vv := range v {
 | 
			
		||||
			c.Writer.Header().Set(k, vv)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	if _, gerr := io.Copy(c.Writer, resp.Body); gerr != nil {
 | 
			
		||||
		return nil, &relaymodel.ErrorWithStatusCode{
 | 
			
		||||
			StatusCode: http.StatusInternalServerError,
 | 
			
		||||
			Error: relaymodel.Error{
 | 
			
		||||
				Message: gerr.Error(),
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetModelList() (models []string) {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetChannelName() string {
 | 
			
		||||
	return channelName
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetRequestURL remove static prefix, and return the real request url to the upstream service
 | 
			
		||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
 | 
			
		||||
	prefix := fmt.Sprintf("/v1/oneapi/proxy/%d", meta.ChannelId)
 | 
			
		||||
	return meta.BaseURL + strings.TrimPrefix(meta.RequestURLPath, prefix), nil
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
 | 
			
		||||
	for k, v := range c.Request.Header {
 | 
			
		||||
		req.Header.Set(k, v[0])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// remove unnecessary headers
 | 
			
		||||
	req.Header.Del("Host")
 | 
			
		||||
	req.Header.Del("Content-Length")
 | 
			
		||||
	req.Header.Del("Accept-Encoding")
 | 
			
		||||
	req.Header.Del("Connection")
 | 
			
		||||
 | 
			
		||||
	// set authorization header
 | 
			
		||||
	req.Header.Set("Authorization", meta.APIKey)
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
	return nil, errors.Errorf("not implement")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return channelhelper.DoRequestHelper(a, c, meta, requestBody)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										136
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,136 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/relaymode"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
	meta *meta.Meta
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConvertImageRequest implements adaptor.Adaptor.
 | 
			
		||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
	return DrawImageRequest{
 | 
			
		||||
		Input: ImageInput{
 | 
			
		||||
			Steps:           25,
 | 
			
		||||
			Prompt:          request.Prompt,
 | 
			
		||||
			Guidance:        3,
 | 
			
		||||
			Seed:            int(time.Now().UnixNano()),
 | 
			
		||||
			SafetyTolerance: 5,
 | 
			
		||||
			NImages:         1, // replicate will always return 1 image
 | 
			
		||||
			Width:           1440,
 | 
			
		||||
			Height:          1440,
 | 
			
		||||
			AspectRatio:     "1:1",
 | 
			
		||||
		},
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
 | 
			
		||||
	if !request.Stream {
 | 
			
		||||
		// TODO: support non-stream mode
 | 
			
		||||
		return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Build the prompt from OpenAI messages
 | 
			
		||||
	var promptBuilder strings.Builder
 | 
			
		||||
	for _, message := range request.Messages {
 | 
			
		||||
		switch msgCnt := message.Content.(type) {
 | 
			
		||||
		case string:
 | 
			
		||||
			promptBuilder.WriteString(message.Role)
 | 
			
		||||
			promptBuilder.WriteString(": ")
 | 
			
		||||
			promptBuilder.WriteString(msgCnt)
 | 
			
		||||
			promptBuilder.WriteString("\n")
 | 
			
		||||
		default:
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	replicateRequest := ReplicateChatRequest{
 | 
			
		||||
		Input: ChatInput{
 | 
			
		||||
			Prompt:           promptBuilder.String(),
 | 
			
		||||
			MaxTokens:        request.MaxTokens,
 | 
			
		||||
			Temperature:      1.0,
 | 
			
		||||
			TopP:             1.0,
 | 
			
		||||
			PresencePenalty:  0.0,
 | 
			
		||||
			FrequencyPenalty: 0.0,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Map optional fields
 | 
			
		||||
	if request.Temperature != nil {
 | 
			
		||||
		replicateRequest.Input.Temperature = *request.Temperature
 | 
			
		||||
	}
 | 
			
		||||
	if request.TopP != nil {
 | 
			
		||||
		replicateRequest.Input.TopP = *request.TopP
 | 
			
		||||
	}
 | 
			
		||||
	if request.PresencePenalty != nil {
 | 
			
		||||
		replicateRequest.Input.PresencePenalty = *request.PresencePenalty
 | 
			
		||||
	}
 | 
			
		||||
	if request.FrequencyPenalty != nil {
 | 
			
		||||
		replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty
 | 
			
		||||
	}
 | 
			
		||||
	if request.MaxTokens > 0 {
 | 
			
		||||
		replicateRequest.Input.MaxTokens = request.MaxTokens
 | 
			
		||||
	} else if request.MaxTokens == 0 {
 | 
			
		||||
		replicateRequest.Input.MaxTokens = 500
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return replicateRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(meta *meta.Meta) {
 | 
			
		||||
	a.meta = meta
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
 | 
			
		||||
	if !slices.Contains(ModelList, meta.OriginModelName) {
 | 
			
		||||
		return "", errors.Errorf("model %s not supported", meta.OriginModelName)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
 | 
			
		||||
	adaptor.SetupCommonRequestHeader(c, req, meta)
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+meta.APIKey)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	logger.Info(c, "send request to replicate")
 | 
			
		||||
	return adaptor.DoRequestHelper(a, c, meta, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
 | 
			
		||||
	switch meta.Mode {
 | 
			
		||||
	case relaymode.ImagesGenerations:
 | 
			
		||||
		err, usage = ImageHandler(c, resp)
 | 
			
		||||
	case relaymode.ChatCompletions:
 | 
			
		||||
		err, usage = ChatHandler(c, resp)
 | 
			
		||||
	default:
 | 
			
		||||
		err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetModelList() []string {
 | 
			
		||||
	return ModelList
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetChannelName() string {
 | 
			
		||||
	return "replicate"
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										191
									
								
								relay/adaptor/replicate/chat.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								relay/adaptor/replicate/chat.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,191 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ChatHandler(c *gin.Context, resp *http.Response) (
 | 
			
		||||
	srvErr *model.ErrorWithStatusCode, usage *model.Usage) {
 | 
			
		||||
	if resp.StatusCode != http.StatusCreated {
 | 
			
		||||
		payload, _ := io.ReadAll(resp.Body)
 | 
			
		||||
		return openai.ErrorWrapper(
 | 
			
		||||
				errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
 | 
			
		||||
				"bad_status_code", http.StatusInternalServerError),
 | 
			
		||||
			nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respData := new(ChatResponse)
 | 
			
		||||
	if err = json.Unmarshal(respBody, respData); err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		err = func() error {
 | 
			
		||||
			// get task
 | 
			
		||||
			taskReq, err := http.NewRequestWithContext(c.Request.Context(),
 | 
			
		||||
				http.MethodGet, respData.URLs.Get, nil)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "new request")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
 | 
			
		||||
			taskResp, err := http.DefaultClient.Do(taskReq)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "get task")
 | 
			
		||||
			}
 | 
			
		||||
			defer taskResp.Body.Close()
 | 
			
		||||
 | 
			
		||||
			if taskResp.StatusCode != http.StatusOK {
 | 
			
		||||
				payload, _ := io.ReadAll(taskResp.Body)
 | 
			
		||||
				return errors.Errorf("bad status code [%d]%s",
 | 
			
		||||
					taskResp.StatusCode, string(payload))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskBody, err := io.ReadAll(taskResp.Body)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "read task response")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskData := new(ChatResponse)
 | 
			
		||||
			if err = json.Unmarshal(taskBody, taskData); err != nil {
 | 
			
		||||
				return errors.Wrap(err, "decode task response")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			switch taskData.Status {
 | 
			
		||||
			case "succeeded":
 | 
			
		||||
			case "failed", "canceled":
 | 
			
		||||
				return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
 | 
			
		||||
			default:
 | 
			
		||||
				time.Sleep(time.Second * 3)
 | 
			
		||||
				return errNextLoop
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if taskData.URLs.Stream == "" {
 | 
			
		||||
				return errors.New("stream url is empty")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// request stream url
 | 
			
		||||
			responseText, err := chatStreamHandler(c, taskData.URLs.Stream)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "chat stream handler")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			ctxMeta := meta.GetByContext(c)
 | 
			
		||||
			usage = openai.ResponseText2Usage(responseText,
 | 
			
		||||
				ctxMeta.ActualModelName, ctxMeta.PromptTokens)
 | 
			
		||||
			return nil
 | 
			
		||||
		}()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if errors.Is(err, errNextLoop) {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		break
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	eventPrefix = "event: "
 | 
			
		||||
	dataPrefix  = "data: "
 | 
			
		||||
	done        = "[DONE]"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) {
 | 
			
		||||
	// request stream endpoint
 | 
			
		||||
	streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.Wrap(err, "new request to stream")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
 | 
			
		||||
	streamReq.Header.Set("Accept", "text/event-stream")
 | 
			
		||||
	streamReq.Header.Set("Cache-Control", "no-store")
 | 
			
		||||
 | 
			
		||||
	resp, err := http.DefaultClient.Do(streamReq)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.Wrap(err, "do request to stream")
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		payload, _ := io.ReadAll(resp.Body)
 | 
			
		||||
		return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	doneRendered := false
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		line := strings.TrimSpace(scanner.Text())
 | 
			
		||||
		if line == "" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Handle comments starting with ':'
 | 
			
		||||
		if strings.HasPrefix(line, ":") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Parse SSE fields
 | 
			
		||||
		if strings.HasPrefix(line, eventPrefix) {
 | 
			
		||||
			event := strings.TrimSpace(line[len(eventPrefix):])
 | 
			
		||||
			var data string
 | 
			
		||||
			// Read the following lines to get data and id
 | 
			
		||||
			for scanner.Scan() {
 | 
			
		||||
				nextLine := scanner.Text()
 | 
			
		||||
				if nextLine == "" {
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
				if strings.HasPrefix(nextLine, dataPrefix) {
 | 
			
		||||
					data = nextLine[len(dataPrefix):]
 | 
			
		||||
				} else if strings.HasPrefix(nextLine, "id:") {
 | 
			
		||||
					// id = strings.TrimSpace(nextLine[len("id:"):])
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if event == "output" {
 | 
			
		||||
				render.StringData(c, data)
 | 
			
		||||
				responseText += data
 | 
			
		||||
			} else if event == "done" {
 | 
			
		||||
				render.Done(c)
 | 
			
		||||
				doneRendered = true
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		return "", errors.Wrap(err, "scan stream")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !doneRendered {
 | 
			
		||||
		render.Done(c)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return responseText, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,58 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
// ModelList is a list of models that can be used with Replicate.
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/pricing
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// image model
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro",
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro-ultra",
 | 
			
		||||
	"black-forest-labs/flux-canny-dev",
 | 
			
		||||
	"black-forest-labs/flux-canny-pro",
 | 
			
		||||
	"black-forest-labs/flux-depth-dev",
 | 
			
		||||
	"black-forest-labs/flux-depth-pro",
 | 
			
		||||
	"black-forest-labs/flux-dev",
 | 
			
		||||
	"black-forest-labs/flux-dev-lora",
 | 
			
		||||
	"black-forest-labs/flux-fill-dev",
 | 
			
		||||
	"black-forest-labs/flux-fill-pro",
 | 
			
		||||
	"black-forest-labs/flux-pro",
 | 
			
		||||
	"black-forest-labs/flux-redux-dev",
 | 
			
		||||
	"black-forest-labs/flux-redux-schnell",
 | 
			
		||||
	"black-forest-labs/flux-schnell",
 | 
			
		||||
	"black-forest-labs/flux-schnell-lora",
 | 
			
		||||
	"ideogram-ai/ideogram-v2",
 | 
			
		||||
	"ideogram-ai/ideogram-v2-turbo",
 | 
			
		||||
	"recraft-ai/recraft-v3",
 | 
			
		||||
	"recraft-ai/recraft-v3-svg",
 | 
			
		||||
	"stability-ai/stable-diffusion-3",
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large",
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large-turbo",
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-medium",
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// language model
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	"ibm-granite/granite-20b-code-instruct-8k",
 | 
			
		||||
	"ibm-granite/granite-3.0-2b-instruct",
 | 
			
		||||
	"ibm-granite/granite-3.0-8b-instruct",
 | 
			
		||||
	"ibm-granite/granite-8b-code-instruct-128k",
 | 
			
		||||
	"meta/llama-2-13b",
 | 
			
		||||
	"meta/llama-2-13b-chat",
 | 
			
		||||
	"meta/llama-2-70b",
 | 
			
		||||
	"meta/llama-2-70b-chat",
 | 
			
		||||
	"meta/llama-2-7b",
 | 
			
		||||
	"meta/llama-2-7b-chat",
 | 
			
		||||
	"meta/meta-llama-3.1-405b-instruct",
 | 
			
		||||
	"meta/meta-llama-3-70b",
 | 
			
		||||
	"meta/meta-llama-3-70b-instruct",
 | 
			
		||||
	"meta/meta-llama-3-8b",
 | 
			
		||||
	"meta/meta-llama-3-8b-instruct",
 | 
			
		||||
	"mistralai/mistral-7b-instruct-v0.2",
 | 
			
		||||
	"mistralai/mistral-7b-v0.1",
 | 
			
		||||
	"mistralai/mixtral-8x7b-instruct-v0.1",
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// video model
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// "minimax/video-01",  // TODO: implement the adaptor
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										222
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										222
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,222 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"image"
 | 
			
		||||
	"image/png"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"golang.org/x/image/webp"
 | 
			
		||||
	"golang.org/x/sync/errgroup"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ImagesEditsHandler just copy response body to client
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-fill-pro
 | 
			
		||||
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
// 	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
// 	for k, v := range resp.Header {
 | 
			
		||||
// 		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
// 	}
 | 
			
		||||
 | 
			
		||||
// 	if _, err := io.Copy(c.Writer, resp.Body); err != nil {
 | 
			
		||||
// 		return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
// 	}
 | 
			
		||||
// 	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
// 	return nil, nil
 | 
			
		||||
// }
 | 
			
		||||
 | 
			
		||||
var errNextLoop = errors.New("next_loop")
 | 
			
		||||
 | 
			
		||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	if resp.StatusCode != http.StatusCreated {
 | 
			
		||||
		payload, _ := io.ReadAll(resp.Body)
 | 
			
		||||
		return openai.ErrorWrapper(
 | 
			
		||||
				errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
 | 
			
		||||
				"bad_status_code", http.StatusInternalServerError),
 | 
			
		||||
			nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respData := new(ImageResponse)
 | 
			
		||||
	if err = json.Unmarshal(respBody, respData); err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		err = func() error {
 | 
			
		||||
			// get task
 | 
			
		||||
			taskReq, err := http.NewRequestWithContext(c.Request.Context(),
 | 
			
		||||
				http.MethodGet, respData.URLs.Get, nil)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "new request")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
 | 
			
		||||
			taskResp, err := http.DefaultClient.Do(taskReq)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "get task")
 | 
			
		||||
			}
 | 
			
		||||
			defer taskResp.Body.Close()
 | 
			
		||||
 | 
			
		||||
			if taskResp.StatusCode != http.StatusOK {
 | 
			
		||||
				payload, _ := io.ReadAll(taskResp.Body)
 | 
			
		||||
				return errors.Errorf("bad status code [%d]%s",
 | 
			
		||||
					taskResp.StatusCode, string(payload))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskBody, err := io.ReadAll(taskResp.Body)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "read task response")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskData := new(ImageResponse)
 | 
			
		||||
			if err = json.Unmarshal(taskBody, taskData); err != nil {
 | 
			
		||||
				return errors.Wrap(err, "decode task response")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			switch taskData.Status {
 | 
			
		||||
			case "succeeded":
 | 
			
		||||
			case "failed", "canceled":
 | 
			
		||||
				return errors.Errorf("task failed: %s", taskData.Status)
 | 
			
		||||
			default:
 | 
			
		||||
				time.Sleep(time.Second * 3)
 | 
			
		||||
				return errNextLoop
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			output, err := taskData.GetOutput()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "get output")
 | 
			
		||||
			}
 | 
			
		||||
			if len(output) == 0 {
 | 
			
		||||
				return errors.New("response output is empty")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var mu sync.Mutex
 | 
			
		||||
			var pool errgroup.Group
 | 
			
		||||
			respBody := &openai.ImageResponse{
 | 
			
		||||
				Created: taskData.CompletedAt.Unix(),
 | 
			
		||||
				Data:    []openai.ImageData{},
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, imgOut := range output {
 | 
			
		||||
				imgOut := imgOut
 | 
			
		||||
				pool.Go(func() error {
 | 
			
		||||
					// download image
 | 
			
		||||
					downloadReq, err := http.NewRequestWithContext(c.Request.Context(),
 | 
			
		||||
						http.MethodGet, imgOut, nil)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "new request")
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					imgResp, err := http.DefaultClient.Do(downloadReq)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "download image")
 | 
			
		||||
					}
 | 
			
		||||
					defer imgResp.Body.Close()
 | 
			
		||||
 | 
			
		||||
					if imgResp.StatusCode != http.StatusOK {
 | 
			
		||||
						payload, _ := io.ReadAll(imgResp.Body)
 | 
			
		||||
						return errors.Errorf("bad status code [%d]%s",
 | 
			
		||||
							imgResp.StatusCode, string(payload))
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					imgData, err := io.ReadAll(imgResp.Body)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "read image")
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					imgData, err = ConvertImageToPNG(imgData)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "convert image")
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					mu.Lock()
 | 
			
		||||
					respBody.Data = append(respBody.Data, openai.ImageData{
 | 
			
		||||
						B64Json: fmt.Sprintf("data:image/png;base64,%s",
 | 
			
		||||
							base64.StdEncoding.EncodeToString(imgData)),
 | 
			
		||||
					})
 | 
			
		||||
					mu.Unlock()
 | 
			
		||||
 | 
			
		||||
					return nil
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err := pool.Wait(); err != nil {
 | 
			
		||||
				if len(respBody.Data) == 0 {
 | 
			
		||||
					return errors.WithStack(err)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			c.JSON(http.StatusOK, respBody)
 | 
			
		||||
			return nil
 | 
			
		||||
		}()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if errors.Is(err, errNextLoop) {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		break
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConvertImageToPNG converts a WebP image to PNG format
 | 
			
		||||
func ConvertImageToPNG(webpData []byte) ([]byte, error) {
 | 
			
		||||
	// bypass if it's already a PNG image
 | 
			
		||||
	if bytes.HasPrefix(webpData, []byte("\x89PNG")) {
 | 
			
		||||
		return webpData, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// check if is jpeg, convert to png
 | 
			
		||||
	if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) {
 | 
			
		||||
		img, _, err := image.Decode(bytes.NewReader(webpData))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, errors.Wrap(err, "decode jpeg")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var pngBuffer bytes.Buffer
 | 
			
		||||
		if err := png.Encode(&pngBuffer, img); err != nil {
 | 
			
		||||
			return nil, errors.Wrap(err, "encode png")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return pngBuffer.Bytes(), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Decode the WebP image
 | 
			
		||||
	img, err := webp.Decode(bytes.NewReader(webpData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "decode webp")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Encode the image as PNG
 | 
			
		||||
	var pngBuffer bytes.Buffer
 | 
			
		||||
	if err := png.Encode(&pngBuffer, img); err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "encode png")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return pngBuffer.Bytes(), nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										159
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,159 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// DrawImageRequest draw image by fluxpro
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
 | 
			
		||||
type DrawImageRequest struct {
 | 
			
		||||
	Input ImageInput `json:"input"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ImageInput is input of DrawImageByFluxProRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema
 | 
			
		||||
type ImageInput struct {
 | 
			
		||||
	Steps           int    `json:"steps" binding:"required,min=1"`
 | 
			
		||||
	Prompt          string `json:"prompt" binding:"required,min=5"`
 | 
			
		||||
	ImagePrompt     string `json:"image_prompt"`
 | 
			
		||||
	Guidance        int    `json:"guidance" binding:"required,min=2,max=5"`
 | 
			
		||||
	Interval        int    `json:"interval" binding:"required,min=1,max=4"`
 | 
			
		||||
	AspectRatio     string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"`
 | 
			
		||||
	SafetyTolerance int    `json:"safety_tolerance" binding:"required,min=1,max=5"`
 | 
			
		||||
	Seed            int    `json:"seed"`
 | 
			
		||||
	NImages         int    `json:"n_images" binding:"required,min=1,max=8"`
 | 
			
		||||
	Width           int    `json:"width" binding:"required,min=256,max=1440"`
 | 
			
		||||
	Height          int    `json:"height" binding:"required,min=256,max=1440"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
 | 
			
		||||
type InpaintingImageByFlusReplicateRequest struct {
 | 
			
		||||
	Input FluxInpaintingInput `json:"input"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FluxInpaintingInput is input of DrawImageByFluxProRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
 | 
			
		||||
type FluxInpaintingInput struct {
 | 
			
		||||
	Mask             string `json:"mask" binding:"required"`
 | 
			
		||||
	Image            string `json:"image" binding:"required"`
 | 
			
		||||
	Seed             int    `json:"seed"`
 | 
			
		||||
	Steps            int    `json:"steps" binding:"required,min=1"`
 | 
			
		||||
	Prompt           string `json:"prompt" binding:"required,min=5"`
 | 
			
		||||
	Guidance         int    `json:"guidance" binding:"required,min=2,max=5"`
 | 
			
		||||
	OutputFormat     string `json:"output_format"`
 | 
			
		||||
	SafetyTolerance  int    `json:"safety_tolerance" binding:"required,min=1,max=5"`
 | 
			
		||||
	PromptUnsampling bool   `json:"prompt_unsampling"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ImageResponse is response of DrawImageByFluxProRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
 | 
			
		||||
type ImageResponse struct {
 | 
			
		||||
	CompletedAt time.Time        `json:"completed_at"`
 | 
			
		||||
	CreatedAt   time.Time        `json:"created_at"`
 | 
			
		||||
	DataRemoved bool             `json:"data_removed"`
 | 
			
		||||
	Error       string           `json:"error"`
 | 
			
		||||
	ID          string           `json:"id"`
 | 
			
		||||
	Input       DrawImageRequest `json:"input"`
 | 
			
		||||
	Logs        string           `json:"logs"`
 | 
			
		||||
	Metrics     FluxMetrics      `json:"metrics"`
 | 
			
		||||
	// Output could be `string` or `[]string`
 | 
			
		||||
	Output    any       `json:"output"`
 | 
			
		||||
	StartedAt time.Time `json:"started_at"`
 | 
			
		||||
	Status    string    `json:"status"`
 | 
			
		||||
	URLs      FluxURLs  `json:"urls"`
 | 
			
		||||
	Version   string    `json:"version"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *ImageResponse) GetOutput() ([]string, error) {
 | 
			
		||||
	switch v := r.Output.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		return []string{v}, nil
 | 
			
		||||
	case []string:
 | 
			
		||||
		return v, nil
 | 
			
		||||
	case nil:
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	case []interface{}:
 | 
			
		||||
		// convert []interface{} to []string
 | 
			
		||||
		ret := make([]string, len(v))
 | 
			
		||||
		for idx, vv := range v {
 | 
			
		||||
			if vvv, ok := vv.(string); ok {
 | 
			
		||||
				ret[idx] = vvv
 | 
			
		||||
			} else {
 | 
			
		||||
				return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return ret, nil
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FluxMetrics is metrics of ImageResponse
 | 
			
		||||
type FluxMetrics struct {
 | 
			
		||||
	ImageCount  int     `json:"image_count"`
 | 
			
		||||
	PredictTime float64 `json:"predict_time"`
 | 
			
		||||
	TotalTime   float64 `json:"total_time"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FluxURLs is urls of ImageResponse
 | 
			
		||||
type FluxURLs struct {
 | 
			
		||||
	Get    string `json:"get"`
 | 
			
		||||
	Cancel string `json:"cancel"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ReplicateChatRequest struct {
 | 
			
		||||
	Input ChatInput `json:"input" form:"input" binding:"required"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChatInput is input of ChatByReplicateRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema
 | 
			
		||||
type ChatInput struct {
 | 
			
		||||
	TopK             int     `json:"top_k"`
 | 
			
		||||
	TopP             float64 `json:"top_p"`
 | 
			
		||||
	Prompt           string  `json:"prompt"`
 | 
			
		||||
	MaxTokens        int     `json:"max_tokens"`
 | 
			
		||||
	MinTokens        int     `json:"min_tokens"`
 | 
			
		||||
	Temperature      float64 `json:"temperature"`
 | 
			
		||||
	SystemPrompt     string  `json:"system_prompt"`
 | 
			
		||||
	StopSequences    string  `json:"stop_sequences"`
 | 
			
		||||
	PromptTemplate   string  `json:"prompt_template"`
 | 
			
		||||
	PresencePenalty  float64 `json:"presence_penalty"`
 | 
			
		||||
	FrequencyPenalty float64 `json:"frequency_penalty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChatResponse is response of ChatByReplicateRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json
 | 
			
		||||
type ChatResponse struct {
 | 
			
		||||
	CompletedAt time.Time   `json:"completed_at"`
 | 
			
		||||
	CreatedAt   time.Time   `json:"created_at"`
 | 
			
		||||
	DataRemoved bool        `json:"data_removed"`
 | 
			
		||||
	Error       string      `json:"error"`
 | 
			
		||||
	ID          string      `json:"id"`
 | 
			
		||||
	Input       ChatInput   `json:"input"`
 | 
			
		||||
	Logs        string      `json:"logs"`
 | 
			
		||||
	Metrics     FluxMetrics `json:"metrics"`
 | 
			
		||||
	// Output could be `string` or `[]string`
 | 
			
		||||
	Output    []string        `json:"output"`
 | 
			
		||||
	StartedAt time.Time       `json:"started_at"`
 | 
			
		||||
	Status    string          `json:"status"`
 | 
			
		||||
	URLs      ChatResponseUrl `json:"urls"`
 | 
			
		||||
	Version   string          `json:"version"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChatResponseUrl is task urls of ChatResponse
 | 
			
		||||
type ChatResponseUrl struct {
 | 
			
		||||
	Stream string `json:"stream"`
 | 
			
		||||
	Get    string `json:"get"`
 | 
			
		||||
	Cancel string `json:"cancel"`
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										36
									
								
								relay/adaptor/siliconflow/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								relay/adaptor/siliconflow/constants.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,36 @@
 | 
			
		||||
package siliconflow
 | 
			
		||||
 | 
			
		||||
// https://docs.siliconflow.cn/docs/getting-started
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"deepseek-ai/deepseek-llm-67b-chat",
 | 
			
		||||
	"Qwen/Qwen1.5-14B-Chat",
 | 
			
		||||
	"Qwen/Qwen1.5-7B-Chat",
 | 
			
		||||
	"Qwen/Qwen1.5-110B-Chat",
 | 
			
		||||
	"Qwen/Qwen1.5-32B-Chat",
 | 
			
		||||
	"01-ai/Yi-1.5-6B-Chat",
 | 
			
		||||
	"01-ai/Yi-1.5-9B-Chat-16K",
 | 
			
		||||
	"01-ai/Yi-1.5-34B-Chat-16K",
 | 
			
		||||
	"THUDM/chatglm3-6b",
 | 
			
		||||
	"deepseek-ai/DeepSeek-V2-Chat",
 | 
			
		||||
	"THUDM/glm-4-9b-chat",
 | 
			
		||||
	"Qwen/Qwen2-72B-Instruct",
 | 
			
		||||
	"Qwen/Qwen2-7B-Instruct",
 | 
			
		||||
	"Qwen/Qwen2-57B-A14B-Instruct",
 | 
			
		||||
	"deepseek-ai/DeepSeek-Coder-V2-Instruct",
 | 
			
		||||
	"Qwen/Qwen2-1.5B-Instruct",
 | 
			
		||||
	"internlm/internlm2_5-7b-chat",
 | 
			
		||||
	"BAAI/bge-large-en-v1.5",
 | 
			
		||||
	"BAAI/bge-large-zh-v1.5",
 | 
			
		||||
	"Pro/Qwen/Qwen2-7B-Instruct",
 | 
			
		||||
	"Pro/Qwen/Qwen2-1.5B-Instruct",
 | 
			
		||||
	"Pro/Qwen/Qwen1.5-7B-Chat",
 | 
			
		||||
	"Pro/THUDM/glm-4-9b-chat",
 | 
			
		||||
	"Pro/THUDM/chatglm3-6b",
 | 
			
		||||
	"Pro/01-ai/Yi-1.5-9B-Chat-16K",
 | 
			
		||||
	"Pro/01-ai/Yi-1.5-6B-Chat",
 | 
			
		||||
	"Pro/google/gemma-2-9b-it",
 | 
			
		||||
	"Pro/internlm/internlm2_5-7b-chat",
 | 
			
		||||
	"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
 | 
			
		||||
	"Pro/mistralai/Mistral-7B-Instruct-v0.2",
 | 
			
		||||
}
 | 
			
		||||
@@ -1,7 +1,13 @@
 | 
			
		||||
package stepfun
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"step-1-8k",
 | 
			
		||||
	"step-1-32k",
 | 
			
		||||
	"step-1-128k",
 | 
			
		||||
	"step-1-256k",
 | 
			
		||||
	"step-1-flash",
 | 
			
		||||
	"step-2-16k",
 | 
			
		||||
	"step-1v-8k",
 | 
			
		||||
	"step-1v-32k",
 | 
			
		||||
	"step-1-200k",
 | 
			
		||||
	"step-1x-medium",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,16 +2,19 @@ package tencent
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/relaymode"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://cloud.tencent.com/document/api/1729/101837
 | 
			
		||||
@@ -52,10 +55,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	tencentRequest := ConvertRequest(*request)
 | 
			
		||||
	var convertedRequest any
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case relaymode.Embeddings:
 | 
			
		||||
		a.Action = "GetEmbedding"
 | 
			
		||||
		convertedRequest = ConvertEmbeddingRequest(*request)
 | 
			
		||||
	default:
 | 
			
		||||
		a.Action = "ChatCompletions"
 | 
			
		||||
		convertedRequest = ConvertRequest(*request)
 | 
			
		||||
	}
 | 
			
		||||
	// we have to calculate the sign here
 | 
			
		||||
	a.Sign = GetSign(*tencentRequest, a, secretId, secretKey)
 | 
			
		||||
	return tencentRequest, nil
 | 
			
		||||
	a.Sign = GetSign(convertedRequest, a, secretId, secretKey)
 | 
			
		||||
	return convertedRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
@@ -75,7 +86,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
 | 
			
		||||
		err, responseText = StreamHandler(c, resp)
 | 
			
		||||
		usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = Handler(c, resp)
 | 
			
		||||
		switch meta.Mode {
 | 
			
		||||
		case relaymode.Embeddings:
 | 
			
		||||
			err, usage = EmbeddingHandler(c, resp)
 | 
			
		||||
		default:
 | 
			
		||||
			err, usage = Handler(c, resp)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -5,4 +5,6 @@ var ModelList = []string{
 | 
			
		||||
	"hunyuan-standard",
 | 
			
		||||
	"hunyuan-standard-256K",
 | 
			
		||||
	"hunyuan-pro",
 | 
			
		||||
	"hunyuan-vision",
 | 
			
		||||
	"hunyuan-embedding",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -8,7 +8,6 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@@ -16,11 +15,14 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/conv"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
@@ -39,13 +41,73 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
		Model:       &request.Model,
 | 
			
		||||
		Stream:      &request.Stream,
 | 
			
		||||
		Messages:    messages,
 | 
			
		||||
		TopP:        &request.TopP,
 | 
			
		||||
		Temperature: &request.Temperature,
 | 
			
		||||
		TopP:        request.TopP,
 | 
			
		||||
		Temperature: request.Temperature,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
 | 
			
		||||
	return &EmbeddingRequest{
 | 
			
		||||
		InputList: request.ParseInput(),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	var tencentResponseP EmbeddingResponseP
 | 
			
		||||
	err := json.NewDecoder(resp.Body).Decode(&tencentResponseP)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tencentResponse := tencentResponseP.Response
 | 
			
		||||
	if tencentResponse.Error.Code != "" {
 | 
			
		||||
		return &model.ErrorWithStatusCode{
 | 
			
		||||
			Error: model.Error{
 | 
			
		||||
				Message: tencentResponse.Error.Message,
 | 
			
		||||
				Code:    tencentResponse.Error.Code,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	requestModel := c.GetString(ctxkey.RequestModel)
 | 
			
		||||
	fullTextResponse := embeddingResponseTencent2OpenAI(&tencentResponse)
 | 
			
		||||
	fullTextResponse.Model = requestModel
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func embeddingResponseTencent2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
 | 
			
		||||
	openAIEmbeddingResponse := openai.EmbeddingResponse{
 | 
			
		||||
		Object: "list",
 | 
			
		||||
		Data:   make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
 | 
			
		||||
		Model:  "hunyuan-embedding",
 | 
			
		||||
		Usage:  model.Usage{TotalTokens: response.EmbeddingUsage.TotalTokens},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, item := range response.Data {
 | 
			
		||||
		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
 | 
			
		||||
			Object:    item.Object,
 | 
			
		||||
			Index:     item.Index,
 | 
			
		||||
			Embedding: item.Embedding,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	return &openAIEmbeddingResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
 | 
			
		||||
	fullTextResponse := openai.TextResponse{
 | 
			
		||||
		Id:      response.ReqID,
 | 
			
		||||
		Object:  "chat.completion",
 | 
			
		||||
		Created: helper.GetTimestamp(),
 | 
			
		||||
		Usage: model.Usage{
 | 
			
		||||
@@ -148,7 +210,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	TencentResponse = responseP.Response
 | 
			
		||||
	if TencentResponse.Error.Code != 0 {
 | 
			
		||||
	if TencentResponse.Error.Code != "" {
 | 
			
		||||
		return &model.ErrorWithStatusCode{
 | 
			
		||||
			Error: model.Error{
 | 
			
		||||
				Message: TencentResponse.Error.Message,
 | 
			
		||||
@@ -195,7 +257,7 @@ func hmacSha256(s, key string) string {
 | 
			
		||||
	return string(hashed.Sum(nil))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string {
 | 
			
		||||
func GetSign(req any, adaptor *Adaptor, secId, secKey string) string {
 | 
			
		||||
	// build canonical request string
 | 
			
		||||
	host := "hunyuan.tencentcloudapi.com"
 | 
			
		||||
	httpRequestMethod := "POST"
 | 
			
		||||
 
 | 
			
		||||
@@ -35,16 +35,16 @@ type ChatRequest struct {
 | 
			
		||||
	// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
 | 
			
		||||
	// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
 | 
			
		||||
	// 3. 非必要不建议使用,不合理的取值会影响效果。
 | 
			
		||||
	TopP *float64 `json:"TopP"`
 | 
			
		||||
	TopP *float64 `json:"TopP,omitempty"`
 | 
			
		||||
	// 说明:
 | 
			
		||||
	// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
 | 
			
		||||
	// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
 | 
			
		||||
	// 3. 非必要不建议使用,不合理的取值会影响效果。
 | 
			
		||||
	Temperature *float64 `json:"Temperature"`
 | 
			
		||||
	Temperature *float64 `json:"Temperature,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Error struct {
 | 
			
		||||
	Code    int    `json:"Code"`
 | 
			
		||||
	Code    string `json:"Code"`
 | 
			
		||||
	Message string `json:"Message"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -61,15 +61,41 @@ type ResponseChoices struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatResponse struct {
 | 
			
		||||
	Choices []ResponseChoices `json:"Choices,omitempty"` // 结果
 | 
			
		||||
	Created int64             `json:"Created,omitempty"` // unix 时间戳的字符串
 | 
			
		||||
	Id      string            `json:"Id,omitempty"`      // 会话 id
 | 
			
		||||
	Usage   Usage             `json:"Usage,omitempty"`   // token 数量
 | 
			
		||||
	Error   Error             `json:"Error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值
 | 
			
		||||
	Note    string            `json:"Note,omitempty"`    // 注释
 | 
			
		||||
	ReqID   string            `json:"Req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
 | 
			
		||||
	Choices []ResponseChoices `json:"Choices,omitempty"`   // 结果
 | 
			
		||||
	Created int64             `json:"Created,omitempty"`   // unix 时间戳的字符串
 | 
			
		||||
	Id      string            `json:"Id,omitempty"`        // 会话 id
 | 
			
		||||
	Usage   Usage             `json:"Usage,omitempty"`     // token 数量
 | 
			
		||||
	Error   Error             `json:"Error,omitempty"`     // 错误信息 注意:此字段可能返回 null,表示取不到有效值
 | 
			
		||||
	Note    string            `json:"Note,omitempty"`      // 注释
 | 
			
		||||
	ReqID   string            `json:"RequestId,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatResponseP struct {
 | 
			
		||||
	Response ChatResponse `json:"Response,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddingRequest struct {
 | 
			
		||||
	InputList []string `json:"InputList"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddingData struct {
 | 
			
		||||
	Embedding []float64 `json:"Embedding"`
 | 
			
		||||
	Index     int       `json:"Index"`
 | 
			
		||||
	Object    string    `json:"Object"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddingUsage struct {
 | 
			
		||||
	PromptTokens int `json:"PromptTokens"`
 | 
			
		||||
	TotalTokens  int `json:"TotalTokens"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddingResponse struct {
 | 
			
		||||
	Data           []EmbeddingData `json:"Data"`
 | 
			
		||||
	EmbeddingUsage EmbeddingUsage  `json:"Usage,omitempty"`
 | 
			
		||||
	RequestId      string          `json:"RequestId,omitempty"`
 | 
			
		||||
	Error          Error           `json:"Error,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddingResponseP struct {
 | 
			
		||||
	Response EmbeddingResponse `json:"Response,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										117
									
								
								relay/adaptor/vertexai/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								relay/adaptor/vertexai/adaptor.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,117 @@
 | 
			
		||||
package vertexai
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	relaymodel "github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ adaptor.Adaptor = new(Adaptor)
 | 
			
		||||
 | 
			
		||||
const channelName = "vertexai"
 | 
			
		||||
 | 
			
		||||
type Adaptor struct{}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(meta *meta.Meta) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	adaptor := GetAdaptor(request.Model)
 | 
			
		||||
	if adaptor == nil {
 | 
			
		||||
		return nil, errors.New("adaptor not found")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return adaptor.ConvertRequest(c, relayMode, request)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
 | 
			
		||||
	adaptor := GetAdaptor(meta.ActualModelName)
 | 
			
		||||
	if adaptor == nil {
 | 
			
		||||
		return nil, &relaymodel.ErrorWithStatusCode{
 | 
			
		||||
			StatusCode: http.StatusInternalServerError,
 | 
			
		||||
			Error: relaymodel.Error{
 | 
			
		||||
				Message: "adaptor not found",
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return adaptor.DoResponse(c, resp, meta)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetModelList() (models []string) {
 | 
			
		||||
	models = modelList
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetChannelName() string {
 | 
			
		||||
	return channelName
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
 | 
			
		||||
	suffix := ""
 | 
			
		||||
	if strings.HasPrefix(meta.ActualModelName, "gemini") {
 | 
			
		||||
		if meta.IsStream {
 | 
			
		||||
			suffix = "streamGenerateContent?alt=sse"
 | 
			
		||||
		} else {
 | 
			
		||||
			suffix = "generateContent"
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		if meta.IsStream {
 | 
			
		||||
			suffix = "streamRawPredict?alt=sse"
 | 
			
		||||
		} else {
 | 
			
		||||
			suffix = "rawPredict"
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if meta.BaseURL != "" {
 | 
			
		||||
		return fmt.Sprintf(
 | 
			
		||||
			"%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
 | 
			
		||||
			meta.BaseURL,
 | 
			
		||||
			meta.Config.VertexAIProjectID,
 | 
			
		||||
			meta.Config.Region,
 | 
			
		||||
			meta.ActualModelName,
 | 
			
		||||
			suffix,
 | 
			
		||||
		), nil
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf(
 | 
			
		||||
		"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
 | 
			
		||||
		meta.Config.Region,
 | 
			
		||||
		meta.Config.VertexAIProjectID,
 | 
			
		||||
		meta.Config.Region,
 | 
			
		||||
		meta.ActualModelName,
 | 
			
		||||
		suffix,
 | 
			
		||||
	), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
 | 
			
		||||
	adaptor.SetupCommonRequestHeader(c, req, meta)
 | 
			
		||||
	token, err := getToken(c, meta.ChannelId, meta.Config.VertexAIADC)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+token)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
	return request, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return channelhelper.DoRequestHelper(a, c, meta, requestBody)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										60
									
								
								relay/adaptor/vertexai/claude/adapter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								relay/adaptor/vertexai/claude/adapter.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,60 @@
 | 
			
		||||
package vertexai
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"claude-3-haiku@20240307",
 | 
			
		||||
	"claude-3-sonnet@20240229",
 | 
			
		||||
	"claude-3-opus@20240229",
 | 
			
		||||
	"claude-3-5-sonnet@20240620",
 | 
			
		||||
	"claude-3-5-sonnet-v2@20241022",
 | 
			
		||||
	"claude-3-5-haiku@20241022",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const anthropicVersion = "vertex-2023-10-16"
 | 
			
		||||
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	claudeReq := anthropic.ConvertRequest(*request)
 | 
			
		||||
	req := Request{
 | 
			
		||||
		AnthropicVersion: anthropicVersion,
 | 
			
		||||
		// Model:            claudeReq.Model,
 | 
			
		||||
		Messages:    claudeReq.Messages,
 | 
			
		||||
		System:      claudeReq.System,
 | 
			
		||||
		MaxTokens:   claudeReq.MaxTokens,
 | 
			
		||||
		Temperature: claudeReq.Temperature,
 | 
			
		||||
		TopP:        claudeReq.TopP,
 | 
			
		||||
		TopK:        claudeReq.TopK,
 | 
			
		||||
		Stream:      claudeReq.Stream,
 | 
			
		||||
		Tools:       claudeReq.Tools,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.Set(ctxkey.RequestModel, request.Model)
 | 
			
		||||
	c.Set(ctxkey.ConvertedRequest, req)
 | 
			
		||||
	return req, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
 | 
			
		||||
	if meta.IsStream {
 | 
			
		||||
		err, usage = anthropic.StreamHandler(c, resp)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = anthropic.Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										19
									
								
								relay/adaptor/vertexai/claude/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								relay/adaptor/vertexai/claude/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
package vertexai
 | 
			
		||||
 | 
			
		||||
import "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
 | 
			
		||||
 | 
			
		||||
type Request struct {
 | 
			
		||||
	// AnthropicVersion must be "vertex-2023-10-16"
 | 
			
		||||
	AnthropicVersion string `json:"anthropic_version"`
 | 
			
		||||
	// Model            string              `json:"model"`
 | 
			
		||||
	Messages      []anthropic.Message `json:"messages"`
 | 
			
		||||
	System        string              `json:"system,omitempty"`
 | 
			
		||||
	MaxTokens     int                 `json:"max_tokens,omitempty"`
 | 
			
		||||
	StopSequences []string            `json:"stop_sequences,omitempty"`
 | 
			
		||||
	Stream        bool                `json:"stream,omitempty"`
 | 
			
		||||
	Temperature   *float64            `json:"temperature,omitempty"`
 | 
			
		||||
	TopP          *float64            `json:"top_p,omitempty"`
 | 
			
		||||
	TopK          int                 `json:"top_k,omitempty"`
 | 
			
		||||
	Tools         []anthropic.Tool    `json:"tools,omitempty"`
 | 
			
		||||
	ToolChoice    any                 `json:"tool_choice,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										53
									
								
								relay/adaptor/vertexai/gemini/adapter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								relay/adaptor/vertexai/gemini/adapter.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,53 @@
 | 
			
		||||
package vertexai
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/gemini"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/relaymode"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"gemini-pro", "gemini-pro-vision",
 | 
			
		||||
	"gemini-1.5-pro-001", "gemini-1.5-flash-001",
 | 
			
		||||
	"gemini-1.5-pro-002", "gemini-1.5-flash-002",
 | 
			
		||||
	"gemini-2.0-flash-exp",
 | 
			
		||||
	"gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	geminiRequest := gemini.ConvertRequest(*request)
 | 
			
		||||
	c.Set(ctxkey.RequestModel, request.Model)
 | 
			
		||||
	c.Set(ctxkey.ConvertedRequest, geminiRequest)
 | 
			
		||||
	return geminiRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
 | 
			
		||||
	if meta.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = gemini.StreamHandler(c, resp)
 | 
			
		||||
		usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
 | 
			
		||||
	} else {
 | 
			
		||||
		switch meta.Mode {
 | 
			
		||||
		case relaymode.Embeddings:
 | 
			
		||||
			err, usage = gemini.EmbeddingHandler(c, resp)
 | 
			
		||||
		default:
 | 
			
		||||
			err, usage = gemini.Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										50
									
								
								relay/adaptor/vertexai/registry.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								relay/adaptor/vertexai/registry.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,50 @@
 | 
			
		||||
package vertexai
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	claude "github.com/songquanpeng/one-api/relay/adaptor/vertexai/claude"
 | 
			
		||||
	gemini "github.com/songquanpeng/one-api/relay/adaptor/vertexai/gemini"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type VertexAIModelType int
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	VerterAIClaude VertexAIModelType = iota + 1
 | 
			
		||||
	VerterAIGemini
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var modelMapping = map[string]VertexAIModelType{}
 | 
			
		||||
var modelList = []string{}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	modelList = append(modelList, claude.ModelList...)
 | 
			
		||||
	for _, model := range claude.ModelList {
 | 
			
		||||
		modelMapping[model] = VerterAIClaude
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	modelList = append(modelList, gemini.ModelList...)
 | 
			
		||||
	for _, model := range gemini.ModelList {
 | 
			
		||||
		modelMapping[model] = VerterAIGemini
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type innerAIAdapter interface {
 | 
			
		||||
	ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
 | 
			
		||||
	DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAdaptor(model string) innerAIAdapter {
 | 
			
		||||
	adaptorType := modelMapping[model]
 | 
			
		||||
	switch adaptorType {
 | 
			
		||||
	case VerterAIClaude:
 | 
			
		||||
		return &claude.Adaptor{}
 | 
			
		||||
	case VerterAIGemini:
 | 
			
		||||
		return &gemini.Adaptor{}
 | 
			
		||||
	default:
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										62
									
								
								relay/adaptor/vertexai/token.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								relay/adaptor/vertexai/token.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,62 @@
 | 
			
		||||
package vertexai
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	credentials "cloud.google.com/go/iam/credentials/apiv1"
 | 
			
		||||
	"cloud.google.com/go/iam/credentials/apiv1/credentialspb"
 | 
			
		||||
	"github.com/patrickmn/go-cache"
 | 
			
		||||
	"google.golang.org/api/option"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ApplicationDefaultCredentials struct {
 | 
			
		||||
	Type                    string `json:"type"`
 | 
			
		||||
	ProjectID               string `json:"project_id"`
 | 
			
		||||
	PrivateKeyID            string `json:"private_key_id"`
 | 
			
		||||
	PrivateKey              string `json:"private_key"`
 | 
			
		||||
	ClientEmail             string `json:"client_email"`
 | 
			
		||||
	ClientID                string `json:"client_id"`
 | 
			
		||||
	AuthURI                 string `json:"auth_uri"`
 | 
			
		||||
	TokenURI                string `json:"token_uri"`
 | 
			
		||||
	AuthProviderX509CertURL string `json:"auth_provider_x509_cert_url"`
 | 
			
		||||
	ClientX509CertURL       string `json:"client_x509_cert_url"`
 | 
			
		||||
	UniverseDomain          string `json:"universe_domain"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var Cache = cache.New(50*time.Minute, 55*time.Minute)
 | 
			
		||||
 | 
			
		||||
const defaultScope = "https://www.googleapis.com/auth/cloud-platform"
 | 
			
		||||
 | 
			
		||||
func getToken(ctx context.Context, channelId int, adcJson string) (string, error) {
 | 
			
		||||
	cacheKey := fmt.Sprintf("vertexai-token-%d", channelId)
 | 
			
		||||
	if token, found := Cache.Get(cacheKey); found {
 | 
			
		||||
		return token.(string), nil
 | 
			
		||||
	}
 | 
			
		||||
	adc := &ApplicationDefaultCredentials{}
 | 
			
		||||
	if err := json.Unmarshal([]byte(adcJson), adc); err != nil {
 | 
			
		||||
		return "", fmt.Errorf("Failed to decode credentials file: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c, err := credentials.NewIamCredentialsClient(ctx, option.WithCredentialsJSON([]byte(adcJson)))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("Failed to create client: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	defer c.Close()
 | 
			
		||||
 | 
			
		||||
	req := &credentialspb.GenerateAccessTokenRequest{
 | 
			
		||||
		// See https://pkg.go.dev/cloud.google.com/go/iam/credentials/apiv1/credentialspb#GenerateAccessTokenRequest.
 | 
			
		||||
		Name:  fmt.Sprintf("projects/-/serviceAccounts/%s", adc.ClientEmail),
 | 
			
		||||
		Scope: []string{defaultScope},
 | 
			
		||||
	}
 | 
			
		||||
	resp, err := c.GenerateAccessToken(ctx, req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("Failed to generate access token: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	_ = resp
 | 
			
		||||
 | 
			
		||||
	Cache.Set(cacheKey, resp.AccessToken, cache.DefaultExpiration)
 | 
			
		||||
	return resp.AccessToken, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										5
									
								
								relay/adaptor/xai/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								relay/adaptor/xai/constants.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
			
		||||
package xai
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"grok-beta",
 | 
			
		||||
}
 | 
			
		||||
@@ -5,6 +5,8 @@ var ModelList = []string{
 | 
			
		||||
	"SparkDesk-v1.1",
 | 
			
		||||
	"SparkDesk-v2.1",
 | 
			
		||||
	"SparkDesk-v3.1",
 | 
			
		||||
	"SparkDesk-v3.1-128K",
 | 
			
		||||
	"SparkDesk-v3.5",
 | 
			
		||||
	"SparkDesk-v3.5-32K",
 | 
			
		||||
	"SparkDesk-v4.0",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -272,9 +272,9 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseAPIVersionByModelName(modelName string) string {
 | 
			
		||||
	parts := strings.Split(modelName, "-")
 | 
			
		||||
	if len(parts) == 2 {
 | 
			
		||||
		return parts[1]
 | 
			
		||||
	index := strings.IndexAny(modelName, "-")
 | 
			
		||||
	if index != -1 {
 | 
			
		||||
		return modelName[index+1:]
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
@@ -283,13 +283,17 @@ func parseAPIVersionByModelName(modelName string) string {
 | 
			
		||||
func apiVersion2domain(apiVersion string) string {
 | 
			
		||||
	switch apiVersion {
 | 
			
		||||
	case "v1.1":
 | 
			
		||||
		return "general"
 | 
			
		||||
		return "lite"
 | 
			
		||||
	case "v2.1":
 | 
			
		||||
		return "generalv2"
 | 
			
		||||
	case "v3.1":
 | 
			
		||||
		return "generalv3"
 | 
			
		||||
	case "v3.1-128K":
 | 
			
		||||
		return "pro-128k"
 | 
			
		||||
	case "v3.5":
 | 
			
		||||
		return "generalv3.5"
 | 
			
		||||
	case "v3.5-32K":
 | 
			
		||||
		return "max-32k"
 | 
			
		||||
	case "v4.0":
 | 
			
		||||
		return "4.0Ultra"
 | 
			
		||||
	}
 | 
			
		||||
@@ -297,7 +301,17 @@ func apiVersion2domain(apiVersion string) string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) {
 | 
			
		||||
	var authUrl string
 | 
			
		||||
	domain := apiVersion2domain(apiVersion)
 | 
			
		||||
	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
 | 
			
		||||
	switch apiVersion {
 | 
			
		||||
	case "v3.1-128K":
 | 
			
		||||
		authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/pro-128k"), apiKey, apiSecret)
 | 
			
		||||
		break
 | 
			
		||||
	case "v3.5-32K":
 | 
			
		||||
		authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/max-32k"), apiKey, apiSecret)
 | 
			
		||||
		break
 | 
			
		||||
	default:
 | 
			
		||||
		authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
 | 
			
		||||
	}
 | 
			
		||||
	return domain, authUrl
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -19,11 +19,11 @@ type ChatRequest struct {
 | 
			
		||||
	} `json:"header"`
 | 
			
		||||
	Parameter struct {
 | 
			
		||||
		Chat struct {
 | 
			
		||||
			Domain      string  `json:"domain,omitempty"`
 | 
			
		||||
			Temperature float64 `json:"temperature,omitempty"`
 | 
			
		||||
			TopK        int     `json:"top_k,omitempty"`
 | 
			
		||||
			MaxTokens   int     `json:"max_tokens,omitempty"`
 | 
			
		||||
			Auditing    bool    `json:"auditing,omitempty"`
 | 
			
		||||
			Domain      string   `json:"domain,omitempty"`
 | 
			
		||||
			Temperature *float64 `json:"temperature,omitempty"`
 | 
			
		||||
			TopK        int      `json:"top_k,omitempty"`
 | 
			
		||||
			MaxTokens   int      `json:"max_tokens,omitempty"`
 | 
			
		||||
			Auditing    bool     `json:"auditing,omitempty"`
 | 
			
		||||
		} `json:"chat"`
 | 
			
		||||
	} `json:"parameter"`
 | 
			
		||||
	Payload struct {
 | 
			
		||||
 
 | 
			
		||||
@@ -4,13 +4,13 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/relaymode"
 | 
			
		||||
	"io"
 | 
			
		||||
	"math"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
@@ -65,13 +65,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
		baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request)
 | 
			
		||||
		return baiduEmbeddingRequest, err
 | 
			
		||||
	default:
 | 
			
		||||
		// TopP (0.0, 1.0)
 | 
			
		||||
		request.TopP = math.Min(0.99, request.TopP)
 | 
			
		||||
		request.TopP = math.Max(0.01, request.TopP)
 | 
			
		||||
		// TopP [0.0, 1.0]
 | 
			
		||||
		request.TopP = helper.Float64PtrMax(request.TopP, 1)
 | 
			
		||||
		request.TopP = helper.Float64PtrMin(request.TopP, 0)
 | 
			
		||||
 | 
			
		||||
		// Temperature (0.0, 1.0)
 | 
			
		||||
		request.Temperature = math.Min(0.99, request.Temperature)
 | 
			
		||||
		request.Temperature = math.Max(0.01, request.Temperature)
 | 
			
		||||
		// Temperature [0.0, 1.0]
 | 
			
		||||
		request.Temperature = helper.Float64PtrMax(request.Temperature, 1)
 | 
			
		||||
		request.Temperature = helper.Float64PtrMin(request.Temperature, 0)
 | 
			
		||||
		a.SetVersionByModeName(request.Model)
 | 
			
		||||
		if a.APIVersion == "v4" {
 | 
			
		||||
			return request, nil
 | 
			
		||||
 
 | 
			
		||||
@@ -12,8 +12,8 @@ type Message struct {
 | 
			
		||||
 | 
			
		||||
type Request struct {
 | 
			
		||||
	Prompt      []Message `json:"prompt"`
 | 
			
		||||
	Temperature float64   `json:"temperature,omitempty"`
 | 
			
		||||
	TopP        float64   `json:"top_p,omitempty"`
 | 
			
		||||
	Temperature *float64  `json:"temperature,omitempty"`
 | 
			
		||||
	TopP        *float64  `json:"top_p,omitempty"`
 | 
			
		||||
	RequestId   string    `json:"request_id,omitempty"`
 | 
			
		||||
	Incremental bool      `json:"incremental,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -17,6 +17,9 @@ const (
 | 
			
		||||
	Cohere
 | 
			
		||||
	Cloudflare
 | 
			
		||||
	DeepL
 | 
			
		||||
	VertexAI
 | 
			
		||||
	Proxy
 | 
			
		||||
	Replicate
 | 
			
		||||
 | 
			
		||||
	Dummy // this one is only for count, do not add any channel after this
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package billing
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
)
 | 
			
		||||
@@ -31,8 +32,17 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQ
 | 
			
		||||
	}
 | 
			
		||||
	// totalQuota is total quota consumed
 | 
			
		||||
	if totalQuota != 0 {
 | 
			
		||||
		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
		model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
 | 
			
		||||
		logContent := fmt.Sprintf("倍率:%.2f × %.2f", modelRatio, groupRatio)
 | 
			
		||||
		model.RecordConsumeLog(ctx, &model.Log{
 | 
			
		||||
			UserId:           userId,
 | 
			
		||||
			ChannelId:        channelId,
 | 
			
		||||
			PromptTokens:     int(totalQuota),
 | 
			
		||||
			CompletionTokens: 0,
 | 
			
		||||
			ModelName:        modelName,
 | 
			
		||||
			TokenName:        tokenName,
 | 
			
		||||
			Quota:            int(totalQuota),
 | 
			
		||||
			Content:          logContent,
 | 
			
		||||
		})
 | 
			
		||||
		model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
 | 
			
		||||
		model.UpdateChannelUsedQuota(channelId, totalQuota)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -30,6 +30,14 @@ var ImageSizeRatios = map[string]map[string]float64{
 | 
			
		||||
		"720x1280":  1,
 | 
			
		||||
		"1280x720":  1,
 | 
			
		||||
	},
 | 
			
		||||
	"step-1x-medium": {
 | 
			
		||||
		"256x256":   1,
 | 
			
		||||
		"512x512":   1,
 | 
			
		||||
		"768x768":   1,
 | 
			
		||||
		"1024x1024": 1,
 | 
			
		||||
		"1280x800":  1,
 | 
			
		||||
		"800x1280":  1,
 | 
			
		||||
	},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var ImageGenerationAmounts = map[string][2]int{
 | 
			
		||||
@@ -39,6 +47,7 @@ var ImageGenerationAmounts = map[string][2]int{
 | 
			
		||||
	"ali-stable-diffusion-v1.5": {1, 4}, // Ali
 | 
			
		||||
	"wanx-v1":                   {1, 4}, // Ali
 | 
			
		||||
	"cogview-3":                 {1, 1},
 | 
			
		||||
	"step-1x-medium":            {1, 1},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var ImagePromptLengthLimitations = map[string]int{
 | 
			
		||||
@@ -48,6 +57,7 @@ var ImagePromptLengthLimitations = map[string]int{
 | 
			
		||||
	"ali-stable-diffusion-v1.5": 4000,
 | 
			
		||||
	"wanx-v1":                   4000,
 | 
			
		||||
	"cogview-3":                 833,
 | 
			
		||||
	"step-1x-medium":            4000,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var ImageOriginModelName = map[string]string{
 | 
			
		||||
 
 | 
			
		||||
@@ -9,9 +9,10 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	USD2RMB = 7
 | 
			
		||||
	USD     = 500 // $0.002 = 1 -> $1 = 500
 | 
			
		||||
	RMB     = USD / USD2RMB
 | 
			
		||||
	USD2RMB   = 7
 | 
			
		||||
	USD       = 500 // $0.002 = 1 -> $1 = 500
 | 
			
		||||
	MILLI_USD = 1.0 / 1000 * USD
 | 
			
		||||
	RMB       = USD / USD2RMB
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ModelRatio
 | 
			
		||||
@@ -28,15 +29,20 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"gpt-4-32k":               30,
 | 
			
		||||
	"gpt-4-32k-0314":          30,
 | 
			
		||||
	"gpt-4-32k-0613":          30,
 | 
			
		||||
	"gpt-4-1106-preview":      5,    // $0.01 / 1K tokens
 | 
			
		||||
	"gpt-4-0125-preview":      5,    // $0.01 / 1K tokens
 | 
			
		||||
	"gpt-4-turbo-preview":     5,    // $0.01 / 1K tokens
 | 
			
		||||
	"gpt-4-turbo":             5,    // $0.01 / 1K tokens
 | 
			
		||||
	"gpt-4-turbo-2024-04-09":  5,    // $0.01 / 1K tokens
 | 
			
		||||
	"gpt-4o":                  2.5,  // $0.005 / 1K tokens
 | 
			
		||||
	"gpt-4o-2024-05-13":       2.5,  // $0.005 / 1K tokens
 | 
			
		||||
	"gpt-4-vision-preview":    5,    // $0.01 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo":           0.25, // $0.0005 / 1K tokens
 | 
			
		||||
	"gpt-4-1106-preview":      5,     // $0.01 / 1K tokens
 | 
			
		||||
	"gpt-4-0125-preview":      5,     // $0.01 / 1K tokens
 | 
			
		||||
	"gpt-4-turbo-preview":     5,     // $0.01 / 1K tokens
 | 
			
		||||
	"gpt-4-turbo":             5,     // $0.01 / 1K tokens
 | 
			
		||||
	"gpt-4-turbo-2024-04-09":  5,     // $0.01 / 1K tokens
 | 
			
		||||
	"gpt-4o":                  2.5,   // $0.005 / 1K tokens
 | 
			
		||||
	"chatgpt-4o-latest":       2.5,   // $0.005 / 1K tokens
 | 
			
		||||
	"gpt-4o-2024-05-13":       2.5,   // $0.005 / 1K tokens
 | 
			
		||||
	"gpt-4o-2024-08-06":       1.25,  // $0.0025 / 1K tokens
 | 
			
		||||
	"gpt-4o-2024-11-20":       1.25,  // $0.0025 / 1K tokens
 | 
			
		||||
	"gpt-4o-mini":             0.075, // $0.00015 / 1K tokens
 | 
			
		||||
	"gpt-4o-mini-2024-07-18":  0.075, // $0.00015 / 1K tokens
 | 
			
		||||
	"gpt-4-vision-preview":    5,     // $0.01 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo":           0.25,  // $0.0005 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-0301":      0.75,
 | 
			
		||||
	"gpt-3.5-turbo-0613":      0.75,
 | 
			
		||||
	"gpt-3.5-turbo-16k":       1.5, // $0.003 / 1K tokens
 | 
			
		||||
@@ -44,8 +50,14 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"gpt-3.5-turbo-instruct":  0.75, // $0.0015 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-1106":      0.5,  // $0.001 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-0125":      0.25, // $0.0005 / 1K tokens
 | 
			
		||||
	"davinci-002":             1,    // $0.002 / 1K tokens
 | 
			
		||||
	"babbage-002":             0.2,  // $0.0004 / 1K tokens
 | 
			
		||||
	"o1":                      7.5,  // $15.00 / 1M input tokens
 | 
			
		||||
	"o1-2024-12-17":           7.5,
 | 
			
		||||
	"o1-preview":              7.5, // $15.00 / 1M input tokens
 | 
			
		||||
	"o1-preview-2024-09-12":   7.5,
 | 
			
		||||
	"o1-mini":                 1.5, // $3.00 / 1M input tokens
 | 
			
		||||
	"o1-mini-2024-09-12":      1.5,
 | 
			
		||||
	"davinci-002":             1,   // $0.002 / 1K tokens
 | 
			
		||||
	"babbage-002":             0.2, // $0.0004 / 1K tokens
 | 
			
		||||
	"text-ada-001":            0.2,
 | 
			
		||||
	"text-babbage-001":        0.25,
 | 
			
		||||
	"text-curie-001":          1,
 | 
			
		||||
@@ -75,8 +87,10 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"claude-2.0":                 8.0 / 1000 * USD,
 | 
			
		||||
	"claude-2.1":                 8.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-haiku-20240307":    0.25 / 1000 * USD,
 | 
			
		||||
	"claude-3-5-haiku-20241022":  1.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-sonnet-20240229":   3.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-5-sonnet-20240620": 3.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-5-sonnet-20241022": 3.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-opus-20240229":     15.0 / 1000 * USD,
 | 
			
		||||
	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
 | 
			
		||||
	"ERNIE-4.0-8K":       0.120 * RMB,
 | 
			
		||||
@@ -96,12 +110,16 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"bge-large-en":       0.002 * RMB,
 | 
			
		||||
	"tao-8k":             0.002 * RMB,
 | 
			
		||||
	// https://ai.google.dev/pricing
 | 
			
		||||
	"PaLM-2":                    1,
 | 
			
		||||
	"gemini-pro":                1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
 | 
			
		||||
	"gemini-pro-vision":         1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
 | 
			
		||||
	"gemini-1.0-pro-vision-001": 1,
 | 
			
		||||
	"gemini-1.0-pro-001":        1,
 | 
			
		||||
	"gemini-1.5-pro":            1,
 | 
			
		||||
	"gemini-pro":                          1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
 | 
			
		||||
	"gemini-1.0-pro":                      1,
 | 
			
		||||
	"gemini-1.5-pro":                      1,
 | 
			
		||||
	"gemini-1.5-pro-001":                  1,
 | 
			
		||||
	"gemini-1.5-flash":                    1,
 | 
			
		||||
	"gemini-1.5-flash-001":                1,
 | 
			
		||||
	"gemini-2.0-flash-exp":                1,
 | 
			
		||||
	"gemini-2.0-flash-thinking-exp":       1,
 | 
			
		||||
	"gemini-2.0-flash-thinking-exp-01-21": 1,
 | 
			
		||||
	"aqa":                                 1,
 | 
			
		||||
	// https://open.bigmodel.cn/pricing
 | 
			
		||||
	"glm-4":         0.1 * RMB,
 | 
			
		||||
	"glm-4v":        0.1 * RMB,
 | 
			
		||||
@@ -113,27 +131,94 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"chatglm_lite":  0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"cogview-3":     0.25 * RMB,
 | 
			
		||||
	// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
 | 
			
		||||
	"qwen-turbo":                0.5715, // ¥0.008 / 1k tokens
 | 
			
		||||
	"qwen-plus":                 1.4286, // ¥0.02 / 1k tokens
 | 
			
		||||
	"qwen-max":                  1.4286, // ¥0.02 / 1k tokens
 | 
			
		||||
	"qwen-max-longcontext":      1.4286, // ¥0.02 / 1k tokens
 | 
			
		||||
	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens
 | 
			
		||||
	"ali-stable-diffusion-xl":   8,
 | 
			
		||||
	"ali-stable-diffusion-v1.5": 8,
 | 
			
		||||
	"wanx-v1":                   8,
 | 
			
		||||
	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v1.1":            1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v2.1":            1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v3.1":            1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v3.5":            1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v4.0":            1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"hunyuan":                   7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
 | 
			
		||||
	"ChatStd":                   0.01 * RMB,
 | 
			
		||||
	"ChatPro":                   0.1 * RMB,
 | 
			
		||||
	"qwen-turbo":                  1.4286, // ¥0.02 / 1k tokens
 | 
			
		||||
	"qwen-turbo-latest":           1.4286,
 | 
			
		||||
	"qwen-plus":                   1.4286,
 | 
			
		||||
	"qwen-plus-latest":            1.4286,
 | 
			
		||||
	"qwen-max":                    1.4286,
 | 
			
		||||
	"qwen-max-latest":             1.4286,
 | 
			
		||||
	"qwen-max-longcontext":        1.4286,
 | 
			
		||||
	"qwen-vl-max":                 1.4286,
 | 
			
		||||
	"qwen-vl-max-latest":          1.4286,
 | 
			
		||||
	"qwen-vl-plus":                1.4286,
 | 
			
		||||
	"qwen-vl-plus-latest":         1.4286,
 | 
			
		||||
	"qwen-vl-ocr":                 1.4286,
 | 
			
		||||
	"qwen-vl-ocr-latest":          1.4286,
 | 
			
		||||
	"qwen-audio-turbo":            1.4286,
 | 
			
		||||
	"qwen-math-plus":              1.4286,
 | 
			
		||||
	"qwen-math-plus-latest":       1.4286,
 | 
			
		||||
	"qwen-math-turbo":             1.4286,
 | 
			
		||||
	"qwen-math-turbo-latest":      1.4286,
 | 
			
		||||
	"qwen-coder-plus":             1.4286,
 | 
			
		||||
	"qwen-coder-plus-latest":      1.4286,
 | 
			
		||||
	"qwen-coder-turbo":            1.4286,
 | 
			
		||||
	"qwen-coder-turbo-latest":     1.4286,
 | 
			
		||||
	"qwq-32b-preview":             1.4286,
 | 
			
		||||
	"qwen2.5-72b-instruct":        1.4286,
 | 
			
		||||
	"qwen2.5-32b-instruct":        1.4286,
 | 
			
		||||
	"qwen2.5-14b-instruct":        1.4286,
 | 
			
		||||
	"qwen2.5-7b-instruct":         1.4286,
 | 
			
		||||
	"qwen2.5-3b-instruct":         1.4286,
 | 
			
		||||
	"qwen2.5-1.5b-instruct":       1.4286,
 | 
			
		||||
	"qwen2.5-0.5b-instruct":       1.4286,
 | 
			
		||||
	"qwen2-72b-instruct":          1.4286,
 | 
			
		||||
	"qwen2-57b-a14b-instruct":     1.4286,
 | 
			
		||||
	"qwen2-7b-instruct":           1.4286,
 | 
			
		||||
	"qwen2-1.5b-instruct":         1.4286,
 | 
			
		||||
	"qwen2-0.5b-instruct":         1.4286,
 | 
			
		||||
	"qwen1.5-110b-chat":           1.4286,
 | 
			
		||||
	"qwen1.5-72b-chat":            1.4286,
 | 
			
		||||
	"qwen1.5-32b-chat":            1.4286,
 | 
			
		||||
	"qwen1.5-14b-chat":            1.4286,
 | 
			
		||||
	"qwen1.5-7b-chat":             1.4286,
 | 
			
		||||
	"qwen1.5-1.8b-chat":           1.4286,
 | 
			
		||||
	"qwen1.5-0.5b-chat":           1.4286,
 | 
			
		||||
	"qwen-72b-chat":               1.4286,
 | 
			
		||||
	"qwen-14b-chat":               1.4286,
 | 
			
		||||
	"qwen-7b-chat":                1.4286,
 | 
			
		||||
	"qwen-1.8b-chat":              1.4286,
 | 
			
		||||
	"qwen-1.8b-longcontext-chat":  1.4286,
 | 
			
		||||
	"qwen2-vl-7b-instruct":        1.4286,
 | 
			
		||||
	"qwen2-vl-2b-instruct":        1.4286,
 | 
			
		||||
	"qwen-vl-v1":                  1.4286,
 | 
			
		||||
	"qwen-vl-chat-v1":             1.4286,
 | 
			
		||||
	"qwen2-audio-instruct":        1.4286,
 | 
			
		||||
	"qwen-audio-chat":             1.4286,
 | 
			
		||||
	"qwen2.5-math-72b-instruct":   1.4286,
 | 
			
		||||
	"qwen2.5-math-7b-instruct":    1.4286,
 | 
			
		||||
	"qwen2.5-math-1.5b-instruct":  1.4286,
 | 
			
		||||
	"qwen2-math-72b-instruct":     1.4286,
 | 
			
		||||
	"qwen2-math-7b-instruct":      1.4286,
 | 
			
		||||
	"qwen2-math-1.5b-instruct":    1.4286,
 | 
			
		||||
	"qwen2.5-coder-32b-instruct":  1.4286,
 | 
			
		||||
	"qwen2.5-coder-14b-instruct":  1.4286,
 | 
			
		||||
	"qwen2.5-coder-7b-instruct":   1.4286,
 | 
			
		||||
	"qwen2.5-coder-3b-instruct":   1.4286,
 | 
			
		||||
	"qwen2.5-coder-1.5b-instruct": 1.4286,
 | 
			
		||||
	"qwen2.5-coder-0.5b-instruct": 1.4286,
 | 
			
		||||
	"text-embedding-v1":           0.05, // ¥0.0007 / 1k tokens
 | 
			
		||||
	"text-embedding-v3":           0.05,
 | 
			
		||||
	"text-embedding-v2":           0.05,
 | 
			
		||||
	"text-embedding-async-v2":     0.05,
 | 
			
		||||
	"text-embedding-async-v1":     0.05,
 | 
			
		||||
	"ali-stable-diffusion-xl":     8.00,
 | 
			
		||||
	"ali-stable-diffusion-v1.5":   8.00,
 | 
			
		||||
	"wanx-v1":                     8.00,
 | 
			
		||||
	"SparkDesk":                   1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v1.1":              1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v2.1":              1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v3.1":              1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v3.1-128K":         1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v3.5":              1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v3.5-32K":          1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"SparkDesk-v4.0":              1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"360GPT_S2_V9":                0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"embedding-bert-512-v1":       0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"embedding_s1_v1":             0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"semantic_similarity_s1_v1":   0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"hunyuan":                     7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
 | 
			
		||||
	"ChatStd":                     0.01 * RMB,
 | 
			
		||||
	"ChatPro":                     0.1 * RMB,
 | 
			
		||||
	// https://platform.moonshot.cn/pricing
 | 
			
		||||
	"moonshot-v1-8k":   0.012 * RMB,
 | 
			
		||||
	"moonshot-v1-32k":  0.024 * RMB,
 | 
			
		||||
@@ -156,20 +241,35 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"mistral-large-latest":  8.0 / 1000 * USD,
 | 
			
		||||
	"mistral-embed":         0.1 / 1000 * USD,
 | 
			
		||||
	// https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed
 | 
			
		||||
	"llama3-70b-8192":    0.59 / 1000 * USD,
 | 
			
		||||
	"mixtral-8x7b-32768": 0.27 / 1000 * USD,
 | 
			
		||||
	"llama3-8b-8192":     0.05 / 1000 * USD,
 | 
			
		||||
	"gemma-7b-it":        0.1 / 1000 * USD,
 | 
			
		||||
	"llama2-70b-4096":    0.64 / 1000 * USD,
 | 
			
		||||
	"llama2-7b-2048":     0.1 / 1000 * USD,
 | 
			
		||||
	"gemma-7b-it":                           0.07 / 1000000 * USD,
 | 
			
		||||
	"gemma2-9b-it":                          0.20 / 1000000 * USD,
 | 
			
		||||
	"llama-3.1-70b-versatile":               0.59 / 1000000 * USD,
 | 
			
		||||
	"llama-3.1-8b-instant":                  0.05 / 1000000 * USD,
 | 
			
		||||
	"llama-3.2-11b-text-preview":            0.05 / 1000000 * USD,
 | 
			
		||||
	"llama-3.2-11b-vision-preview":          0.05 / 1000000 * USD,
 | 
			
		||||
	"llama-3.2-1b-preview":                  0.05 / 1000000 * USD,
 | 
			
		||||
	"llama-3.2-3b-preview":                  0.05 / 1000000 * USD,
 | 
			
		||||
	"llama-3.2-90b-text-preview":            0.59 / 1000000 * USD,
 | 
			
		||||
	"llama-guard-3-8b":                      0.05 / 1000000 * USD,
 | 
			
		||||
	"llama3-70b-8192":                       0.59 / 1000000 * USD,
 | 
			
		||||
	"llama3-8b-8192":                        0.05 / 1000000 * USD,
 | 
			
		||||
	"llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD,
 | 
			
		||||
	"llama3-groq-8b-8192-tool-use-preview":  0.19 / 1000000 * USD,
 | 
			
		||||
	"mixtral-8x7b-32768":                    0.24 / 1000000 * USD,
 | 
			
		||||
 | 
			
		||||
	// https://platform.lingyiwanwu.com/docs#-计费单元
 | 
			
		||||
	"yi-34b-chat-0205": 2.5 / 1000 * RMB,
 | 
			
		||||
	"yi-34b-chat-200k": 12.0 / 1000 * RMB,
 | 
			
		||||
	"yi-vl-plus":       6.0 / 1000 * RMB,
 | 
			
		||||
	// stepfun todo
 | 
			
		||||
	"step-1v-32k": 0.024 * RMB,
 | 
			
		||||
	"step-1-32k":  0.024 * RMB,
 | 
			
		||||
	"step-1-200k": 0.15 * RMB,
 | 
			
		||||
	// https://platform.stepfun.com/docs/pricing/details
 | 
			
		||||
	"step-1-8k":    0.005 / 1000 * RMB,
 | 
			
		||||
	"step-1-32k":   0.015 / 1000 * RMB,
 | 
			
		||||
	"step-1-128k":  0.040 / 1000 * RMB,
 | 
			
		||||
	"step-1-256k":  0.095 / 1000 * RMB,
 | 
			
		||||
	"step-1-flash": 0.001 / 1000 * RMB,
 | 
			
		||||
	"step-2-16k":   0.038 / 1000 * RMB,
 | 
			
		||||
	"step-1v-8k":   0.005 / 1000 * RMB,
 | 
			
		||||
	"step-1v-32k":  0.015 / 1000 * RMB,
 | 
			
		||||
	// aws llama3 https://aws.amazon.com/cn/bedrock/pricing/
 | 
			
		||||
	"llama3-8b-8192(33)":  0.0003 / 0.002,  // $0.0003 / 1K tokens
 | 
			
		||||
	"llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens
 | 
			
		||||
@@ -181,22 +281,75 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"command-r":             0.5 / 1000 * USD,
 | 
			
		||||
	"command-r-plus":        3.0 / 1000 * USD,
 | 
			
		||||
	// https://platform.deepseek.com/api-docs/pricing/
 | 
			
		||||
	"deepseek-chat":  1.0 / 1000 * RMB,
 | 
			
		||||
	"deepseek-coder": 1.0 / 1000 * RMB,
 | 
			
		||||
	"deepseek-chat":     0.14 * MILLI_USD,
 | 
			
		||||
	"deepseek-reasoner": 0.55 * MILLI_USD,
 | 
			
		||||
	// https://www.deepl.com/pro?cta=header-prices
 | 
			
		||||
	"deepl-zh": 25.0 / 1000 * USD,
 | 
			
		||||
	"deepl-en": 25.0 / 1000 * USD,
 | 
			
		||||
	"deepl-ja": 25.0 / 1000 * USD,
 | 
			
		||||
	// https://console.x.ai/
 | 
			
		||||
	"grok-beta": 5.0 / 1000 * USD,
 | 
			
		||||
	// replicate charges based on the number of generated images
 | 
			
		||||
	// https://replicate.com/pricing
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro":                0.04 * USD,
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro-ultra":          0.06 * USD,
 | 
			
		||||
	"black-forest-labs/flux-canny-dev":              0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-canny-pro":              0.05 * USD,
 | 
			
		||||
	"black-forest-labs/flux-depth-dev":              0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-depth-pro":              0.05 * USD,
 | 
			
		||||
	"black-forest-labs/flux-dev":                    0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-dev-lora":               0.032 * USD,
 | 
			
		||||
	"black-forest-labs/flux-fill-dev":               0.04 * USD,
 | 
			
		||||
	"black-forest-labs/flux-fill-pro":               0.05 * USD,
 | 
			
		||||
	"black-forest-labs/flux-pro":                    0.055 * USD,
 | 
			
		||||
	"black-forest-labs/flux-redux-dev":              0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-redux-schnell":          0.003 * USD,
 | 
			
		||||
	"black-forest-labs/flux-schnell":                0.003 * USD,
 | 
			
		||||
	"black-forest-labs/flux-schnell-lora":           0.02 * USD,
 | 
			
		||||
	"ideogram-ai/ideogram-v2":                       0.08 * USD,
 | 
			
		||||
	"ideogram-ai/ideogram-v2-turbo":                 0.05 * USD,
 | 
			
		||||
	"recraft-ai/recraft-v3":                         0.04 * USD,
 | 
			
		||||
	"recraft-ai/recraft-v3-svg":                     0.08 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3":               0.035 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large":       0.065 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-medium":      0.035 * USD,
 | 
			
		||||
	// replicate chat models
 | 
			
		||||
	"ibm-granite/granite-20b-code-instruct-8k":  0.100 * USD,
 | 
			
		||||
	"ibm-granite/granite-3.0-2b-instruct":       0.030 * USD,
 | 
			
		||||
	"ibm-granite/granite-3.0-8b-instruct":       0.050 * USD,
 | 
			
		||||
	"ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD,
 | 
			
		||||
	"meta/llama-2-13b":                          0.100 * USD,
 | 
			
		||||
	"meta/llama-2-13b-chat":                     0.100 * USD,
 | 
			
		||||
	"meta/llama-2-70b":                          0.650 * USD,
 | 
			
		||||
	"meta/llama-2-70b-chat":                     0.650 * USD,
 | 
			
		||||
	"meta/llama-2-7b":                           0.050 * USD,
 | 
			
		||||
	"meta/llama-2-7b-chat":                      0.050 * USD,
 | 
			
		||||
	"meta/meta-llama-3.1-405b-instruct":         9.500 * USD,
 | 
			
		||||
	"meta/meta-llama-3-70b":                     0.650 * USD,
 | 
			
		||||
	"meta/meta-llama-3-70b-instruct":            0.650 * USD,
 | 
			
		||||
	"meta/meta-llama-3-8b":                      0.050 * USD,
 | 
			
		||||
	"meta/meta-llama-3-8b-instruct":             0.050 * USD,
 | 
			
		||||
	"mistralai/mistral-7b-instruct-v0.2":        0.050 * USD,
 | 
			
		||||
	"mistralai/mistral-7b-v0.1":                 0.050 * USD,
 | 
			
		||||
	"mistralai/mixtral-8x7b-instruct-v0.1":      0.300 * USD,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var CompletionRatio = map[string]float64{
 | 
			
		||||
	// aws llama3
 | 
			
		||||
	"llama3-8b-8192(33)":  0.0006 / 0.0003,
 | 
			
		||||
	"llama3-70b-8192(33)": 0.0035 / 0.00265,
 | 
			
		||||
	// whisper
 | 
			
		||||
	"whisper-1": 0, // only count input tokens
 | 
			
		||||
	// deepseek
 | 
			
		||||
	"deepseek-chat":     0.28 / 0.14,
 | 
			
		||||
	"deepseek-reasoner": 2.19 / 0.55,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var DefaultModelRatio map[string]float64
 | 
			
		||||
var DefaultCompletionRatio map[string]float64
 | 
			
		||||
var (
 | 
			
		||||
	DefaultModelRatio      map[string]float64
 | 
			
		||||
	DefaultCompletionRatio map[string]float64
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	DefaultModelRatio = make(map[string]float64)
 | 
			
		||||
@@ -308,13 +461,25 @@ func GetCompletionRatio(name string, channelType int) float64 {
 | 
			
		||||
		return 4.0 / 3.0
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(name, "gpt-4") {
 | 
			
		||||
		if strings.HasPrefix(name, "gpt-4o") {
 | 
			
		||||
			if name == "gpt-4o-2024-05-13" {
 | 
			
		||||
				return 3
 | 
			
		||||
			}
 | 
			
		||||
			return 4
 | 
			
		||||
		}
 | 
			
		||||
		if strings.HasPrefix(name, "gpt-4-turbo") ||
 | 
			
		||||
			strings.HasPrefix(name, "gpt-4o") ||
 | 
			
		||||
			strings.HasSuffix(name, "preview") {
 | 
			
		||||
			return 3
 | 
			
		||||
		}
 | 
			
		||||
		return 2
 | 
			
		||||
	}
 | 
			
		||||
	// including o1, o1-preview, o1-mini
 | 
			
		||||
	if strings.HasPrefix(name, "o1") {
 | 
			
		||||
		return 4
 | 
			
		||||
	}
 | 
			
		||||
	if name == "chatgpt-4o-latest" {
 | 
			
		||||
		return 3
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(name, "claude-3") {
 | 
			
		||||
		return 5
 | 
			
		||||
	}
 | 
			
		||||
@@ -330,6 +495,7 @@ func GetCompletionRatio(name string, channelType int) float64 {
 | 
			
		||||
	if strings.HasPrefix(name, "deepseek-") {
 | 
			
		||||
		return 2
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch name {
 | 
			
		||||
	case "llama2-70b-4096":
 | 
			
		||||
		return 0.8 / 0.64
 | 
			
		||||
@@ -343,6 +509,37 @@ func GetCompletionRatio(name string, channelType int) float64 {
 | 
			
		||||
		return 3
 | 
			
		||||
	case "command-r-plus":
 | 
			
		||||
		return 5
 | 
			
		||||
	case "grok-beta":
 | 
			
		||||
		return 3
 | 
			
		||||
	// Replicate Models
 | 
			
		||||
	// https://replicate.com/pricing
 | 
			
		||||
	case "ibm-granite/granite-20b-code-instruct-8k":
 | 
			
		||||
		return 5
 | 
			
		||||
	case "ibm-granite/granite-3.0-2b-instruct":
 | 
			
		||||
		return 8.333333333333334
 | 
			
		||||
	case "ibm-granite/granite-3.0-8b-instruct",
 | 
			
		||||
		"ibm-granite/granite-8b-code-instruct-128k":
 | 
			
		||||
		return 5
 | 
			
		||||
	case "meta/llama-2-13b",
 | 
			
		||||
		"meta/llama-2-13b-chat",
 | 
			
		||||
		"meta/llama-2-7b",
 | 
			
		||||
		"meta/llama-2-7b-chat",
 | 
			
		||||
		"meta/meta-llama-3-8b",
 | 
			
		||||
		"meta/meta-llama-3-8b-instruct":
 | 
			
		||||
		return 5
 | 
			
		||||
	case "meta/llama-2-70b",
 | 
			
		||||
		"meta/llama-2-70b-chat",
 | 
			
		||||
		"meta/meta-llama-3-70b",
 | 
			
		||||
		"meta/meta-llama-3-70b-instruct":
 | 
			
		||||
		return 2.750 / 0.650 // ≈4.230769
 | 
			
		||||
	case "meta/meta-llama-3.1-405b-instruct":
 | 
			
		||||
		return 1
 | 
			
		||||
	case "mistralai/mistral-7b-instruct-v0.2",
 | 
			
		||||
		"mistralai/mistral-7b-v0.1":
 | 
			
		||||
		return 5
 | 
			
		||||
	case "mistralai/mixtral-8x7b-instruct-v0.1":
 | 
			
		||||
		return 1.000 / 0.300 // ≈3.333333
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return 1
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user