mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-31 22:03:41 +08:00 
			
		
		
		
	Compare commits
	
		
			73 Commits
		
	
	
		
			v0.5.6-alp
			...
			v0.5.9
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 01f7b0186f | ||
|  | a3f80a3392 | ||
|  | 8f5b83562b | ||
|  | b7570d5c77 | ||
|  | 0e73418cdf | ||
|  | 9889377f0e | ||
|  | b273464e77 | ||
|  | b4e43d97fd | ||
|  | 3347a44023 | ||
|  | 923e24534b | ||
|  | b4d67ca614 | ||
|  | d85e356b6e | ||
|  | 495fc628e4 | ||
|  | 76f9288c34 | ||
|  | 915d13fdd4 | ||
|  | 969f539777 | ||
|  | 54e5f8ecd2 | ||
|  | 34d517cfa2 | ||
|  | ddcaf95f5f | ||
|  | 1d15157f7d | ||
|  | de7b9710a5 | ||
|  | 58bb3ab6f6 | ||
|  | d306cb5229 | ||
|  | 6c5307d0c4 | ||
|  | 7c4505bdfc | ||
|  | 9d43ec57d8 | ||
|  | e5311892d1 | ||
|  | bc7c9105f4 | ||
|  | 3fe76c8af7 | ||
|  | c70c614018 | ||
|  | 0d87de697c | ||
|  | aec343dc38 | ||
|  | 89d458b9cf | ||
|  | 63fafba112 | ||
|  | a398f35968 | ||
|  | 57aa637c77 | ||
|  | 3b483639a4 | ||
|  | 22980b4c44 | ||
|  | 64cdb7eafb | ||
|  | 824444244b | ||
|  | fbe9985f57 | ||
|  | a27a5bcc06 | ||
|  | e28d4b1741 | ||
|  | f073592d39 | ||
|  | fa41ca9805 | ||
|  | e338de45b6 | ||
|  | 114587b46f | ||
|  | b4b4acc288 | ||
|  | d663de3e3a | ||
|  | a85ecace2e | ||
|  | fbdea91ea1 | ||
|  | 8d34b7a77e | ||
|  | cbd62011b8 | ||
|  | 4701897e2e | ||
|  | 0f6c132a80 | ||
|  | 3cac45dc85 | ||
|  | 47c08c72ce | ||
|  | 53b2cace0b | ||
|  | f0fc991b44 | ||
|  | 594f06e7b0 | ||
|  | 197d1d7a9d | ||
|  | f9b748c2ca | ||
|  | fd98463611 | ||
|  | f5a1cd3463 | ||
|  | 8651451e53 | ||
|  | 1c5bb97a42 | ||
|  | de868e4e4e | ||
|  | 1d258cc898 | ||
|  | 37e09d764c | ||
|  | 159b9e3369 | ||
|  | 92001986db | ||
|  | a5647b1ea7 | ||
|  | 215e54fc96 | 
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -5,4 +5,5 @@ upload | ||||
| *.db | ||||
| build | ||||
| *.db-journal | ||||
| logs | ||||
| logs | ||||
| data | ||||
| @@ -189,6 +189,8 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co | ||||
|  | ||||
| > Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage. | ||||
|  | ||||
| [](https://zeabur.com/templates/7Q0KO3) | ||||
|  | ||||
| 1. First, fork the code. | ||||
| 2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console. | ||||
| 3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port). | ||||
|   | ||||
| @@ -190,6 +190,8 @@ Please refer to the [environment variables](#environment-variables) section for | ||||
|  | ||||
| > Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。 | ||||
|  | ||||
| [](https://zeabur.com/templates/7Q0KO3) | ||||
|  | ||||
| 1. まず、コードをフォークする。 | ||||
| 2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。 | ||||
| 3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。 | ||||
|   | ||||
							
								
								
									
										100
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										100
									
								
								README.md
									
									
									
									
									
								
							| @@ -51,14 +51,17 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|   <a href="https://iamazing.cn/page/reward">赞赏支持</a> | ||||
| </p> | ||||
|  | ||||
| > **Note** | ||||
| > [!NOTE] | ||||
| > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 | ||||
| >  | ||||
| > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 | ||||
|  | ||||
| > **Warning** | ||||
| > [!WARNING] | ||||
| > 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 | ||||
|  | ||||
| > [!WARNING] | ||||
| > 使用 root 用户初次登录系统后,务必修改默认密码 `123456`! | ||||
|  | ||||
| ## 功能 | ||||
| 1. 支持多种大模型: | ||||
|    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||
| @@ -69,9 +72,10 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||
|    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) | ||||
|    + [x] [360 智脑](https://ai.360.cn) | ||||
|    + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) | ||||
| 2. 支持配置镜像以及众多第三方代理服务: | ||||
|    + [x] [OpenAI-SB](https://openai-sb.com) | ||||
|    + [x] [CloseAI](https://console.closeai-asia.com/r/2412) | ||||
|    + [x] [CloseAI](https://referer.shadowai.xyz/r/2412) | ||||
|    + [x] [API2D](https://api2d.com/r/197971) | ||||
|    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) | ||||
|    + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) | ||||
| @@ -88,26 +92,33 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
| 12. 支持**用户邀请奖励**。 | ||||
| 13. 支持以美元为单位显示额度。 | ||||
| 14. 支持发布公告,设置充值链接,设置新用户初始额度。 | ||||
| 15. 支持模型映射,重定向用户的请求模型。 | ||||
| 15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功。 | ||||
| 16. 支持失败自动重试。 | ||||
| 17. 支持绘图接口。 | ||||
| 18. 支持丰富的**自定义**设置, | ||||
| 18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。 | ||||
| 19. 支持丰富的**自定义**设置, | ||||
|     1. 支持自定义系统名称,logo 以及页脚。 | ||||
|     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 | ||||
| 19. 支持通过系统访问令牌访问管理 API。 | ||||
| 20. 支持 Cloudflare Turnstile 用户校验。 | ||||
| 21. 支持用户管理,支持**多种用户登录注册方式**: | ||||
| 20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。 | ||||
| 21. 支持 Cloudflare Turnstile 用户校验。 | ||||
| 22. 支持用户管理,支持**多种用户登录注册方式**: | ||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||
|  | ||||
| ## 部署 | ||||
| ### 基于 Docker 进行部署 | ||||
| 部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api` | ||||
| ```shell | ||||
| # 使用 SQLite 的部署命令: | ||||
| docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api | ||||
| # 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数,不清楚如何修改请参见下面环境变量一节。 | ||||
| # 例如: | ||||
| docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api | ||||
| ``` | ||||
|  | ||||
| 其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。 | ||||
|  | ||||
| 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 | ||||
| 数据和日志将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 | ||||
|  | ||||
| 如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。 | ||||
|  | ||||
| @@ -149,6 +160,19 @@ sudo service nginx restart | ||||
|  | ||||
| 初始账号用户名为 `root`,密码为 `123456`。 | ||||
|  | ||||
|  | ||||
| ### 基于 Docker Compose 进行部署 | ||||
|  | ||||
| > 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分 | ||||
|  | ||||
| ```shell | ||||
| # 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内 | ||||
| docker-compose up -d | ||||
|  | ||||
| # 查看部署状态 | ||||
| docker-compose ps | ||||
| ``` | ||||
|  | ||||
| ### 手动部署 | ||||
| 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: | ||||
|    ```shell | ||||
| @@ -236,7 +260,9 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | ||||
| <summary><strong>部署到 Zeabur</strong></summary> | ||||
| <div> | ||||
|  | ||||
| > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用。 | ||||
| > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用 | ||||
|  | ||||
| [](https://zeabur.com/templates/7Q0KO3) | ||||
|  | ||||
| 1. 首先 fork 一份代码。 | ||||
| 2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 | ||||
| @@ -251,6 +277,17 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | ||||
| </div> | ||||
| </details> | ||||
|  | ||||
| <details> | ||||
| <summary><strong>部署到 Render</strong></summary> | ||||
| <div> | ||||
|  | ||||
| > Render 提供免费额度,绑卡后可以进一步提升额度 | ||||
|  | ||||
| Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashboard.render.com | ||||
|  | ||||
| </div> | ||||
| </details> | ||||
|  | ||||
| ## 配置 | ||||
| 系统本身开箱即用。 | ||||
|  | ||||
| @@ -269,13 +306,20 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | ||||
|  | ||||
| 注意,具体的 API Base 的格式取决于你所使用的客户端。 | ||||
|  | ||||
| 例如对于 OpenAI 的官方库: | ||||
| ```bash | ||||
| OPENAI_API_KEY="sk-xxxxxx" | ||||
| OPENAI_API_BASE="https://<HOST>:<PORT>/v1"  | ||||
| ``` | ||||
|  | ||||
| ```mermaid | ||||
| graph LR | ||||
|     A(用户) | ||||
|     A --->|请求| B(One API) | ||||
|     A --->|使用 One API 分发的 key 进行请求| B(One API) | ||||
|     B -->|中继请求| C(OpenAI) | ||||
|     B -->|中继请求| D(Azure) | ||||
|     B -->|中继请求| E(其他下游渠道) | ||||
|     B -->|中继请求| E(其他 OpenAI API 格式下游渠道) | ||||
|     B -->|中继并修改请求体和返回体| F(非 OpenAI API 格式下游渠道) | ||||
| ``` | ||||
|  | ||||
| 可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。 | ||||
| @@ -303,24 +347,30 @@ graph LR | ||||
|      + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 | ||||
| 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | ||||
|    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||
| 5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。 | ||||
| 5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||
|    + 例子:`MEMORY_CACHE_ENABLED=true` | ||||
| 6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 | ||||
|    + 例子:`SYNC_FREQUENCY=60` | ||||
| 6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | ||||
| 7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | ||||
|    + 例子:`NODE_TYPE=slave` | ||||
| 7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | ||||
| 8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | ||||
|    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` | ||||
| 8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | ||||
| 9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | ||||
|    + 例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||
| 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||
|    + 例子:`POLLING_INTERVAL=5` | ||||
| 10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||
| 10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||
|     + 例子:`POLLING_INTERVAL=5` | ||||
| 11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||
|     + 例子:`BATCH_UPDATE_ENABLED=true` | ||||
|     + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 | ||||
| 11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||
| 12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||
|     + 例子:`BATCH_UPDATE_INTERVAL=5` | ||||
| 12. 请求频率限制: | ||||
| 13. 请求频率限制: | ||||
|     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 | ||||
|     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 | ||||
| 14. 编码器缓存设置: | ||||
|     + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 | ||||
|     + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 | ||||
| 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | ||||
|  | ||||
| ### 命令行参数 | ||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||
| @@ -360,6 +410,12 @@ https://openai.justsong.cn | ||||
|    + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 | ||||
| 6. 报错:`当前分组负载已饱和,请稍后再试` | ||||
|    + 上游通道 429 了。 | ||||
| 7. 升级之后我的数据会丢失吗? | ||||
|    + 如果使用 MySQL,不会。 | ||||
|    + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 | ||||
| 8. 升级之前数据库需要做变更吗? | ||||
|    + 一般情况下不需要,系统将在初始化的时候自动调整。 | ||||
|    + 如果需要的话,我会在更新日志中说明,并给出脚本。 | ||||
|  | ||||
| ## 相关项目 | ||||
| * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | ||||
|   | ||||
| @@ -21,12 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens | ||||
| var DisplayInCurrencyEnabled = true | ||||
| var DisplayTokenStatEnabled = true | ||||
|  | ||||
| var UsingSQLite = false | ||||
|  | ||||
| // Any options with "Secret", "Token" in its key won't be return by GetOptions | ||||
|  | ||||
| var SessionSecret = uuid.New().String() | ||||
| var SQLitePath = "one-api.db" | ||||
|  | ||||
| var OptionMap map[string]string | ||||
| var OptionMapRWMutex sync.RWMutex | ||||
| @@ -56,6 +53,7 @@ var EmailDomainWhitelist = []string{ | ||||
| } | ||||
|  | ||||
| var DebugEnabled = os.Getenv("DEBUG") == "true" | ||||
| var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" | ||||
|  | ||||
| var LogConsumeEnabled = true | ||||
|  | ||||
| @@ -80,6 +78,7 @@ var QuotaForInviter = 0 | ||||
| var QuotaForInvitee = 0 | ||||
| var ChannelDisableThreshold = 5.0 | ||||
| var AutomaticDisableChannelEnabled = false | ||||
| var AutomaticEnableChannelEnabled = false | ||||
| var QuotaRemindThreshold = 1000 | ||||
| var PreConsumedQuota = 500 | ||||
| var ApproximateTokenEnabled = false | ||||
| @@ -92,11 +91,13 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" | ||||
| var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | ||||
| var RequestInterval = time.Duration(requestInterval) * time.Second | ||||
|  | ||||
| var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY | ||||
| var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second | ||||
|  | ||||
| var BatchUpdateEnabled = false | ||||
| var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) | ||||
|  | ||||
| var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second | ||||
|  | ||||
| const ( | ||||
| 	RequestIdKey = "X-Oneapi-Request-Id" | ||||
| ) | ||||
| @@ -155,9 +156,10 @@ const ( | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	ChannelStatusUnknown  = 0 | ||||
| 	ChannelStatusEnabled  = 1 // don't use 0, 0 is the default value! | ||||
| 	ChannelStatusDisabled = 2 // also don't use 0 | ||||
| 	ChannelStatusUnknown          = 0 | ||||
| 	ChannelStatusEnabled          = 1 // don't use 0, 0 is the default value! | ||||
| 	ChannelStatusManuallyDisabled = 2 // also don't use 0 | ||||
| 	ChannelStatusAutoDisabled     = 3 | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -184,30 +186,32 @@ const ( | ||||
| 	ChannelTypeOpenRouter     = 20 | ||||
| 	ChannelTypeAIProxyLibrary = 21 | ||||
| 	ChannelTypeFastGPT        = 22 | ||||
| 	ChannelTypeTencent        = 23 | ||||
| ) | ||||
|  | ||||
| var ChannelBaseURLs = []string{ | ||||
| 	"",                                // 0 | ||||
| 	"https://api.openai.com",          // 1 | ||||
| 	"https://oa.api2d.net",            // 2 | ||||
| 	"",                                // 3 | ||||
| 	"https://api.closeai-proxy.xyz",   // 4 | ||||
| 	"https://api.openai-sb.com",       // 5 | ||||
| 	"https://api.openaimax.com",       // 6 | ||||
| 	"https://api.ohmygpt.com",         // 7 | ||||
| 	"",                                // 8 | ||||
| 	"https://api.caipacity.com",       // 9 | ||||
| 	"https://api.aiproxy.io",          // 10 | ||||
| 	"",                                // 11 | ||||
| 	"https://api.api2gpt.com",         // 12 | ||||
| 	"https://api.aigc2d.com",          // 13 | ||||
| 	"https://api.anthropic.com",       // 14 | ||||
| 	"https://aip.baidubce.com",        // 15 | ||||
| 	"https://open.bigmodel.cn",        // 16 | ||||
| 	"https://dashscope.aliyuncs.com",  // 17 | ||||
| 	"",                                // 18 | ||||
| 	"https://ai.360.cn",               // 19 | ||||
| 	"https://openrouter.ai/api",       // 20 | ||||
| 	"https://api.aiproxy.io",          // 21 | ||||
| 	"https://fastgpt.run/api/openapi", // 22 | ||||
| 	"",                                  // 0 | ||||
| 	"https://api.openai.com",            // 1 | ||||
| 	"https://oa.api2d.net",              // 2 | ||||
| 	"",                                  // 3 | ||||
| 	"https://api.closeai-proxy.xyz",     // 4 | ||||
| 	"https://api.openai-sb.com",         // 5 | ||||
| 	"https://api.openaimax.com",         // 6 | ||||
| 	"https://api.ohmygpt.com",           // 7 | ||||
| 	"",                                  // 8 | ||||
| 	"https://api.caipacity.com",         // 9 | ||||
| 	"https://api.aiproxy.io",            // 10 | ||||
| 	"",                                  // 11 | ||||
| 	"https://api.api2gpt.com",           // 12 | ||||
| 	"https://api.aigc2d.com",            // 13 | ||||
| 	"https://api.anthropic.com",         // 14 | ||||
| 	"https://aip.baidubce.com",          // 15 | ||||
| 	"https://open.bigmodel.cn",          // 16 | ||||
| 	"https://dashscope.aliyuncs.com",    // 17 | ||||
| 	"",                                  // 18 | ||||
| 	"https://ai.360.cn",                 // 19 | ||||
| 	"https://openrouter.ai/api",         // 20 | ||||
| 	"https://api.aiproxy.io",            // 21 | ||||
| 	"https://fastgpt.run/api/openapi",   // 22 | ||||
| 	"https://hunyuan.cloud.tencent.com", //23 | ||||
| } | ||||
|   | ||||
							
								
								
									
										6
									
								
								common/database.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								common/database.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| package common | ||||
|  | ||||
| var UsingSQLite = false | ||||
| var UsingPostgreSQL = false | ||||
|  | ||||
| var SQLitePath = "one-api.db" | ||||
| @@ -1,11 +1,13 @@ | ||||
| package common | ||||
|  | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| 	"crypto/tls" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"net/smtp" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func SendEmail(subject string, receiver string, content string) error { | ||||
| @@ -13,15 +15,32 @@ func SendEmail(subject string, receiver string, content string) error { | ||||
| 		SMTPFrom = SMTPAccount | ||||
| 	} | ||||
| 	encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) | ||||
|  | ||||
| 	// Extract domain from SMTPFrom | ||||
| 	parts := strings.Split(SMTPFrom, "@") | ||||
| 	var domain string | ||||
| 	if len(parts) > 1 { | ||||
| 		domain = parts[1] | ||||
| 	} | ||||
| 	// Generate a unique Message-ID | ||||
| 	buf := make([]byte, 16) | ||||
| 	_, err := rand.Read(buf) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	messageId := fmt.Sprintf("<%x@%s>", buf, domain) | ||||
|  | ||||
| 	mail := []byte(fmt.Sprintf("To: %s\r\n"+ | ||||
| 		"From: %s<%s>\r\n"+ | ||||
| 		"Subject: %s\r\n"+ | ||||
| 		"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 | ||||
| 		"Date: %s\r\n"+ | ||||
| 		"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", | ||||
| 		receiver, SystemName, SMTPFrom, encodedSubject, content)) | ||||
| 		receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) | ||||
| 	auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) | ||||
| 	addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) | ||||
| 	to := strings.Split(receiver, ";") | ||||
| 	var err error | ||||
|  | ||||
| 	if SMTPPort == 465 { | ||||
| 		tlsConfig := &tls.Config{ | ||||
| 			InsecureSkipVerify: true, | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||
| @@ -16,7 +17,13 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	err = json.Unmarshal(requestBody, &v) | ||||
| 	contentType := c.Request.Header.Get("Content-Type") | ||||
| 	if strings.HasPrefix(contentType, "application/json") { | ||||
| 		err = json.Unmarshal(requestBody, &v) | ||||
| 	} else { | ||||
| 		// skip for now | ||||
| 		// TODO: someday non json request have variant model, we will need to implementation this | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
| @@ -3,8 +3,32 @@ package common | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var DalleSizeRatios = map[string]map[string]float64{ | ||||
| 	"dall-e-2": { | ||||
| 		"256x256":   1, | ||||
| 		"512x512":   1.125, | ||||
| 		"1024x1024": 1.25, | ||||
| 	}, | ||||
| 	"dall-e-3": { | ||||
| 		"1024x1024": 1, | ||||
| 		"1024x1792": 2, | ||||
| 		"1792x1024": 2, | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| var DalleGenerationImageAmounts = map[string][2]int{ | ||||
| 	"dall-e-2": {1, 10}, | ||||
| 	"dall-e-3": {1, 1}, // OpenAI allows n=1 currently. | ||||
| } | ||||
|  | ||||
| var DalleImagePromptLengthLimitations = map[string]int{ | ||||
| 	"dall-e-2": 1000, | ||||
| 	"dall-e-3": 4000, | ||||
| } | ||||
|  | ||||
| // ModelRatio | ||||
| // https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf | ||||
| @@ -19,11 +43,15 @@ var ModelRatio = map[string]float64{ | ||||
| 	"gpt-4-32k":                 30, | ||||
| 	"gpt-4-32k-0314":            30, | ||||
| 	"gpt-4-32k-0613":            30, | ||||
| 	"gpt-4-1106-preview":        5,    // $0.01 / 1K tokens | ||||
| 	"gpt-4-vision-preview":      5,    // $0.01 / 1K tokens | ||||
| 	"gpt-3.5-turbo":             0.75, // $0.0015 / 1K tokens | ||||
| 	"gpt-3.5-turbo-0301":        0.75, | ||||
| 	"gpt-3.5-turbo-0613":        0.75, | ||||
| 	"gpt-3.5-turbo-16k":         1.5, // $0.003 / 1K tokens | ||||
| 	"gpt-3.5-turbo-16k-0613":    1.5, | ||||
| 	"gpt-3.5-turbo-instruct":    0.75, // $0.0015 / 1K tokens | ||||
| 	"gpt-3.5-turbo-1106":        0.5,  // $0.001 / 1K tokens | ||||
| 	"text-ada-001":              0.2, | ||||
| 	"text-babbage-001":          0.25, | ||||
| 	"text-curie-001":            1, | ||||
| @@ -31,7 +59,11 @@ var ModelRatio = map[string]float64{ | ||||
| 	"text-davinci-003":          10, | ||||
| 	"text-davinci-edit-001":     10, | ||||
| 	"code-davinci-edit-001":     10, | ||||
| 	"whisper-1":                 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens | ||||
| 	"whisper-1":                 15,  // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens | ||||
| 	"tts-1":                     7.5, // $0.015 / 1K characters | ||||
| 	"tts-1-1106":                7.5, | ||||
| 	"tts-1-hd":                  15, // $0.030 / 1K characters | ||||
| 	"tts-1-hd-1106":             15, | ||||
| 	"davinci":                   10, | ||||
| 	"curie":                     10, | ||||
| 	"babbage":                   10, | ||||
| @@ -40,25 +72,30 @@ var ModelRatio = map[string]float64{ | ||||
| 	"text-search-ada-doc-001":   10, | ||||
| 	"text-moderation-stable":    0.1, | ||||
| 	"text-moderation-latest":    0.1, | ||||
| 	"dall-e":                    8, | ||||
| 	"dall-e-2":                  8,      // $0.016 - $0.020 / image | ||||
| 	"dall-e-3":                  20,     // $0.040 - $0.120 / image | ||||
| 	"claude-instant-1":          0.815,  // $1.63 / 1M tokens | ||||
| 	"claude-2":                  5.51,   // $11.02 / 1M tokens | ||||
| 	"claude-2.0":                5.51,   // $11.02 / 1M tokens | ||||
| 	"claude-2.1":                5.51,   // $11.02 / 1M tokens | ||||
| 	"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens | ||||
| 	"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens | ||||
| 	"ERNIE-Bot-4":               8.572,  // ¥0.12 / 1k tokens | ||||
| 	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens | ||||
| 	"PaLM-2":                    1, | ||||
| 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens | ||||
| 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | ||||
| 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | ||||
| 	"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens | ||||
| 	"qwen-v1":                   0.8572, // ¥0.012 / 1k tokens | ||||
| 	"qwen-plus-v1":              1,      // ¥0.014 / 1k tokens | ||||
| 	"qwen-turbo":                0.8572, // ¥0.012 / 1k tokens | ||||
| 	"qwen-plus":                 10,     // ¥0.14 / 1k tokens | ||||
| 	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens | ||||
| 	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens | ||||
| 	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens | ||||
| 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens | ||||
| 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens | ||||
| 	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens | ||||
| 	"360GPT_S2_V9.4":            0.8572, // ¥0.012 / 1k tokens | ||||
| 	"hunyuan":                   7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 | ||||
| } | ||||
|  | ||||
| func ModelRatio2JSONString() string { | ||||
| @@ -85,9 +122,24 @@ func GetModelRatio(name string) float64 { | ||||
|  | ||||
| func GetCompletionRatio(name string) float64 { | ||||
| 	if strings.HasPrefix(name, "gpt-3.5") { | ||||
| 		if strings.HasSuffix(name, "1106") { | ||||
| 			return 2 | ||||
| 		} | ||||
| 		if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" { | ||||
| 			// TODO: clear this after 2023-12-11 | ||||
| 			now := time.Now() | ||||
| 			// https://platform.openai.com/docs/models/continuous-model-upgrades | ||||
| 			// if after 2023-12-11, use 2 | ||||
| 			if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) { | ||||
| 				return 2 | ||||
| 			} | ||||
| 		} | ||||
| 		return 1.333333 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "gpt-4") { | ||||
| 		if strings.HasSuffix(name, "preview") { | ||||
| 			return 3 | ||||
| 		} | ||||
| 		return 2 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "claude-instant-1") { | ||||
|   | ||||
| @@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int { | ||||
| func MessageWithRequestId(message string, id string) string { | ||||
| 	return fmt.Sprintf("%s (request id: %s)", message, id) | ||||
| } | ||||
|  | ||||
| func String2Int(str string) int { | ||||
| 	num, err := strconv.Atoi(str) | ||||
| 	if err != nil { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|   | ||||
| @@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He | ||||
| } | ||||
|  | ||||
| func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL) | ||||
| 	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL()) | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
|  | ||||
| 	if err != nil { | ||||
| @@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { | ||||
|  | ||||
| func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||
| 	baseURL := common.ChannelBaseURLs[channel.Type] | ||||
| 	if channel.BaseURL == "" { | ||||
| 		channel.BaseURL = baseURL | ||||
| 	if channel.GetBaseURL() == "" { | ||||
| 		channel.BaseURL = &baseURL | ||||
| 	} | ||||
| 	switch channel.Type { | ||||
| 	case common.ChannelTypeOpenAI: | ||||
| 		if channel.BaseURL != "" { | ||||
| 			baseURL = channel.BaseURL | ||||
| 		if channel.GetBaseURL() != "" { | ||||
| 			baseURL = channel.GetBaseURL() | ||||
| 		} | ||||
| 	case common.ChannelTypeAzure: | ||||
| 		return 0, errors.New("尚未实现") | ||||
| 	case common.ChannelTypeCustom: | ||||
| 		baseURL = channel.BaseURL | ||||
| 		baseURL = channel.GetBaseURL() | ||||
| 	case common.ChannelTypeCloseAI: | ||||
| 		return updateChannelCloseAIBalance(channel) | ||||
| 	case common.ChannelTypeOpenAISB: | ||||
|   | ||||
| @@ -5,13 +5,15 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { | ||||
| @@ -42,14 +44,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai | ||||
| 	} | ||||
| 	requestURL := common.ChannelBaseURLs[channel.Type] | ||||
| 	if channel.Type == common.ChannelTypeAzure { | ||||
| 		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model) | ||||
| 		requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) | ||||
| 	} else { | ||||
| 		if channel.BaseURL != "" { | ||||
| 			requestURL = channel.BaseURL | ||||
| 		if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { | ||||
| 			requestURL = baseURL | ||||
| 		} | ||||
| 		requestURL += "/v1/chat/completions" | ||||
| 	} | ||||
|  | ||||
| 		requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) | ||||
| 	} | ||||
| 	jsonData, err := json.Marshal(request) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| @@ -70,11 +72,18 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 	var response TextResponse | ||||
| 	err = json.NewDecoder(resp.Body).Decode(&response) | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(body, &response) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil | ||||
| 	} | ||||
| 	if response.Usage.CompletionTokens == 0 { | ||||
| 		if response.Error.Message == "" { | ||||
| 			response.Error.Message = "补全 tokens 非预期返回 0" | ||||
| 		} | ||||
| 		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error | ||||
| 	} | ||||
| 	return nil, nil | ||||
| @@ -136,20 +145,32 @@ func TestChannel(c *gin.Context) { | ||||
| var testAllChannelsLock sync.Mutex | ||||
| var testAllChannelsRunning bool = false | ||||
|  | ||||
| // disable & notify | ||||
| func disableChannel(channelId int, channelName string, reason string) { | ||||
| func notifyRootUser(subject string, content string) { | ||||
| 	if common.RootUserEmail == "" { | ||||
| 		common.RootUserEmail = model.GetRootUserEmail() | ||||
| 	} | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||
| 	err := common.SendEmail(subject, common.RootUserEmail, content) | ||||
| 	if err != nil { | ||||
| 		common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // disable & notify | ||||
| func disableChannel(channelId int, channelName string, reason string) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| // enable & notify | ||||
| func enableChannel(channelId int, channelName string) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| func testAllChannels(notify bool) error { | ||||
| 	if common.RootUserEmail == "" { | ||||
| 		common.RootUserEmail = model.GetRootUserEmail() | ||||
| @@ -172,20 +193,21 @@ func testAllChannels(notify bool) error { | ||||
| 	} | ||||
| 	go func() { | ||||
| 		for _, channel := range channels { | ||||
| 			if channel.Status != common.ChannelStatusEnabled { | ||||
| 				continue | ||||
| 			} | ||||
| 			isChannelEnabled := channel.Status == common.ChannelStatusEnabled | ||||
| 			tik := time.Now() | ||||
| 			err, openaiErr := testChannel(channel, *testRequest) | ||||
| 			tok := time.Now() | ||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | ||||
| 			if milliseconds > disableThreshold { | ||||
| 			if isChannelEnabled && milliseconds > disableThreshold { | ||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 			} | ||||
| 			if shouldDisableChannel(openaiErr, -1) { | ||||
| 			if isChannelEnabled && shouldDisableChannel(openaiErr, -1) { | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 			} | ||||
| 			if !isChannelEnabled && shouldEnableChannel(err, openaiErr) { | ||||
| 				enableChannel(channel.Id, channel.Name) | ||||
| 			} | ||||
| 			channel.UpdateResponseTime(milliseconds) | ||||
| 			time.Sleep(common.RequestInterval) | ||||
| 		} | ||||
|   | ||||
| @@ -127,6 +127,23 @@ func DeleteChannel(c *gin.Context) { | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func DeleteDisabledChannel(c *gin.Context) { | ||||
| 	rows, err := model.DeleteDisabledChannel() | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data":    rows, | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func UpdateChannel(c *gin.Context) { | ||||
| 	channel := model.Channel{} | ||||
| 	err := c.ShouldBindJSON(&channel) | ||||
|   | ||||
| @@ -55,12 +55,21 @@ func init() { | ||||
| 	// https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||
| 	openAIModels = []OpenAIModels{ | ||||
| 		{ | ||||
| 			Id:         "dall-e", | ||||
| 			Id:         "dall-e-2", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "dall-e", | ||||
| 			Root:       "dall-e-2", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "dall-e-3", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "dall-e-3", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| @@ -72,6 +81,42 @@ func init() { | ||||
| 			Root:       "whisper-1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "tts-1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "tts-1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "tts-1-1106", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "tts-1-1106", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "tts-1-hd", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "tts-1-hd", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "tts-1-hd-1106", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "tts-1-hd-1106", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-3.5-turbo", | ||||
| 			Object:     "model", | ||||
| @@ -117,6 +162,24 @@ func init() { | ||||
| 			Root:       "gpt-3.5-turbo-16k-0613", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-3.5-turbo-1106", | ||||
| 			Object:     "model", | ||||
| 			Created:    1699593571, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "gpt-3.5-turbo-1106", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-3.5-turbo-instruct", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "gpt-3.5-turbo-instruct", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-4", | ||||
| 			Object:     "model", | ||||
| @@ -171,6 +234,24 @@ func init() { | ||||
| 			Root:       "gpt-4-32k-0613", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-4-1106-preview", | ||||
| 			Object:     "model", | ||||
| 			Created:    1699593571, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "gpt-4-1106-preview", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "gpt-4-vision-preview", | ||||
| 			Object:     "model", | ||||
| 			Created:    1699593571, | ||||
| 			OwnedBy:    "openai", | ||||
| 			Permission: permission, | ||||
| 			Root:       "gpt-4-vision-preview", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "text-embedding-ada-002", | ||||
| 			Object:     "model", | ||||
| @@ -265,7 +346,7 @@ func init() { | ||||
| 			Id:         "claude-instant-1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anturopic", | ||||
| 			OwnedBy:    "anthropic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-instant-1", | ||||
| 			Parent:     nil, | ||||
| @@ -274,11 +355,29 @@ func init() { | ||||
| 			Id:         "claude-2", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anturopic", | ||||
| 			OwnedBy:    "anthropic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-2", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "claude-2.1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anthropic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-2.1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "claude-2.0", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anthropic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-2.0", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "ERNIE-Bot", | ||||
| 			Object:     "model", | ||||
| @@ -297,6 +396,15 @@ func init() { | ||||
| 			Root:       "ERNIE-Bot-turbo", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "ERNIE-Bot-4", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "baidu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "ERNIE-Bot-4", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "Embedding-V1", | ||||
| 			Object:     "model", | ||||
| @@ -315,6 +423,15 @@ func init() { | ||||
| 			Root:       "PaLM-2", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "chatglm_turbo", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "zhipu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "chatglm_turbo", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "chatglm_pro", | ||||
| 			Object:     "model", | ||||
| @@ -343,21 +460,21 @@ func init() { | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "qwen-v1", | ||||
| 			Id:         "qwen-turbo", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "ali", | ||||
| 			Permission: permission, | ||||
| 			Root:       "qwen-v1", | ||||
| 			Root:       "qwen-turbo", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "qwen-plus-v1", | ||||
| 			Id:         "qwen-plus", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "ali", | ||||
| 			Permission: permission, | ||||
| 			Root:       "qwen-plus-v1", | ||||
| 			Root:       "qwen-plus", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| @@ -415,12 +532,12 @@ func init() { | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "360GPT_S2_V9.4", | ||||
| 			Id:         "hunyuan", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "360", | ||||
| 			OwnedBy:    "tencent", | ||||
| 			Permission: permission, | ||||
| 			Root:       "360GPT_S2_V9.4", | ||||
| 			Root:       "hunyuan", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 	} | ||||
|   | ||||
| @@ -46,7 +46,7 @@ func UpdateOption(c *gin.Context) { | ||||
| 		if option.Value == "true" && common.GitHubClientId == "" { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client ID 以及 GitHub Client Secret!", | ||||
| 				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
|   | ||||
| @@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct { | ||||
| func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { | ||||
| 	query := "" | ||||
| 	if len(request.Messages) != 0 { | ||||
| 		query = request.Messages[len(request.Messages)-1].Content | ||||
| 		query = request.Messages[len(request.Messages)-1].StringContent() | ||||
| 	} | ||||
| 	return &AIProxyLibraryRequest{ | ||||
| 		Model:  request.Model, | ||||
|   | ||||
| @@ -88,18 +88,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { | ||||
| 		message := request.Messages[i] | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, AliMessage{ | ||||
| 				User: message.Content, | ||||
| 				User: message.StringContent(), | ||||
| 				Bot:  "Okay", | ||||
| 			}) | ||||
| 			continue | ||||
| 		} else { | ||||
| 			if i == len(request.Messages)-1 { | ||||
| 				prompt = message.Content | ||||
| 				prompt = message.StringContent() | ||||
| 				break | ||||
| 			} | ||||
| 			messages = append(messages, AliMessage{ | ||||
| 				User: message.Content, | ||||
| 				Bot:  request.Messages[i+1].Content, | ||||
| 				User: message.StringContent(), | ||||
| 				Bot:  request.Messages[i+1].StringContent(), | ||||
| 			}) | ||||
| 			i++ | ||||
| 		} | ||||
|   | ||||
| @@ -4,13 +4,14 @@ import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| @@ -21,16 +22,44 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	userId := c.GetInt("id") | ||||
| 	group := c.GetString("group") | ||||
| 	tokenName := c.GetString("token_name") | ||||
|  | ||||
| 	var ttsRequest TextToSpeechRequest | ||||
| 	if relayMode == RelayModeAudioSpeech { | ||||
| 		// Read JSON | ||||
| 		err := common.UnmarshalBodyReusable(c, &ttsRequest) | ||||
| 		// Check if JSON is valid | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "invalid_json", http.StatusBadRequest) | ||||
| 		} | ||||
| 		audioModel = ttsRequest.Model | ||||
| 		// Check if text is too long 4096 | ||||
| 		if len(ttsRequest.Input) > 4096 { | ||||
| 			return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	preConsumedTokens := common.PreConsumedQuota | ||||
| 	modelRatio := common.GetModelRatio(audioModel) | ||||
| 	groupRatio := common.GetGroupRatio(group) | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	preConsumedQuota := int(float64(preConsumedTokens) * ratio) | ||||
| 	var quota int | ||||
| 	var preConsumedQuota int | ||||
| 	switch relayMode { | ||||
| 	case RelayModeAudioSpeech: | ||||
| 		preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) | ||||
| 		quota = preConsumedQuota | ||||
| 	default: | ||||
| 		preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio) | ||||
| 	} | ||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	// Check if user quota is enough | ||||
| 	if userQuota-preConsumedQuota < 0 { | ||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| 	} | ||||
| 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||
| @@ -62,19 +91,33 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
|  | ||||
| 	baseURL := common.ChannelBaseURLs[channelType] | ||||
| 	requestURL := c.Request.URL.String() | ||||
|  | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
|  | ||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
| 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) | ||||
| 	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { | ||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api | ||||
| 		apiVersion := GetAPIVersion(c) | ||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) | ||||
| 	} | ||||
|  | ||||
| 	requestBody := c.Request.Body | ||||
|  | ||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
|  | ||||
| 	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { | ||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		req.Header.Set("api-key", apiKey) | ||||
| 		req.ContentLength = c.Request.ContentLength | ||||
| 	} else { | ||||
| 		req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
|  | ||||
| @@ -91,47 +134,44 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	var audioResponse AudioResponse | ||||
|  | ||||
| 	if relayMode != RelayModeAudioSpeech { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		var whisperResponse WhisperResponse | ||||
| 		err = json.Unmarshal(responseBody, &whisperResponse) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		quota = countTokenText(whisperResponse.Text, audioModel) | ||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| 	} | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		if preConsumedQuota > 0 { | ||||
| 			// we need to roll back the pre-consumed quota | ||||
| 			defer func(ctx context.Context) { | ||||
| 				go func() { | ||||
| 					// negative means add quota back for token & user | ||||
| 					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) | ||||
| 					if err != nil { | ||||
| 						common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) | ||||
| 					} | ||||
| 				}() | ||||
| 			}(c.Request.Context()) | ||||
| 		} | ||||
| 		return relayErrorHandler(resp) | ||||
| 	} | ||||
| 	quotaDelta := quota - preConsumedQuota | ||||
| 	defer func(ctx context.Context) { | ||||
| 		go func() { | ||||
| 			quota := countTokenText(audioResponse.Text, audioModel) | ||||
| 			quotaDelta := quota - preConsumedQuota | ||||
| 			err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 			} | ||||
| 			err = model.CacheUpdateUserQuota(userId) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error update user quota cache: " + err.Error()) | ||||
| 			} | ||||
| 			if quota != 0 { | ||||
| 				tokenName := c.GetString("token_name") | ||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) | ||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 				channelId := c.GetInt("channel_id") | ||||
| 				model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 			} | ||||
| 		}() | ||||
| 		go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) | ||||
| 	}(c.Request.Context()) | ||||
|  | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &audioResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
|  | ||||
| 	for k, v := range resp.Header { | ||||
| 		c.Writer.Header().Set(k, v[0]) | ||||
| 	} | ||||
|   | ||||
| @@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.Content, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    "assistant", | ||||
| @@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { | ||||
| 		} else { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
|   | ||||
| @@ -70,7 +70,9 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { | ||||
| 		} else if message.Role == "assistant" { | ||||
| 			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) | ||||
| 		} else if message.Role == "system" { | ||||
| 			prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) | ||||
| 			if prompt == "" { | ||||
| 				prompt = message.StringContent() | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	prompt += "\n\nAssistant:" | ||||
|   | ||||
| @@ -10,41 +10,76 @@ import ( | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func isWithinRange(element string, value int) bool { | ||||
| 	if _, ok := common.DalleGenerationImageAmounts[element]; !ok { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	min := common.DalleGenerationImageAmounts[element][0] | ||||
| 	max := common.DalleGenerationImageAmounts[element][1] | ||||
|  | ||||
| 	return value >= min && value <= max | ||||
| } | ||||
|  | ||||
| func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	imageModel := "dall-e" | ||||
| 	imageModel := "dall-e-2" | ||||
| 	imageSize := "1024x1024" | ||||
|  | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| 	channelType := c.GetInt("channel") | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	userId := c.GetInt("id") | ||||
| 	consumeQuota := c.GetBool("consume_quota") | ||||
| 	group := c.GetString("group") | ||||
|  | ||||
| 	var imageRequest ImageRequest | ||||
| 	if consumeQuota { | ||||
| 		err := common.UnmarshalBodyReusable(c, &imageRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 	err := common.UnmarshalBodyReusable(c, &imageRequest) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Size validation | ||||
| 	if imageRequest.Size != "" { | ||||
| 		imageSize = imageRequest.Size | ||||
| 	} | ||||
|  | ||||
| 	// Model validation | ||||
| 	if imageRequest.Model != "" { | ||||
| 		imageModel = imageRequest.Model | ||||
| 	} | ||||
|  | ||||
| 	imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] | ||||
|  | ||||
| 	// Check if model is supported | ||||
| 	if hasValidSize { | ||||
| 		if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { | ||||
| 			if imageSize == "1024x1024" { | ||||
| 				imageCostRatio *= 2 | ||||
| 			} else { | ||||
| 				imageCostRatio *= 1.5 | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Prompt validation | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 		return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 		return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Not "256x256", "512x512", or "1024x1024" | ||||
| 	if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { | ||||
| 		return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest) | ||||
| 	// Check prompt length | ||||
| 	if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { | ||||
| 		return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// N should between 1 and 10 | ||||
| 	if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { | ||||
| 		return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) | ||||
| 	// Number of generated images validation | ||||
| 	if isWithinRange(imageModel, imageRequest.N) == false { | ||||
| 		return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// map model name | ||||
| @@ -61,18 +96,21 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 			isModelMapped = true | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	baseURL := common.ChannelBaseURLs[channelType] | ||||
| 	requestURL := c.Request.URL.String() | ||||
|  | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
|  | ||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
| 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) | ||||
| 	if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations { | ||||
| 		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api | ||||
| 		apiVersion := GetAPIVersion(c) | ||||
| 		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview | ||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) | ||||
| 	} | ||||
|  | ||||
| 	var requestBody io.Reader | ||||
| 	if isModelMapped { | ||||
| 	if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body | ||||
| 		jsonStr, err := json.Marshal(imageRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| @@ -87,26 +125,23 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | ||||
|  | ||||
| 	sizeRatio := 1.0 | ||||
| 	// Size | ||||
| 	if imageRequest.Size == "256x256" { | ||||
| 		sizeRatio = 1 | ||||
| 	} else if imageRequest.Size == "512x512" { | ||||
| 		sizeRatio = 1.125 | ||||
| 	} else if imageRequest.Size == "1024x1024" { | ||||
| 		sizeRatio = 1.25 | ||||
| 	} | ||||
| 	quota := int(ratio*sizeRatio*1000) * imageRequest.N | ||||
| 	quota := int(ratio*imageCostRatio*1000) * imageRequest.N | ||||
|  | ||||
| 	if consumeQuota && userQuota-quota < 0 { | ||||
| 		return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden) | ||||
| 	if userQuota-quota < 0 { | ||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| 	} | ||||
|  | ||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
| 	token := c.Request.Header.Get("Authorization") | ||||
| 	if channelType == common.ChannelTypeAzure { // Azure authentication | ||||
| 		token = strings.TrimPrefix(token, "Bearer ") | ||||
| 		req.Header.Set("api-key", token) | ||||
| 	} else { | ||||
| 		req.Header.Set("Authorization", token) | ||||
| 	} | ||||
|  | ||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
| @@ -127,43 +162,39 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	var textResponse ImageResponse | ||||
|  | ||||
| 	defer func(ctx context.Context) { | ||||
| 		if consumeQuota { | ||||
| 			err := model.PostConsumeTokenQuota(tokenId, quota) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 			} | ||||
| 			err = model.CacheUpdateUserQuota(userId) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error update user quota cache: " + err.Error()) | ||||
| 			} | ||||
| 			if quota != 0 { | ||||
| 				tokenName := c.GetString("token_name") | ||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) | ||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 				channelId := c.GetInt("channel_id") | ||||
| 				model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 			} | ||||
| 		err := model.PostConsumeTokenQuota(tokenId, quota) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 		} | ||||
| 		err = model.CacheUpdateUserQuota(userId) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error update user quota cache: " + err.Error()) | ||||
| 		} | ||||
| 		if quota != 0 { | ||||
| 			tokenName := c.GetString("token_name") | ||||
| 			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 			model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) | ||||
| 			model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 			channelId := c.GetInt("channel_id") | ||||
| 			model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 		} | ||||
| 	}(c.Request.Context()) | ||||
|  | ||||
| 	if consumeQuota { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
|  | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = json.Unmarshal(responseBody, &textResponse) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
|  | ||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &textResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
|  | ||||
| 	for k, v := range resp.Header { | ||||
| 		c.Writer.Header().Set(k, v[0]) | ||||
|   | ||||
| @@ -88,30 +88,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var textResponse TextResponse | ||||
| 	if consumeQuota { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
| 		err = json.Unmarshal(responseBody, &textResponse) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
| 		if textResponse.Error.Type != "" { | ||||
| 			return &OpenAIErrorWithStatusCode{ | ||||
| 				OpenAIError: textResponse.Error, | ||||
| 				StatusCode:  resp.StatusCode, | ||||
| 			}, nil | ||||
| 		} | ||||
| 		// Reset response body | ||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &textResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if textResponse.Error.Type != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: textResponse.Error, | ||||
| 			StatusCode:  resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	// Reset response body | ||||
| 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
|  | ||||
| 	// We shouldn't set the header before we parse the response body, because the parse part may fail. | ||||
| 	// And then we will have to send an error response, but in this case, the header has already been set. | ||||
| 	// So the httpClient will be confused by the response. | ||||
| @@ -120,7 +119,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp | ||||
| 		c.Writer.Header().Set(k, v[0]) | ||||
| 	} | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err := io.Copy(c.Writer, resp.Body) | ||||
| 	_, err = io.Copy(c.Writer, resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| @@ -132,7 +131,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp | ||||
| 	if textResponse.Usage.TotalTokens == 0 { | ||||
| 		completionTokens := 0 | ||||
| 		for _, choice := range textResponse.Choices { | ||||
| 			completionTokens += countTokenText(choice.Message.Content, model) | ||||
| 			completionTokens += countTokenText(choice.Message.StringContent(), model) | ||||
| 		} | ||||
| 		textResponse.Usage = Usage{ | ||||
| 			PromptTokens:     promptTokens, | ||||
|   | ||||
| @@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { | ||||
| 	} | ||||
| 	for _, message := range textRequest.Messages { | ||||
| 		palmMessage := PaLMChatMessage{ | ||||
| 			Content: message.Content, | ||||
| 			Content: message.StringContent(), | ||||
| 		} | ||||
| 		if message.Role == "user" { | ||||
| 			palmMessage.Author = "0" | ||||
|   | ||||
							
								
								
									
										287
									
								
								controller/relay-tencent.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										287
									
								
								controller/relay-tencent.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,287 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"crypto/hmac" | ||||
| 	"crypto/sha1" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://cloud.tencent.com/document/product/1729/97732 | ||||
|  | ||||
| type TencentMessage struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type TencentChatRequest struct { | ||||
| 	AppId    int64  `json:"app_id"`    // 腾讯云账号的 APPID | ||||
| 	SecretId string `json:"secret_id"` // 官网 SecretId | ||||
| 	// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 | ||||
| 	// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 | ||||
| 	Timestamp int64 `json:"timestamp"` | ||||
| 	// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, | ||||
| 	// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 | ||||
| 	Expired int64  `json:"expired"` | ||||
| 	QueryID string `json:"query_id"` //请求 Id,用于问题排查 | ||||
| 	// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 | ||||
| 	// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 | ||||
| 	// 建议该参数和 top_p 只设置1个,不要同时更改 top_p | ||||
| 	Temperature float64 `json:"temperature"` | ||||
| 	// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 | ||||
| 	// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 | ||||
| 	// 建议该参数和 temperature 只设置1个,不要同时更改 | ||||
| 	TopP float64 `json:"top_p"` | ||||
| 	// Stream 0:同步,1:流式 (默认,协议:SSE) | ||||
| 	// 同步请求超时:60s,如果内容较长建议使用流式 | ||||
| 	Stream int `json:"stream"` | ||||
| 	// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 | ||||
| 	// 输入 content 总数最大支持 3000 token。 | ||||
| 	Messages []TencentMessage `json:"messages"` | ||||
| } | ||||
|  | ||||
| type TencentError struct { | ||||
| 	Code    int    `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type TencentUsage struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| 	TotalTokens  int `json:"total_tokens"` | ||||
| } | ||||
|  | ||||
| type TencentResponseChoices struct { | ||||
| 	FinishReason string         `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 | ||||
| 	Messages     TencentMessage `json:"messages,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 | ||||
| 	Delta        TencentMessage `json:"delta,omitempty"`         // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 | ||||
| } | ||||
|  | ||||
| type TencentChatResponse struct { | ||||
| 	Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 | ||||
| 	Created string                   `json:"created,omitempty"` // unix 时间戳的字符串 | ||||
| 	Id      string                   `json:"id,omitempty"`      // 会话 id | ||||
| 	Usage   Usage                    `json:"usage,omitempty"`   // token 数量 | ||||
| 	Error   TencentError             `json:"error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值 | ||||
| 	Note    string                   `json:"note,omitempty"`    // 注释 | ||||
| 	ReqID   string                   `json:"req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { | ||||
| 	messages := make([]TencentMessage, 0, len(request.Messages)) | ||||
| 	for i := 0; i < len(request.Messages); i++ { | ||||
| 		message := request.Messages[i] | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, TencentMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, TencentMessage{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 			continue | ||||
| 		} | ||||
| 		messages = append(messages, TencentMessage{ | ||||
| 			Content: message.StringContent(), | ||||
| 			Role:    message.Role, | ||||
| 		}) | ||||
| 	} | ||||
| 	stream := 0 | ||||
| 	if request.Stream { | ||||
| 		stream = 1 | ||||
| 	} | ||||
| 	return &TencentChatRequest{ | ||||
| 		Timestamp:   common.GetTimestamp(), | ||||
| 		Expired:     common.GetTimestamp() + 24*60*60, | ||||
| 		QueryID:     common.GetUUID(), | ||||
| 		Temperature: request.Temperature, | ||||
| 		TopP:        request.TopP, | ||||
| 		Stream:      stream, | ||||
| 		Messages:    messages, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Usage:   response.Usage, | ||||
| 	} | ||||
| 	if len(response.Choices) > 0 { | ||||
| 		choice := OpenAITextResponseChoice{ | ||||
| 			Index: 0, | ||||
| 			Message: Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: response.Choices[0].Messages.Content, | ||||
| 			}, | ||||
| 			FinishReason: response.Choices[0].FinishReason, | ||||
| 		} | ||||
| 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice) | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "tencent-hunyuan", | ||||
| 	} | ||||
| 	if len(TencentResponse.Choices) > 0 { | ||||
| 		var choice ChatCompletionsStreamResponseChoice | ||||
| 		choice.Delta.Content = TencentResponse.Choices[0].Delta.Content | ||||
| 		if TencentResponse.Choices[0].FinishReason == "stop" { | ||||
| 			choice.FinishReason = &stopFinishReason | ||||
| 		} | ||||
| 		response.Choices = append(response.Choices, choice) | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||||
| 	var responseText string | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] != "data:" { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[5:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	setEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var TencentResponse TencentChatResponse | ||||
| 			err := json.Unmarshal([]byte(data), &TencentResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response := streamResponseTencent2OpenAI(&TencentResponse) | ||||
| 			if len(response.Choices) != 0 { | ||||
| 				responseText += response.Choices[0].Delta.Content | ||||
| 			} | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var TencentResponse TencentChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &TencentResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if TencentResponse.Error.Code != 0 { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: TencentResponse.Error.Message, | ||||
| 				Code:    TencentResponse.Error.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseTencent2OpenAI(&TencentResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|  | ||||
| func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) { | ||||
| 	parts := strings.Split(config, "|") | ||||
| 	if len(parts) != 3 { | ||||
| 		err = errors.New("invalid tencent config") | ||||
| 		return | ||||
| 	} | ||||
| 	appId, err = strconv.ParseInt(parts[0], 10, 64) | ||||
| 	secretId = parts[1] | ||||
| 	secretKey = parts[2] | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func getTencentSign(req TencentChatRequest, secretKey string) string { | ||||
| 	params := make([]string, 0) | ||||
| 	params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) | ||||
| 	params = append(params, "secret_id="+req.SecretId) | ||||
| 	params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) | ||||
| 	params = append(params, "query_id="+req.QueryID) | ||||
| 	params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) | ||||
| 	params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) | ||||
| 	params = append(params, "stream="+strconv.Itoa(req.Stream)) | ||||
| 	params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) | ||||
|  | ||||
| 	var messageStr string | ||||
| 	for _, msg := range req.Messages { | ||||
| 		messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) | ||||
| 	} | ||||
| 	messageStr = strings.TrimSuffix(messageStr, ",") | ||||
| 	params = append(params, "messages=["+messageStr+"]") | ||||
|  | ||||
| 	sort.Sort(sort.StringSlice(params)) | ||||
| 	url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") | ||||
| 	mac := hmac.New(sha1.New, []byte(secretKey)) | ||||
| 	signURL := url | ||||
| 	mac.Write([]byte(signURL)) | ||||
| 	sign := mac.Sum([]byte(nil)) | ||||
| 	return base64.StdEncoding.EncodeToString(sign) | ||||
| } | ||||
| @@ -6,13 +6,15 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"math" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -24,13 +26,21 @@ const ( | ||||
| 	APITypeAli | ||||
| 	APITypeXunfei | ||||
| 	APITypeAIProxyLibrary | ||||
| 	APITypeTencent | ||||
| ) | ||||
|  | ||||
| var httpClient *http.Client | ||||
| var impatientHTTPClient *http.Client | ||||
|  | ||||
| func init() { | ||||
| 	httpClient = &http.Client{} | ||||
| 	if common.RelayTimeout == 0 { | ||||
| 		httpClient = &http.Client{} | ||||
| 	} else { | ||||
| 		httpClient = &http.Client{ | ||||
| 			Timeout: time.Duration(common.RelayTimeout) * time.Second, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	impatientHTTPClient = &http.Client{ | ||||
| 		Timeout: 5 * time.Second, | ||||
| 	} | ||||
| @@ -41,14 +51,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| 	userId := c.GetInt("id") | ||||
| 	consumeQuota := c.GetBool("consume_quota") | ||||
| 	group := c.GetString("group") | ||||
| 	var textRequest GeneralOpenAIRequest | ||||
| 	if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { | ||||
| 		err := common.UnmarshalBodyReusable(c, &textRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 		} | ||||
| 	err := common.UnmarshalBodyReusable(c, &textRequest) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 	} | ||||
| 	if relayMode == RelayModeModerations && textRequest.Model == "" { | ||||
| 		textRequest.Model = "text-moderation-latest" | ||||
| @@ -109,22 +116,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 		apiType = APITypeXunfei | ||||
| 	case common.ChannelTypeAIProxyLibrary: | ||||
| 		apiType = APITypeAIProxyLibrary | ||||
| 	case common.ChannelTypeTencent: | ||||
| 		apiType = APITypeTencent | ||||
| 	} | ||||
| 	baseURL := common.ChannelBaseURLs[channelType] | ||||
| 	requestURL := c.Request.URL.String() | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
| 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) | ||||
| 	switch apiType { | ||||
| 	case APITypeOpenAI: | ||||
| 		if channelType == common.ChannelTypeAzure { | ||||
| 			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | ||||
| 			query := c.Request.URL.Query() | ||||
| 			apiVersion := query.Get("api-version") | ||||
| 			if apiVersion == "" { | ||||
| 				apiVersion = c.GetString("api_version") | ||||
| 			} | ||||
| 			apiVersion := GetAPIVersion(c) | ||||
| 			requestURL := strings.Split(requestURL, "?")[0] | ||||
| 			requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) | ||||
| 			baseURL = c.GetString("base_url") | ||||
| @@ -135,7 +140,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			model_ = strings.TrimSuffix(model_, "-0301") | ||||
| 			model_ = strings.TrimSuffix(model_, "-0314") | ||||
| 			model_ = strings.TrimSuffix(model_, "-0613") | ||||
| 			fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) | ||||
|  | ||||
| 			requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) | ||||
| 			fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType) | ||||
| 		} | ||||
| 	case APITypeClaude: | ||||
| 		fullRequestURL = "https://api.anthropic.com/v1/complete" | ||||
| @@ -148,6 +155,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" | ||||
| 		case "ERNIE-Bot-turbo": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" | ||||
| 		case "ERNIE-Bot-4": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" | ||||
| 		case "BLOOMZ-7B": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" | ||||
| 		case "Embedding-V1": | ||||
| @@ -179,6 +188,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 		if relayMode == RelayModeEmbeddings { | ||||
| 			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" | ||||
| 		} | ||||
| 	case APITypeTencent: | ||||
| 		fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" | ||||
| 	case APITypeAIProxyLibrary: | ||||
| 		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) | ||||
| 	} | ||||
| @@ -204,6 +215,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	if userQuota-preConsumedQuota < 0 { | ||||
| 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| 	} | ||||
| 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||
| @@ -214,7 +228,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 		preConsumedQuota = 0 | ||||
| 		common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) | ||||
| 	} | ||||
| 	if consumeQuota && preConsumedQuota > 0 { | ||||
| 	if preConsumedQuota > 0 { | ||||
| 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) | ||||
| @@ -282,6 +296,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeTencent: | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		appId, secretId, secretKey, err := parseTencentConfig(apiKey) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		tencentRequest := requestOpenAI2Tencent(textRequest) | ||||
| 		tencentRequest.AppId = appId | ||||
| 		tencentRequest.SecretId = secretId | ||||
| 		jsonStr, err := json.Marshal(tencentRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		sign := getTencentSign(*tencentRequest, secretKey) | ||||
| 		c.Request.Header.Set("Authorization", sign) | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeAIProxyLibrary: | ||||
| 		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) | ||||
| 		aiProxyLibraryRequest.LibraryId = c.GetString("library_id") | ||||
| @@ -329,11 +360,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			if textRequest.Stream { | ||||
| 				req.Header.Set("X-DashScope-SSE", "enable") | ||||
| 			} | ||||
| 		case APITypeTencent: | ||||
| 			req.Header.Set("Authorization", apiKey) | ||||
| 		case APITypePaLM: | ||||
| 			// do not set Authorization header | ||||
| 		default: | ||||
| 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||
| 		} | ||||
| 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
| 		if isStream && c.Request.Header.Get("Accept") == "" { | ||||
| 			req.Header.Set("Accept", "text/event-stream") | ||||
| 		} | ||||
| 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | ||||
| 		resp, err = httpClient.Do(req) | ||||
| 		if err != nil { | ||||
| @@ -369,39 +407,36 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	defer func(ctx context.Context) { | ||||
| 		// c.Writer.Flush() | ||||
| 		go func() { | ||||
| 			if consumeQuota { | ||||
| 				quota := 0 | ||||
| 				completionRatio := common.GetCompletionRatio(textRequest.Model) | ||||
| 				promptTokens = textResponse.Usage.PromptTokens | ||||
| 				completionTokens = textResponse.Usage.CompletionTokens | ||||
|  | ||||
| 				quota = promptTokens + int(float64(completionTokens)*completionRatio) | ||||
| 				quota = int(float64(quota) * ratio) | ||||
| 				if ratio != 0 && quota <= 0 { | ||||
| 					quota = 1 | ||||
| 				} | ||||
| 				totalTokens := promptTokens + completionTokens | ||||
| 				if totalTokens == 0 { | ||||
| 					// in this case, must be some error happened | ||||
| 					// we cannot just return, because we may have to return the pre-consumed quota | ||||
| 					quota = 0 | ||||
| 				} | ||||
| 				quotaDelta := quota - preConsumedQuota | ||||
| 				err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error consuming token remain quota: "+err.Error()) | ||||
| 				} | ||||
| 				err = model.CacheUpdateUserQuota(userId) | ||||
| 				if err != nil { | ||||
| 					common.LogError(ctx, "error update user quota cache: "+err.Error()) | ||||
| 				} | ||||
| 				if quota != 0 { | ||||
| 					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 					model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) | ||||
| 					model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 					model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 				} | ||||
| 			quota := 0 | ||||
| 			completionRatio := common.GetCompletionRatio(textRequest.Model) | ||||
| 			promptTokens = textResponse.Usage.PromptTokens | ||||
| 			completionTokens = textResponse.Usage.CompletionTokens | ||||
| 			quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) | ||||
| 			if ratio != 0 && quota <= 0 { | ||||
| 				quota = 1 | ||||
| 			} | ||||
| 			totalTokens := promptTokens + completionTokens | ||||
| 			if totalTokens == 0 { | ||||
| 				// in this case, must be some error happened | ||||
| 				// we cannot just return, because we may have to return the pre-consumed quota | ||||
| 				quota = 0 | ||||
| 			} | ||||
| 			quotaDelta := quota - preConsumedQuota | ||||
| 			err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, "error consuming token remain quota: "+err.Error()) | ||||
| 			} | ||||
| 			err = model.CacheUpdateUserQuota(userId) | ||||
| 			if err != nil { | ||||
| 				common.LogError(ctx, "error update user quota cache: "+err.Error()) | ||||
| 			} | ||||
| 			if quota != 0 { | ||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 				model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) | ||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||
| 				model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 			} | ||||
|  | ||||
| 		}() | ||||
| 	}(c.Request.Context()) | ||||
| 	switch apiType { | ||||
| @@ -415,7 +450,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) | ||||
| 			err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| @@ -581,6 +616,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeTencent: | ||||
| 		if isStream { | ||||
| 			err, responseText := tencentStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			textResponse.Usage.PromptTokens = promptTokens | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := tencentHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	default: | ||||
| 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | ||||
| 	} | ||||
|   | ||||
| @@ -1,52 +1,64 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkoukk/tiktoken-go" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/pkoukk/tiktoken-go" | ||||
| ) | ||||
|  | ||||
| var stopFinishReason = "stop" | ||||
|  | ||||
| // tokenEncoderMap won't grow after initialization | ||||
| var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | ||||
| var defaultTokenEncoder *tiktoken.Tiktoken | ||||
|  | ||||
| func InitTokenEncoders() { | ||||
| 	common.SysLog("initializing token encoders") | ||||
| 	fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") | ||||
| 	gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") | ||||
| 	if err != nil { | ||||
| 		common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error())) | ||||
| 		common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) | ||||
| 	} | ||||
| 	defaultTokenEncoder = gpt35TokenEncoder | ||||
| 	gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") | ||||
| 	if err != nil { | ||||
| 		common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) | ||||
| 	} | ||||
| 	for model, _ := range common.ModelRatio { | ||||
| 		tokenEncoder, err := tiktoken.EncodingForModel(model) | ||||
| 		if err != nil { | ||||
| 			common.SysError(fmt.Sprintf("using fallback encoder for model %s", model)) | ||||
| 			tokenEncoderMap[model] = fallbackTokenEncoder | ||||
| 			continue | ||||
| 		if strings.HasPrefix(model, "gpt-3.5") { | ||||
| 			tokenEncoderMap[model] = gpt35TokenEncoder | ||||
| 		} else if strings.HasPrefix(model, "gpt-4") { | ||||
| 			tokenEncoderMap[model] = gpt4TokenEncoder | ||||
| 		} else { | ||||
| 			tokenEncoderMap[model] = nil | ||||
| 		} | ||||
| 		tokenEncoderMap[model] = tokenEncoder | ||||
| 	} | ||||
| 	common.SysLog("token encoders initialized") | ||||
| } | ||||
|  | ||||
| func getTokenEncoder(model string) *tiktoken.Tiktoken { | ||||
| 	if tokenEncoder, ok := tokenEncoderMap[model]; ok { | ||||
| 	tokenEncoder, ok := tokenEncoderMap[model] | ||||
| 	if ok && tokenEncoder != nil { | ||||
| 		return tokenEncoder | ||||
| 	} | ||||
| 	tokenEncoder, err := tiktoken.EncodingForModel(model) | ||||
| 	if err != nil { | ||||
| 		common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) | ||||
| 		tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo") | ||||
| 	if ok { | ||||
| 		tokenEncoder, err := tiktoken.EncodingForModel(model) | ||||
| 		if err != nil { | ||||
| 			common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error())) | ||||
| 			common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) | ||||
| 			tokenEncoder = defaultTokenEncoder | ||||
| 		} | ||||
| 		tokenEncoderMap[model] = tokenEncoder | ||||
| 		return tokenEncoder | ||||
| 	} | ||||
| 	tokenEncoderMap[model] = tokenEncoder | ||||
| 	return tokenEncoder | ||||
| 	return defaultTokenEncoder | ||||
| } | ||||
|  | ||||
| func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | ||||
| @@ -75,7 +87,7 @@ func countTokenMessages(messages []Message, model string) int { | ||||
| 	tokenNum := 0 | ||||
| 	for _, message := range messages { | ||||
| 		tokenNum += tokensPerMessage | ||||
| 		tokenNum += getTokenNum(tokenEncoder, message.Content) | ||||
| 		tokenNum += getTokenNum(tokenEncoder, message.StringContent()) | ||||
| 		tokenNum += getTokenNum(tokenEncoder, message.Role) | ||||
| 		if message.Name != nil { | ||||
| 			tokenNum += tokensPerName | ||||
| @@ -133,6 +145,19 @@ func shouldDisableChannel(err *OpenAIError, statusCode int) bool { | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| func shouldEnableChannel(err error, openAIErr *OpenAIError) bool { | ||||
| 	if !common.AutomaticEnableChannelEnabled { | ||||
| 		return false | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	if openAIErr != nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func setEventStreamHeaders(c *gin.Context) { | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| @@ -167,3 +192,48 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr | ||||
| 	openAIErrorWithStatusCode.OpenAIError = textResponse.Error | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func getFullRequestURL(baseURL string, requestURL string, channelType int) string { | ||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
|  | ||||
| 	if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { | ||||
| 		switch channelType { | ||||
| 		case common.ChannelTypeOpenAI: | ||||
| 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) | ||||
| 		case common.ChannelTypeAzure: | ||||
| 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) | ||||
| 		} | ||||
| 	} | ||||
| 	return fullRequestURL | ||||
| } | ||||
|  | ||||
| func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { | ||||
| 	// quotaDelta is remaining quota to be consumed | ||||
| 	err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||
| 	if err != nil { | ||||
| 		common.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 	} | ||||
| 	err = model.CacheUpdateUserQuota(userId) | ||||
| 	if err != nil { | ||||
| 		common.SysError("error update user quota cache: " + err.Error()) | ||||
| 	} | ||||
| 	// totalQuota is total quota consumed | ||||
| 	if totalQuota != 0 { | ||||
| 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 		model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) | ||||
| 		model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) | ||||
| 		model.UpdateChannelUsedQuota(channelId, totalQuota) | ||||
| 	} | ||||
| 	if totalQuota <= 0 { | ||||
| 		common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetAPIVersion(c *gin.Context) string { | ||||
| 	query := c.Request.URL.Query() | ||||
| 	apiVersion := query.Get("api-version") | ||||
| 	if apiVersion == "" { | ||||
| 		apiVersion = c.GetString("api_version") | ||||
| 	} | ||||
| 	return apiVersion | ||||
| } | ||||
|   | ||||
| @@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.Content, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    "assistant", | ||||
| @@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma | ||||
| 		} else { | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| @@ -220,6 +220,9 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin | ||||
| 	for !stop { | ||||
| 		select { | ||||
| 		case xunfeiResponse = <-dataChan: | ||||
| 			if len(xunfeiResponse.Payload.Choices.Text) == 0 { | ||||
| 				continue | ||||
| 			} | ||||
| 			content += xunfeiResponse.Payload.Choices.Text[0].Content | ||||
| 			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens | ||||
| 			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens | ||||
| @@ -295,8 +298,8 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, | ||||
| 		common.SysLog("api_version not found, use default: " + apiVersion) | ||||
| 	} | ||||
| 	domain := "general" | ||||
| 	if apiVersion == "v2.1" { | ||||
| 		domain = "generalv2" | ||||
| 	if apiVersion != "v1.1" { | ||||
| 		domain += strings.Split(apiVersion, ".")[0] | ||||
| 	} | ||||
| 	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) | ||||
| 	return domain, authUrl | ||||
|   | ||||
| @@ -114,7 +114,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    "system", | ||||
| 				Content: message.Content, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    "user", | ||||
| @@ -123,7 +123,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { | ||||
| 		} else { | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
|   | ||||
| @@ -12,10 +12,49 @@ import ( | ||||
|  | ||||
| type Message struct { | ||||
| 	Role    string  `json:"role"` | ||||
| 	Content string  `json:"content"` | ||||
| 	Content any     `json:"content"` | ||||
| 	Name    *string `json:"name,omitempty"` | ||||
| } | ||||
|  | ||||
| type ImageURL struct { | ||||
| 	Url    string `json:"url,omitempty"` | ||||
| 	Detail string `json:"detail,omitempty"` | ||||
| } | ||||
|  | ||||
| type TextContent struct { | ||||
| 	Type string `json:"type,omitempty"` | ||||
| 	Text string `json:"text,omitempty"` | ||||
| } | ||||
|  | ||||
| type ImageContent struct { | ||||
| 	Type     string    `json:"type,omitempty"` | ||||
| 	ImageURL *ImageURL `json:"image_url,omitempty"` | ||||
| } | ||||
|  | ||||
| func (m Message) StringContent() string { | ||||
| 	content, ok := m.Content.(string) | ||||
| 	if ok { | ||||
| 		return content | ||||
| 	} | ||||
| 	contentList, ok := m.Content.([]any) | ||||
| 	if ok { | ||||
| 		var contentStr string | ||||
| 		for _, contentItem := range contentList { | ||||
| 			contentMap, ok := contentItem.(map[string]any) | ||||
| 			if !ok { | ||||
| 				continue | ||||
| 			} | ||||
| 			if contentMap["type"] == "text" { | ||||
| 				if subStr, ok := contentMap["text"].(string); ok { | ||||
| 					contentStr += subStr | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		return contentStr | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	RelayModeUnknown = iota | ||||
| 	RelayModeChatCompletions | ||||
| @@ -24,24 +63,37 @@ const ( | ||||
| 	RelayModeModerations | ||||
| 	RelayModeImagesGenerations | ||||
| 	RelayModeEdits | ||||
| 	RelayModeAudio | ||||
| 	RelayModeAudioSpeech | ||||
| 	RelayModeAudioTranscription | ||||
| 	RelayModeAudioTranslation | ||||
| ) | ||||
|  | ||||
| // https://platform.openai.com/docs/api-reference/chat | ||||
|  | ||||
| type ResponseFormat struct { | ||||
| 	Type string `json:"type,omitempty"` | ||||
| } | ||||
|  | ||||
| type GeneralOpenAIRequest struct { | ||||
| 	Model       string    `json:"model,omitempty"` | ||||
| 	Messages    []Message `json:"messages,omitempty"` | ||||
| 	Prompt      any       `json:"prompt,omitempty"` | ||||
| 	Stream      bool      `json:"stream,omitempty"` | ||||
| 	MaxTokens   int       `json:"max_tokens,omitempty"` | ||||
| 	Temperature float64   `json:"temperature,omitempty"` | ||||
| 	TopP        float64   `json:"top_p,omitempty"` | ||||
| 	N           int       `json:"n,omitempty"` | ||||
| 	Input       any       `json:"input,omitempty"` | ||||
| 	Instruction string    `json:"instruction,omitempty"` | ||||
| 	Size        string    `json:"size,omitempty"` | ||||
| 	Functions   any       `json:"functions,omitempty"` | ||||
| 	Model            string          `json:"model,omitempty"` | ||||
| 	Messages         []Message       `json:"messages,omitempty"` | ||||
| 	Prompt           any             `json:"prompt,omitempty"` | ||||
| 	Stream           bool            `json:"stream,omitempty"` | ||||
| 	MaxTokens        int             `json:"max_tokens,omitempty"` | ||||
| 	Temperature      float64         `json:"temperature,omitempty"` | ||||
| 	TopP             float64         `json:"top_p,omitempty"` | ||||
| 	N                int             `json:"n,omitempty"` | ||||
| 	Input            any             `json:"input,omitempty"` | ||||
| 	Instruction      string          `json:"instruction,omitempty"` | ||||
| 	Size             string          `json:"size,omitempty"` | ||||
| 	Functions        any             `json:"functions,omitempty"` | ||||
| 	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"` | ||||
| 	PresencePenalty  float64         `json:"presence_penalty,omitempty"` | ||||
| 	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"` | ||||
| 	Seed             float64         `json:"seed,omitempty"` | ||||
| 	Tools            any             `json:"tools,omitempty"` | ||||
| 	ToolChoice       any             `json:"tool_choice,omitempty"` | ||||
| 	User             string          `json:"user,omitempty"` | ||||
| } | ||||
|  | ||||
| func (r GeneralOpenAIRequest) ParseInput() []string { | ||||
| @@ -77,16 +129,30 @@ type TextRequest struct { | ||||
| 	//Stream   bool      `json:"stream"` | ||||
| } | ||||
|  | ||||
| // ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create | ||||
| type ImageRequest struct { | ||||
| 	Prompt string `json:"prompt"` | ||||
| 	N      int    `json:"n"` | ||||
| 	Size   string `json:"size"` | ||||
| 	Model          string `json:"model"` | ||||
| 	Prompt         string `json:"prompt" binding:"required"` | ||||
| 	N              int    `json:"n,omitempty"` | ||||
| 	Size           string `json:"size,omitempty"` | ||||
| 	Quality        string `json:"quality,omitempty"` | ||||
| 	ResponseFormat string `json:"response_format,omitempty"` | ||||
| 	Style          string `json:"style,omitempty"` | ||||
| 	User           string `json:"user,omitempty"` | ||||
| } | ||||
|  | ||||
| type AudioResponse struct { | ||||
| type WhisperResponse struct { | ||||
| 	Text string `json:"text,omitempty"` | ||||
| } | ||||
|  | ||||
| type TextToSpeechRequest struct { | ||||
| 	Model          string  `json:"model" binding:"required"` | ||||
| 	Input          string  `json:"input" binding:"required"` | ||||
| 	Voice          string  `json:"voice" binding:"required"` | ||||
| 	Speed          float64 `json:"speed"` | ||||
| 	ResponseFormat string  `json:"response_format"` | ||||
| } | ||||
|  | ||||
| type Usage struct { | ||||
| 	PromptTokens     int `json:"prompt_tokens"` | ||||
| 	CompletionTokens int `json:"completion_tokens"` | ||||
| @@ -183,14 +249,22 @@ func Relay(c *gin.Context) { | ||||
| 		relayMode = RelayModeImagesGenerations | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { | ||||
| 		relayMode = RelayModeEdits | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||
| 		relayMode = RelayModeAudio | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { | ||||
| 		relayMode = RelayModeAudioSpeech | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { | ||||
| 		relayMode = RelayModeAudioTranscription | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { | ||||
| 		relayMode = RelayModeAudioTranslation | ||||
| 	} | ||||
| 	var err *OpenAIErrorWithStatusCode | ||||
| 	switch relayMode { | ||||
| 	case RelayModeImagesGenerations: | ||||
| 		err = relayImageHelper(c, relayMode) | ||||
| 	case RelayModeAudio: | ||||
| 	case RelayModeAudioSpeech: | ||||
| 		fallthrough | ||||
| 	case RelayModeAudioTranslation: | ||||
| 		fallthrough | ||||
| 	case RelayModeAudioTranscription: | ||||
| 		err = relayAudioHelper(c, relayMode) | ||||
| 	default: | ||||
| 		err = relayTextHelper(c, relayMode) | ||||
|   | ||||
| @@ -9,21 +9,21 @@ services: | ||||
|     ports: | ||||
|       - "3000:3000" | ||||
|     volumes: | ||||
|       - ./data:/data | ||||
|       - ./data/oneapi:/data | ||||
|       - ./logs:/app/logs | ||||
|     environment: | ||||
|       - SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api  # 修改此行,或注释掉以使用 SQLite 作为数据库 | ||||
|       - SQL_DSN=oneapi:123456@tcp(db:3306)/one-api  # 修改此行,或注释掉以使用 SQLite 作为数据库 | ||||
|       - REDIS_CONN_STRING=redis://redis | ||||
|       - SESSION_SECRET=random_string  # 修改为随机字符串 | ||||
|       - TZ=Asia/Shanghai | ||||
| #      - NODE_TYPE=slave  # 多机部署时从节点取消注释该行 | ||||
| #      - SYNC_FREQUENCY=60  # 需要定期从数据库加载数据时取消注释该行 | ||||
| #      - FRONTEND_BASE_URL=https://openai.justsong.cn  # 多机部署时从节点取消注释该行 | ||||
|  | ||||
|     depends_on: | ||||
|       - redis | ||||
|       - db | ||||
|     healthcheck: | ||||
|       test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ] | ||||
|       test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] | ||||
|       interval: 30s | ||||
|       timeout: 10s | ||||
|       retries: 3 | ||||
| @@ -32,3 +32,18 @@ services: | ||||
|     image: redis:latest | ||||
|     container_name: redis | ||||
|     restart: always | ||||
|  | ||||
|   db: | ||||
|     image: mysql:8.2.0 | ||||
|     restart: always | ||||
|     container_name: mysql | ||||
|     volumes: | ||||
|       - ./data/mysql:/var/lib/mysql  # 挂载目录,持久化存储 | ||||
|     ports: | ||||
|       - '3306:3306' | ||||
|     environment: | ||||
|       TZ: Asia/Shanghai   # 设置时区 | ||||
|       MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码 | ||||
|       MYSQL_USER: oneapi   # 创建专用用户 | ||||
|       MYSQL_PASSWORD: '123456'    # 设置专用用户密码 | ||||
|       MYSQL_DATABASE: one-api   # 自动创建数据库 | ||||
							
								
								
									
										10
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								go.mod
									
									
									
									
									
								
							| @@ -15,8 +15,9 @@ require ( | ||||
| 	github.com/google/uuid v1.3.0 | ||||
| 	github.com/gorilla/websocket v1.5.0 | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.5 | ||||
| 	golang.org/x/crypto v0.9.0 | ||||
| 	golang.org/x/crypto v0.14.0 | ||||
| 	gorm.io/driver/mysql v1.4.3 | ||||
| 	gorm.io/driver/postgres v1.5.2 | ||||
| 	gorm.io/driver/sqlite v1.4.3 | ||||
| 	gorm.io/gorm v1.25.0 | ||||
| ) | ||||
| @@ -52,10 +53,9 @@ require ( | ||||
| 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | ||||
| 	github.com/ugorji/go/codec v1.2.11 // indirect | ||||
| 	golang.org/x/arch v0.3.0 // indirect | ||||
| 	golang.org/x/net v0.10.0 // indirect | ||||
| 	golang.org/x/sys v0.8.0 // indirect | ||||
| 	golang.org/x/text v0.9.0 // indirect | ||||
| 	golang.org/x/net v0.17.0 // indirect | ||||
| 	golang.org/x/sys v0.13.0 // indirect | ||||
| 	golang.org/x/text v0.13.0 // indirect | ||||
| 	google.golang.org/protobuf v1.30.0 // indirect | ||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||
| 	gorm.io/driver/postgres v1.5.2 // indirect | ||||
| ) | ||||
|   | ||||
							
								
								
									
										17
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								go.sum
									
									
									
									
									
								
							| @@ -150,11 +150,11 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu | ||||
| golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= | ||||
| golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | ||||
| golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= | ||||
| golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= | ||||
| golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= | ||||
| golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= | ||||
| golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= | ||||
| golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | ||||
| golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= | ||||
| golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= | ||||
| golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= | ||||
| golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= | ||||
| golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| @@ -162,14 +162,14 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc | ||||
| golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= | ||||
| golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= | ||||
| golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | ||||
| golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= | ||||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= | ||||
| golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= | ||||
| golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= | ||||
| golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= | ||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| @@ -198,7 +198,6 @@ gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBp | ||||
| gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= | ||||
| gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= | ||||
| gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= | ||||
| gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74= | ||||
| gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= | ||||
| gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= | ||||
| gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= | ||||
|   | ||||
| @@ -119,6 +119,7 @@ | ||||
|   " 年 ": " y ", | ||||
|   "未测试": "Not tested", | ||||
|   "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", | ||||
|   "已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.", | ||||
|   "已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", | ||||
|   "通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", | ||||
|   "已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!", | ||||
| @@ -139,6 +140,7 @@ | ||||
|   "启用": "Enable", | ||||
|   "编辑": "Edit", | ||||
|   "添加新的渠道": "Add a new channel", | ||||
|   "测试所有通道": "Test all channels", | ||||
|   "测试所有已启用通道": "Test all enabled channels", | ||||
|   "更新所有已启用通道余额": "Update the balance of all enabled channels", | ||||
|   "刷新": "Refresh", | ||||
|   | ||||
							
								
								
									
										20
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								main.go
									
									
									
									
									
								
							| @@ -2,6 +2,7 @@ package main | ||||
|  | ||||
| import ( | ||||
| 	"embed" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-contrib/sessions/cookie" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| @@ -50,18 +51,17 @@ func main() { | ||||
| 	// Initialize options | ||||
| 	model.InitOptionMap() | ||||
| 	if common.RedisEnabled { | ||||
| 		// for compatibility with old versions | ||||
| 		common.MemoryCacheEnabled = true | ||||
| 	} | ||||
| 	if common.MemoryCacheEnabled { | ||||
| 		common.SysLog("memory cache enabled") | ||||
| 		common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) | ||||
| 		model.InitChannelCache() | ||||
| 	} | ||||
| 	if os.Getenv("SYNC_FREQUENCY") != "" { | ||||
| 		frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY")) | ||||
| 		if err != nil { | ||||
| 			common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error()) | ||||
| 		} | ||||
| 		common.SyncFrequency = frequency | ||||
| 		go model.SyncOptions(frequency) | ||||
| 		if common.RedisEnabled { | ||||
| 			go model.SyncChannelCache(frequency) | ||||
| 		} | ||||
| 	if common.MemoryCacheEnabled { | ||||
| 		go model.SyncOptions(common.SyncFrequency) | ||||
| 		go model.SyncChannelCache(common.SyncFrequency) | ||||
| 	} | ||||
| 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { | ||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) | ||||
|   | ||||
| @@ -94,7 +94,7 @@ func TokenAuth() func(c *gin.Context) { | ||||
| 			abortWithMessage(c, http.StatusUnauthorized, err.Error()) | ||||
| 			return | ||||
| 		} | ||||
| 		userEnabled, err := model.IsUserEnabled(token.UserId) | ||||
| 		userEnabled, err := model.CacheIsUserEnabled(token.UserId) | ||||
| 		if err != nil { | ||||
| 			abortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| 			return | ||||
| @@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) { | ||||
| 		c.Set("id", token.UserId) | ||||
| 		c.Set("token_id", token.Id) | ||||
| 		c.Set("token_name", token.Name) | ||||
| 		requestURL := c.Request.URL.String() | ||||
| 		consumeQuota := true | ||||
| 		if strings.HasPrefix(requestURL, "/v1/models") { | ||||
| 			consumeQuota = false | ||||
| 		} | ||||
| 		c.Set("consume_quota", consumeQuota) | ||||
| 		if len(parts) > 1 { | ||||
| 			if model.IsAdmin(token.UserId) { | ||||
| 				c.Set("channelId", parts[1]) | ||||
|   | ||||
| @@ -25,12 +25,12 @@ func Distribute() func(c *gin.Context) { | ||||
| 		if ok { | ||||
| 			id, err := strconv.Atoi(channelId.(string)) | ||||
| 			if err != nil { | ||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") | ||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") | ||||
| 				return | ||||
| 			} | ||||
| 			channel, err = model.GetChannelById(id, true) | ||||
| 			if err != nil { | ||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") | ||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") | ||||
| 				return | ||||
| 			} | ||||
| 			if channel.Status != common.ChannelStatusEnabled { | ||||
| @@ -40,10 +40,7 @@ func Distribute() func(c *gin.Context) { | ||||
| 		} else { | ||||
| 			// Select a channel for the user | ||||
| 			var modelRequest ModelRequest | ||||
| 			var err error | ||||
| 			if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||
| 				err = common.UnmarshalBodyReusable(c, &modelRequest) | ||||
| 			} | ||||
| 			err := common.UnmarshalBodyReusable(c, &modelRequest) | ||||
| 			if err != nil { | ||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的请求") | ||||
| 				return | ||||
| @@ -60,10 +57,10 @@ func Distribute() func(c *gin.Context) { | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "dall-e" | ||||
| 					modelRequest.Model = "dall-e-2" | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "whisper-1" | ||||
| 				} | ||||
| @@ -82,9 +79,9 @@ func Distribute() func(c *gin.Context) { | ||||
| 		c.Set("channel", channel.Type) | ||||
| 		c.Set("channel_id", channel.Id) | ||||
| 		c.Set("channel_name", channel.Name) | ||||
| 		c.Set("model_mapping", channel.ModelMapping) | ||||
| 		c.Set("model_mapping", channel.GetModelMapping()) | ||||
| 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||
| 		c.Set("base_url", channel.BaseURL) | ||||
| 		c.Set("base_url", channel.GetBaseURL()) | ||||
| 		switch channel.Type { | ||||
| 		case common.ChannelTypeAzure: | ||||
| 			c.Set("api_version", channel.Other) | ||||
|   | ||||
| @@ -10,16 +10,25 @@ type Ability struct { | ||||
| 	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"` | ||||
| 	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` | ||||
| 	Enabled   bool   `json:"enabled"` | ||||
| 	Priority  int64  `json:"priority" gorm:"bigint;default:0"` | ||||
| 	Priority  *int64 `json:"priority" gorm:"bigint;default:0;index"` | ||||
| } | ||||
|  | ||||
| func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||
| 	ability := Ability{} | ||||
| 	groupCol := "`group`" | ||||
| 	trueVal := "1" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		groupCol = `"group"` | ||||
| 		trueVal = "true" | ||||
| 	} | ||||
|  | ||||
| 	var err error = nil | ||||
| 	if common.UsingSQLite { | ||||
| 		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error | ||||
| 	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) | ||||
| 	channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) | ||||
| 	if common.UsingSQLite || common.UsingPostgreSQL { | ||||
| 		err = channelQuery.Order("RANDOM()").First(&ability).Error | ||||
| 	} else { | ||||
| 		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error | ||||
| 		err = channelQuery.Order("RAND()").First(&ability).Error | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
|   | ||||
| @@ -21,14 +21,18 @@ var ( | ||||
| ) | ||||
|  | ||||
| func CacheGetTokenByKey(key string) (*Token, error) { | ||||
| 	keyCol := "`key`" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		keyCol = `"key"` | ||||
| 	} | ||||
| 	var token Token | ||||
| 	if !common.RedisEnabled { | ||||
| 		err := DB.Where("`key` = ?", key).First(&token).Error | ||||
| 		err := DB.Where(keyCol+" = ?", key).First(&token).Error | ||||
| 		return &token, err | ||||
| 	} | ||||
| 	tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) | ||||
| 	if err != nil { | ||||
| 		err := DB.Where("`key` = ?", key).First(&token).Error | ||||
| 		err := DB.Where(keyCol+" = ?", key).First(&token).Error | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @@ -165,7 +169,7 @@ func InitChannelCache() { | ||||
| 	for group, model2channels := range newGroup2model2channels { | ||||
| 		for model, channels := range model2channels { | ||||
| 			sort.Slice(channels, func(i, j int) bool { | ||||
| 				return channels[i].Priority > channels[j].Priority | ||||
| 				return channels[i].GetPriority() > channels[j].GetPriority() | ||||
| 			}) | ||||
| 			newGroup2model2channels[group][model] = channels | ||||
| 		} | ||||
| @@ -186,7 +190,7 @@ func SyncChannelCache(frequency int) { | ||||
| } | ||||
|  | ||||
| func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||
| 	if !common.RedisEnabled { | ||||
| 	if !common.MemoryCacheEnabled { | ||||
| 		return GetRandomSatisfiedChannel(group, model) | ||||
| 	} | ||||
| 	channelSyncLock.RLock() | ||||
| @@ -195,11 +199,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error | ||||
| 	if len(channels) == 0 { | ||||
| 		return nil, errors.New("channel not found") | ||||
| 	} | ||||
| 	endIdx := len(channels) | ||||
| 	// choose by priority | ||||
| 	firstChannel := channels[0] | ||||
| 	if firstChannel.Priority > 0 { | ||||
| 		return firstChannel, nil | ||||
| 	if firstChannel.GetPriority() > 0 { | ||||
| 		for i := range channels { | ||||
| 			if channels[i].GetPriority() != firstChannel.GetPriority() { | ||||
| 				endIdx = i | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	idx := rand.Intn(len(channels)) | ||||
| 	idx := rand.Intn(endIdx) | ||||
| 	return channels[idx], nil | ||||
| } | ||||
|   | ||||
| @@ -11,19 +11,19 @@ type Channel struct { | ||||
| 	Key                string  `json:"key" gorm:"not null;index"` | ||||
| 	Status             int     `json:"status" gorm:"default:1"` | ||||
| 	Name               string  `json:"name" gorm:"index"` | ||||
| 	Weight             int     `json:"weight"` | ||||
| 	Weight             *uint   `json:"weight" gorm:"default:0"` | ||||
| 	CreatedTime        int64   `json:"created_time" gorm:"bigint"` | ||||
| 	TestTime           int64   `json:"test_time" gorm:"bigint"` | ||||
| 	ResponseTime       int     `json:"response_time"` // in milliseconds | ||||
| 	BaseURL            string  `json:"base_url" gorm:"column:base_url"` | ||||
| 	BaseURL            *string `json:"base_url" gorm:"column:base_url;default:''"` | ||||
| 	Other              string  `json:"other"` | ||||
| 	Balance            float64 `json:"balance"` // in USD | ||||
| 	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"` | ||||
| 	Models             string  `json:"models"` | ||||
| 	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"` | ||||
| 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"` | ||||
| 	ModelMapping       string  `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | ||||
| 	Priority           int64   `json:"priority" gorm:"bigint;default:0"` | ||||
| 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | ||||
| 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"` | ||||
| } | ||||
|  | ||||
| func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | ||||
| @@ -38,7 +38,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | ||||
| } | ||||
|  | ||||
| func SearchChannels(keyword string) (channels []*Channel, err error) { | ||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error | ||||
| 	keyCol := "`key`" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		keyCol = `"key"` | ||||
| 	} | ||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error | ||||
| 	return channels, err | ||||
| } | ||||
|  | ||||
| @@ -53,17 +57,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) { | ||||
| 	return &channel, err | ||||
| } | ||||
|  | ||||
| func GetRandomChannel() (*Channel, error) { | ||||
| 	channel := Channel{} | ||||
| 	var err error = nil | ||||
| 	if common.UsingSQLite { | ||||
| 		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error | ||||
| 	} else { | ||||
| 		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error | ||||
| 	} | ||||
| 	return &channel, err | ||||
| } | ||||
|  | ||||
| func BatchInsertChannels(channels []Channel) error { | ||||
| 	var err error | ||||
| 	err = DB.Create(&channels).Error | ||||
| @@ -79,6 +72,27 @@ func BatchInsertChannels(channels []Channel) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (channel *Channel) GetPriority() int64 { | ||||
| 	if channel.Priority == nil { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	return *channel.Priority | ||||
| } | ||||
|  | ||||
| func (channel *Channel) GetBaseURL() string { | ||||
| 	if channel.BaseURL == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	return *channel.BaseURL | ||||
| } | ||||
|  | ||||
| func (channel *Channel) GetModelMapping() string { | ||||
| 	if channel.ModelMapping == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	return *channel.ModelMapping | ||||
| } | ||||
|  | ||||
| func (channel *Channel) Insert() error { | ||||
| 	var err error | ||||
| 	err = DB.Create(channel).Error | ||||
| @@ -155,3 +169,13 @@ func updateChannelUsedQuota(id int, quota int) { | ||||
| 		common.SysError("failed to update channel used quota: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func DeleteChannelByStatus(status int64) (int64, error) { | ||||
| 	result := DB.Where("status = ?", status).Delete(&Channel{}) | ||||
| 	return result.RowsAffected, result.Error | ||||
| } | ||||
|  | ||||
| func DeleteDisabledChannel() (int64, error) { | ||||
| 	result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) | ||||
| 	return result.RowsAffected, result.Error | ||||
| } | ||||
|   | ||||
							
								
								
									
										25
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -8,18 +8,18 @@ import ( | ||||
| ) | ||||
|  | ||||
| type Log struct { | ||||
| 	Id               int    `json:"id"` | ||||
| 	UserId           int    `json:"user_id"` | ||||
| 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index"` | ||||
| 	Type             int    `json:"type" gorm:"index"` | ||||
| 	Id               int    `json:"id;index:idx_created_at_id,priority:1"` | ||||
| 	UserId           int    `json:"user_id" gorm:"index"` | ||||
| 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"` | ||||
| 	Type             int    `json:"type" gorm:"index:idx_created_at_type"` | ||||
| 	Content          string `json:"content"` | ||||
| 	Username         string `json:"username" gorm:"index;default:''"` | ||||
| 	Username         string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` | ||||
| 	TokenName        string `json:"token_name" gorm:"index;default:''"` | ||||
| 	ModelName        string `json:"model_name" gorm:"index;default:''"` | ||||
| 	ModelName        string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` | ||||
| 	Quota            int    `json:"quota" gorm:"default:0"` | ||||
| 	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"` | ||||
| 	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"` | ||||
| 	Channel          int    `json:"channel" gorm:"default:0"` | ||||
| 	ChannelId        int    `json:"channel" gorm:"index"` | ||||
| } | ||||
|  | ||||
| const ( | ||||
| @@ -47,7 +47,6 @@ func RecordLog(userId int, logType int, content string) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
|  | ||||
| func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | ||||
| 	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||
| 	if !common.LogConsumeEnabled { | ||||
| @@ -64,7 +63,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke | ||||
| 		TokenName:        tokenName, | ||||
| 		ModelName:        modelName, | ||||
| 		Quota:            quota, | ||||
| 		Channel:          channelId, | ||||
| 		ChannelId:        channelId, | ||||
| 	} | ||||
| 	err := DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| @@ -95,7 +94,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName | ||||
| 		tx = tx.Where("created_at <= ?", endTimestamp) | ||||
| 	} | ||||
| 	if channel != 0 { | ||||
| 		tx = tx.Where("channel = ?", channel) | ||||
| 		tx = tx.Where("channel_id = ?", channel) | ||||
| 	} | ||||
| 	err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error | ||||
| 	return logs, err | ||||
| @@ -135,7 +134,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | ||||
| } | ||||
|  | ||||
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { | ||||
| 	tx := DB.Table("logs").Select("sum(quota)") | ||||
| 	tx := DB.Table("logs").Select("ifnull(sum(quota),0)") | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
| 	} | ||||
| @@ -152,14 +151,14 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa | ||||
| 		tx = tx.Where("model_name = ?", modelName) | ||||
| 	} | ||||
| 	if channel != 0 { | ||||
| 		tx = tx.Where("channel = ?", channel) | ||||
| 		tx = tx.Where("channel_id = ?", channel) | ||||
| 	} | ||||
| 	tx.Where("type = ?", LogTypeConsume).Scan("a) | ||||
| 	return quota | ||||
| } | ||||
|  | ||||
| func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | ||||
| 	tx := DB.Table("logs").Select("sum(prompt_tokens) + sum(completion_tokens)") | ||||
| 	tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
| 	} | ||||
|   | ||||
| @@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) { | ||||
| 		if strings.HasPrefix(dsn, "postgres://") { | ||||
| 			// Use PostgreSQL | ||||
| 			common.SysLog("using PostgreSQL as database") | ||||
| 			common.UsingPostgreSQL = true | ||||
| 			return gorm.Open(postgres.New(postgres.Config{ | ||||
| 				DSN:                  dsn, | ||||
| 				PreferSimpleProtocol: true, // disables implicit prepared statement usage | ||||
| @@ -81,6 +82,7 @@ func InitDB() (err error) { | ||||
| 		if !common.IsMasterNode { | ||||
| 			return nil | ||||
| 		} | ||||
| 		common.SysLog("database migration started") | ||||
| 		err = db.AutoMigrate(&Channel{}) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
|   | ||||
| @@ -34,6 +34,7 @@ func InitOptionMap() { | ||||
| 	common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) | ||||
| 	common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) | ||||
| 	common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) | ||||
| 	common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) | ||||
| 	common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) | ||||
| 	common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) | ||||
| 	common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) | ||||
| @@ -147,6 +148,8 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 			common.EmailDomainRestrictionEnabled = boolValue | ||||
| 		case "AutomaticDisableChannelEnabled": | ||||
| 			common.AutomaticDisableChannelEnabled = boolValue | ||||
| 		case "AutomaticEnableChannelEnabled": | ||||
| 			common.AutomaticEnableChannelEnabled = boolValue | ||||
| 		case "ApproximateTokenEnabled": | ||||
| 			common.ApproximateTokenEnabled = boolValue | ||||
| 		case "LogConsumeEnabled": | ||||
|   | ||||
| @@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) { | ||||
| 	} | ||||
| 	redemption := &Redemption{} | ||||
|  | ||||
| 	keyCol := "`key`" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		keyCol = `"key"` | ||||
| 	} | ||||
|  | ||||
| 	err = DB.Transaction(func(tx *gorm.DB) error { | ||||
| 		err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error | ||||
| 		err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error | ||||
| 		if err != nil { | ||||
| 			return errors.New("无效的兑换码") | ||||
| 		} | ||||
|   | ||||
| @@ -266,7 +266,12 @@ func GetUserEmail(id int) (email string, err error) { | ||||
| } | ||||
|  | ||||
| func GetUserGroup(id int) (group string, err error) { | ||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error | ||||
| 	groupCol := "`group`" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		groupCol = `"group"` | ||||
| 	} | ||||
|  | ||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error | ||||
| 	return group, err | ||||
| } | ||||
|  | ||||
| @@ -309,7 +314,8 @@ func GetRootUserEmail() (email string) { | ||||
|  | ||||
| func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | ||||
| 	if common.BatchUpdateEnabled { | ||||
| 		addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota) | ||||
| 		addNewRecord(BatchUpdateTypeUsedQuota, id, quota) | ||||
| 		addNewRecord(BatchUpdateTypeRequestCount, id, 1) | ||||
| 		return | ||||
| 	} | ||||
| 	updateUserUsedQuotaAndRequestCount(id, quota, 1) | ||||
| @@ -327,6 +333,24 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func updateUserUsedQuota(id int, quota int) { | ||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | ||||
| 		map[string]interface{}{ | ||||
| 			"used_quota": gorm.Expr("used_quota + ?", quota), | ||||
| 		}, | ||||
| 	).Error | ||||
| 	if err != nil { | ||||
| 		common.SysError("failed to update user used quota: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func updateUserRequestCount(id int, count int) { | ||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error | ||||
| 	if err != nil { | ||||
| 		common.SysError("failed to update user request count: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetUsernameById(id int) (username string) { | ||||
| 	DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username) | ||||
| 	return username | ||||
|   | ||||
| @@ -6,13 +6,13 @@ import ( | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock | ||||
|  | ||||
| const ( | ||||
| 	BatchUpdateTypeUserQuota = iota | ||||
| 	BatchUpdateTypeTokenQuota | ||||
| 	BatchUpdateTypeUsedQuotaAndRequestCount | ||||
| 	BatchUpdateTypeUsedQuota | ||||
| 	BatchUpdateTypeChannelUsedQuota | ||||
| 	BatchUpdateTypeRequestCount | ||||
| 	BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock | ||||
| ) | ||||
|  | ||||
| var batchUpdateStores []map[int]int | ||||
| @@ -51,7 +51,7 @@ func batchUpdate() { | ||||
| 		store := batchUpdateStores[i] | ||||
| 		batchUpdateStores[i] = make(map[int]int) | ||||
| 		batchUpdateLocks[i].Unlock() | ||||
|  | ||||
| 		// TODO: maybe we can combine updates with same key? | ||||
| 		for key, value := range store { | ||||
| 			switch i { | ||||
| 			case BatchUpdateTypeUserQuota: | ||||
| @@ -64,8 +64,10 @@ func batchUpdate() { | ||||
| 				if err != nil { | ||||
| 					common.SysError("failed to batch update token quota: " + err.Error()) | ||||
| 				} | ||||
| 			case BatchUpdateTypeUsedQuotaAndRequestCount: | ||||
| 				updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect | ||||
| 			case BatchUpdateTypeUsedQuota: | ||||
| 				updateUserUsedQuota(key, value) | ||||
| 			case BatchUpdateTypeRequestCount: | ||||
| 				updateUserRequestCount(key, value) | ||||
| 			case BatchUpdateTypeChannelUsedQuota: | ||||
| 				updateChannelUsedQuota(key, value) | ||||
| 			} | ||||
|   | ||||
							
								
								
									
										3
									
								
								pull_request_template.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								pull_request_template.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| close #issue_number | ||||
|  | ||||
| 我已确认该 PR 已自测通过,相关截图如下: | ||||
| @@ -74,6 +74,7 @@ func SetApiRouter(router *gin.Engine) { | ||||
| 			channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) | ||||
| 			channelRoute.POST("/", controller.AddChannel) | ||||
| 			channelRoute.PUT("/", controller.UpdateChannel) | ||||
| 			channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel) | ||||
| 			channelRoute.DELETE("/:id", controller.DeleteChannel) | ||||
| 		} | ||||
| 		tokenRoute := apiRouter.Group("/token") | ||||
|   | ||||
| @@ -29,17 +29,44 @@ func SetRelayRouter(router *gin.Engine) { | ||||
| 		relayV1Router.POST("/engines/:model/embeddings", controller.Relay) | ||||
| 		relayV1Router.POST("/audio/transcriptions", controller.Relay) | ||||
| 		relayV1Router.POST("/audio/translations", controller.Relay) | ||||
| 		relayV1Router.POST("/audio/speech", controller.Relay) | ||||
| 		relayV1Router.GET("/files", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/files", controller.RelayNotImplemented) | ||||
| 		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/files/:id", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/fine_tuning/jobs", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/fine_tuning/jobs", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/fine_tuning/jobs/:id", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/fine_tuning/jobs/:id/cancel", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/fine_tuning/jobs/:id/events", controller.RelayNotImplemented) | ||||
| 		relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/moderations", controller.Relay) | ||||
| 		relayV1Router.POST("/assistants", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/assistants/:id", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/assistants/:id", controller.RelayNotImplemented) | ||||
| 		relayV1Router.DELETE("/assistants/:id", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/assistants", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/assistants/:id/files", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/assistants/:id/files/:fileId", controller.RelayNotImplemented) | ||||
| 		relayV1Router.DELETE("/assistants/:id/files/:fileId", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/assistants/:id/files", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/threads", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/threads/:id", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/threads/:id", controller.RelayNotImplemented) | ||||
| 		relayV1Router.DELETE("/threads/:id", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/threads/:id/messages", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/threads/:id/messages/:messageId", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/threads/:id/messages/:messageId", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/threads/:id/messages/:messageId/files/:filesId", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/threads/:id/messages/:messageId/files", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/threads/:id/runs", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/threads/:id/runs/:runsId", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/threads/:id/runs/:runsId", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/threads/:id/runs", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/threads/:id/runs/:runsId/submit_tool_outputs", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/threads/:id/runs/:runsId/cancel", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/threads/:id/runs/:runsId/steps/:stepId", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/threads/:id/runs/:runsId/steps", controller.RelayNotImplemented) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -283,7 +283,9 @@ function App() { | ||||
|           </Suspense> | ||||
|         } | ||||
|       /> | ||||
|       <Route path='*' element={NotFound} /> | ||||
|       <Route path='*' element={ | ||||
|           <NotFound /> | ||||
|       } /> | ||||
|     </Routes> | ||||
|   ); | ||||
| } | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react'; | ||||
| import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; | ||||
| import { Link } from 'react-router-dom'; | ||||
| import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers'; | ||||
| import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; | ||||
|  | ||||
| import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; | ||||
| import { renderGroup, renderNumber } from '../helpers/render'; | ||||
| @@ -55,6 +55,7 @@ const ChannelsTable = () => { | ||||
|   const [searchKeyword, setSearchKeyword] = useState(''); | ||||
|   const [searching, setSearching] = useState(false); | ||||
|   const [updatingBalance, setUpdatingBalance] = useState(false); | ||||
|   const [showPrompt, setShowPrompt] = useState(shouldShowPrompt("channel-test")); | ||||
|  | ||||
|   const loadChannels = async (startIdx) => { | ||||
|     const res = await API.get(`/api/channel/?p=${startIdx}`); | ||||
| @@ -96,7 +97,7 @@ const ChannelsTable = () => { | ||||
|       }); | ||||
|   }, []); | ||||
|  | ||||
|   const manageChannel = async (id, action, idx, priority) => { | ||||
|   const manageChannel = async (id, action, idx, value) => { | ||||
|     let data = { id }; | ||||
|     let res; | ||||
|     switch (action) { | ||||
| @@ -112,10 +113,20 @@ const ChannelsTable = () => { | ||||
|         res = await API.put('/api/channel/', data); | ||||
|         break; | ||||
|       case 'priority': | ||||
|         if (priority === '') { | ||||
|         if (value === '') { | ||||
|           return; | ||||
|         } | ||||
|         data.priority = parseInt(priority); | ||||
|         data.priority = parseInt(value); | ||||
|         res = await API.put('/api/channel/', data); | ||||
|         break; | ||||
|       case 'weight': | ||||
|         if (value === '') { | ||||
|           return; | ||||
|         } | ||||
|         data.weight = parseInt(value); | ||||
|         if (data.weight < 0) { | ||||
|           data.weight = 0; | ||||
|         } | ||||
|         res = await API.put('/api/channel/', data); | ||||
|         break; | ||||
|     } | ||||
| @@ -142,9 +153,23 @@ const ChannelsTable = () => { | ||||
|         return <Label basic color='green'>已启用</Label>; | ||||
|       case 2: | ||||
|         return ( | ||||
|           <Label basic color='red'> | ||||
|             已禁用 | ||||
|           </Label> | ||||
|           <Popup | ||||
|             trigger={<Label basic color='red'> | ||||
|               已禁用 | ||||
|             </Label>} | ||||
|             content='本渠道被手动禁用' | ||||
|             basic | ||||
|           /> | ||||
|         ); | ||||
|       case 3: | ||||
|         return ( | ||||
|           <Popup | ||||
|             trigger={<Label basic color='yellow'> | ||||
|               已禁用 | ||||
|             </Label>} | ||||
|             content='本渠道被程序自动禁用' | ||||
|             basic | ||||
|           /> | ||||
|         ); | ||||
|       default: | ||||
|         return ( | ||||
| @@ -202,7 +227,6 @@ const ChannelsTable = () => { | ||||
|       showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); | ||||
|     } else { | ||||
|       showError(message); | ||||
|       showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。") | ||||
|     } | ||||
|   }; | ||||
|  | ||||
| @@ -210,7 +234,18 @@ const ChannelsTable = () => { | ||||
|     const res = await API.get(`/api/channel/test`); | ||||
|     const { success, message } = res.data; | ||||
|     if (success) { | ||||
|       showInfo('已成功开始测试所有已启用通道,请刷新页面查看结果。'); | ||||
|       showInfo('已成功开始测试所有通道,请刷新页面查看结果。'); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const deleteAllDisabledChannels = async () => { | ||||
|     const res = await API.delete(`/api/channel/disabled`); | ||||
|     const { success, message, data } = res.data; | ||||
|     if (success) { | ||||
|       showSuccess(`已删除所有禁用渠道,共计 ${data} 个`); | ||||
|       await refresh(); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
| @@ -251,17 +286,15 @@ const ChannelsTable = () => { | ||||
|     if (channels.length === 0) return; | ||||
|     setLoading(true); | ||||
|     let sortedChannels = [...channels]; | ||||
|     if (typeof sortedChannels[0][key] === 'string') { | ||||
|       sortedChannels.sort((a, b) => { | ||||
|     sortedChannels.sort((a, b) => { | ||||
|       if (!isNaN(a[key])) { | ||||
|         // If the value is numeric, subtract to sort | ||||
|         return a[key] - b[key]; | ||||
|       } else { | ||||
|         // If the value is not numeric, sort as strings | ||||
|         return ('' + a[key]).localeCompare(b[key]); | ||||
|       }); | ||||
|     } else { | ||||
|       sortedChannels.sort((a, b) => { | ||||
|         if (a[key] === b[key]) return 0; | ||||
|         if (a[key] > b[key]) return -1; | ||||
|         if (a[key] < b[key]) return 1; | ||||
|       }); | ||||
|     } | ||||
|       } | ||||
|     }); | ||||
|     if (sortedChannels[0].id === channels[0].id) { | ||||
|       sortedChannels.reverse(); | ||||
|     } | ||||
| @@ -269,6 +302,7 @@ const ChannelsTable = () => { | ||||
|     setLoading(false); | ||||
|   }; | ||||
|  | ||||
|  | ||||
|   return ( | ||||
|     <> | ||||
|       <Form onSubmit={searchChannels}> | ||||
| @@ -282,7 +316,19 @@ const ChannelsTable = () => { | ||||
|           onChange={handleKeywordChange} | ||||
|         /> | ||||
|       </Form> | ||||
|       { | ||||
|         showPrompt && ( | ||||
|           <Message onDismiss={() => { | ||||
|             setShowPrompt(false); | ||||
|             setPromptShown("channel-test"); | ||||
|           }}> | ||||
|             当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo | ||||
|             模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。 | ||||
|  | ||||
|             另外,OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。 | ||||
|           </Message> | ||||
|         ) | ||||
|       } | ||||
|       <Table basic compact size='small'> | ||||
|         <Table.Header> | ||||
|           <Table.Row> | ||||
| @@ -343,10 +389,10 @@ const ChannelsTable = () => { | ||||
|               余额 | ||||
|             </Table.HeaderCell> | ||||
|             <Table.HeaderCell | ||||
|                 style={{ cursor: 'pointer' }} | ||||
|                 onClick={() => { | ||||
|                   sortChannel('priority'); | ||||
|                 }} | ||||
|               style={{ cursor: 'pointer' }} | ||||
|               onClick={() => { | ||||
|                 sortChannel('priority'); | ||||
|               }} | ||||
|             > | ||||
|               优先级 | ||||
|             </Table.HeaderCell> | ||||
| @@ -390,18 +436,18 @@ const ChannelsTable = () => { | ||||
|                   </Table.Cell> | ||||
|                   <Table.Cell> | ||||
|                     <Popup | ||||
|                         trigger={<Input type="number"  defaultValue={channel.priority} onBlur={(event) => { | ||||
|                           manageChannel( | ||||
|                               channel.id, | ||||
|                               'priority', | ||||
|                               idx, | ||||
|                               event.target.value, | ||||
|                           ); | ||||
|                         }}> | ||||
|                           <input style={{maxWidth:'60px'}} /> | ||||
|                         </Input>} | ||||
|                         content='渠道选择优先级,越高越优先' | ||||
|                         basic | ||||
|                       trigger={<Input type='number' defaultValue={channel.priority} onBlur={(event) => { | ||||
|                         manageChannel( | ||||
|                           channel.id, | ||||
|                           'priority', | ||||
|                           idx, | ||||
|                           event.target.value | ||||
|                         ); | ||||
|                       }}> | ||||
|                         <input style={{ maxWidth: '60px' }} /> | ||||
|                       </Input>} | ||||
|                       content='渠道选择优先级,越高越优先' | ||||
|                       basic | ||||
|                     /> | ||||
|                   </Table.Cell> | ||||
|                   <Table.Cell> | ||||
| @@ -481,6 +527,20 @@ const ChannelsTable = () => { | ||||
|               </Button> | ||||
|               <Button size='small' onClick={updateAllChannelsBalance} | ||||
|                       loading={loading || updatingBalance}>更新所有已启用通道余额</Button> | ||||
|               <Popup | ||||
|                 trigger={ | ||||
|                   <Button size='small' loading={loading}> | ||||
|                     删除禁用渠道 | ||||
|                   </Button> | ||||
|                 } | ||||
|                 on='click' | ||||
|                 flowing | ||||
|                 hoverable | ||||
|               > | ||||
|                 <Button size='small' loading={loading} negative onClick={deleteAllDisabledChannels}> | ||||
|                   确认删除 | ||||
|                 </Button> | ||||
|               </Popup> | ||||
|               <Pagination | ||||
|                 floated='right' | ||||
|                 activePage={activePage} | ||||
|   | ||||
| @@ -2,8 +2,8 @@ import React, { useContext, useEffect, useState } from 'react'; | ||||
| import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react'; | ||||
| import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | ||||
| import { UserContext } from '../context/User'; | ||||
| import { API, getLogo, showError, showSuccess } from '../helpers'; | ||||
| import { getOAuthState, onGitHubOAuthClicked } from './utils'; | ||||
| import { API, getLogo, showError, showSuccess, showWarning } from '../helpers'; | ||||
| import { onGitHubOAuthClicked } from './utils'; | ||||
|  | ||||
| const LoginForm = () => { | ||||
|   const [inputs, setInputs] = useState({ | ||||
| @@ -68,8 +68,14 @@ const LoginForm = () => { | ||||
|       if (success) { | ||||
|         userDispatch({ type: 'login', payload: data }); | ||||
|         localStorage.setItem('user', JSON.stringify(data)); | ||||
|         navigate('/'); | ||||
|         showSuccess('登录成功!'); | ||||
|         if (username === 'root' && password === '123456') { | ||||
|           navigate('/user/edit'); | ||||
|           showSuccess('登录成功!'); | ||||
|           showWarning('请立刻修改默认密码!'); | ||||
|         } else { | ||||
|           navigate('/token'); | ||||
|           showSuccess('登录成功!'); | ||||
|         } | ||||
|       } else { | ||||
|         showError(message); | ||||
|       } | ||||
| @@ -126,7 +132,7 @@ const LoginForm = () => { | ||||
|                 circular | ||||
|                 color='black' | ||||
|                 icon='github' | ||||
|                 onClick={()=>onGitHubOAuthClicked(status.github_client_id)} | ||||
|                 onClick={() => onGitHubOAuthClicked(status.github_client_id)} | ||||
|               /> | ||||
|             ) : ( | ||||
|               <></> | ||||
|   | ||||
| @@ -16,6 +16,7 @@ const OperationSetting = () => { | ||||
|     ChatLink: '', | ||||
|     QuotaPerUnit: 0, | ||||
|     AutomaticDisableChannelEnabled: '', | ||||
|     AutomaticEnableChannelEnabled: '', | ||||
|     ChannelDisableThreshold: 0, | ||||
|     LogConsumeEnabled: '', | ||||
|     DisplayInCurrencyEnabled: '', | ||||
| @@ -269,6 +270,12 @@ const OperationSetting = () => { | ||||
|               name='AutomaticDisableChannelEnabled' | ||||
|               onChange={handleInputChange} | ||||
|             /> | ||||
|             <Form.Checkbox | ||||
|                 checked={inputs.AutomaticEnableChannelEnabled === 'true'} | ||||
|                 label='成功时自动启用通道' | ||||
|                 name='AutomaticEnableChannelEnabled' | ||||
|                 onChange={handleInputChange} | ||||
|             /> | ||||
|           </Form.Group> | ||||
|           <Form.Button onClick={() => { | ||||
|             submitConfig('monitor').then(); | ||||
|   | ||||
| @@ -130,7 +130,13 @@ const RedemptionsTable = () => { | ||||
|     setLoading(true); | ||||
|     let sortedRedemptions = [...redemptions]; | ||||
|     sortedRedemptions.sort((a, b) => { | ||||
|       return ('' + a[key]).localeCompare(b[key]); | ||||
|       if (!isNaN(a[key])) { | ||||
|         // If the value is numeric, subtract to sort | ||||
|         return a[key] - b[key]; | ||||
|       } else { | ||||
|         // If the value is not numeric, sort as strings | ||||
|         return ('' + a[key]).localeCompare(b[key]); | ||||
|       } | ||||
|     }); | ||||
|     if (sortedRedemptions[0].id === redemptions[0].id) { | ||||
|       sortedRedemptions.reverse(); | ||||
|   | ||||
| @@ -138,7 +138,7 @@ const TokensTable = () => { | ||||
|     let defaultUrl; | ||||
|    | ||||
|     if (chatLink) { | ||||
|       defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`; | ||||
|       defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; | ||||
|     } else { | ||||
|       defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; | ||||
|     } | ||||
| @@ -228,7 +228,13 @@ const TokensTable = () => { | ||||
|     setLoading(true); | ||||
|     let sortedTokens = [...tokens]; | ||||
|     sortedTokens.sort((a, b) => { | ||||
|       return ('' + a[key]).localeCompare(b[key]); | ||||
|       if (!isNaN(a[key])) { | ||||
|         // If the value is numeric, subtract to sort | ||||
|         return a[key] - b[key]; | ||||
|       } else { | ||||
|         // If the value is not numeric, sort as strings | ||||
|         return ('' + a[key]).localeCompare(b[key]); | ||||
|       } | ||||
|     }); | ||||
|     if (sortedTokens[0].id === tokens[0].id) { | ||||
|       sortedTokens.reverse(); | ||||
|   | ||||
| @@ -133,7 +133,13 @@ const UsersTable = () => { | ||||
|     setLoading(true); | ||||
|     let sortedUsers = [...users]; | ||||
|     sortedUsers.sort((a, b) => { | ||||
|       return ('' + a[key]).localeCompare(b[key]); | ||||
|       if (!isNaN(a[key])) { | ||||
|         // If the value is numeric, subtract to sort | ||||
|         return a[key] - b[key]; | ||||
|       } else { | ||||
|         // If the value is not numeric, sort as strings | ||||
|         return ('' + a[key]).localeCompare(b[key]); | ||||
|       } | ||||
|     }); | ||||
|     if (sortedUsers[0].id === users[0].id) { | ||||
|       sortedUsers.reverse(); | ||||
|   | ||||
| @@ -8,6 +8,7 @@ export const CHANNEL_OPTIONS = [ | ||||
|   { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, | ||||
|   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, | ||||
|   { key: 19, text: '360 智脑', value: 19, color: 'blue' }, | ||||
|   { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, | ||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||
|   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, | ||||
|   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, | ||||
|   | ||||
| @@ -186,4 +186,14 @@ export const verifyJSON = (str) => { | ||||
|     return false; | ||||
|   } | ||||
|   return true; | ||||
| }; | ||||
| }; | ||||
|  | ||||
| export function shouldShowPrompt(id) { | ||||
|   let prompt = localStorage.getItem(`prompt-${id}`); | ||||
|   return !prompt; | ||||
|  | ||||
| } | ||||
|  | ||||
| export function setPromptShown(id) { | ||||
|   localStorage.setItem(`prompt-${id}`, 'true'); | ||||
| } | ||||
| @@ -19,6 +19,8 @@ function type2secretPrompt(type) { | ||||
|       return '按照如下格式输入:APPID|APISecret|APIKey'; | ||||
|     case 22: | ||||
|       return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041'; | ||||
|     case 23: | ||||
|       return '按照如下格式输入:AppId|SecretId|SecretKey'; | ||||
|     default: | ||||
|       return '请输入渠道对应的鉴权密钥'; | ||||
|   } | ||||
| @@ -58,25 +60,28 @@ const EditChannel = () => { | ||||
|       let localModels = []; | ||||
|       switch (value) { | ||||
|         case 14: | ||||
|           localModels = ['claude-instant-1', 'claude-2']; | ||||
|           localModels = ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1']; | ||||
|           break; | ||||
|         case 11: | ||||
|           localModels = ['PaLM-2']; | ||||
|           break; | ||||
|         case 15: | ||||
|           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; | ||||
|           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1']; | ||||
|           break; | ||||
|         case 17: | ||||
|           localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1']; | ||||
|           localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1']; | ||||
|           break; | ||||
|         case 16: | ||||
|           localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||
|           localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||
|           break; | ||||
|         case 18: | ||||
|           localModels = ['SparkDesk']; | ||||
|           break; | ||||
|         case 19: | ||||
|           localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4']; | ||||
|           localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; | ||||
|           break; | ||||
|         case 23: | ||||
|           localModels = ['hunyuan']; | ||||
|           break; | ||||
|       } | ||||
|       setInputs((inputs) => ({ ...inputs, models: localModels })); | ||||
| @@ -174,7 +179,7 @@ const EditChannel = () => { | ||||
|       return; | ||||
|     } | ||||
|     let localInputs = inputs; | ||||
|     if (localInputs.base_url.endsWith('/')) { | ||||
|     if (localInputs.base_url && localInputs.base_url.endsWith('/')) { | ||||
|       localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); | ||||
|     } | ||||
|     if (localInputs.type === 3 && localInputs.other === '') { | ||||
| @@ -183,9 +188,6 @@ const EditChannel = () => { | ||||
|     if (localInputs.type === 18 && localInputs.other === '') { | ||||
|       localInputs.other = 'v2.1'; | ||||
|     } | ||||
|     if (localInputs.model_mapping === '') { | ||||
|       localInputs.model_mapping = '{}'; | ||||
|     } | ||||
|     let res; | ||||
|     localInputs.models = localInputs.models.join(','); | ||||
|     localInputs.group = localInputs.groups.join(','); | ||||
|   | ||||
| @@ -1,19 +1,12 @@ | ||||
| import React from 'react'; | ||||
| import { Segment, Header } from 'semantic-ui-react'; | ||||
| import { Message } from 'semantic-ui-react'; | ||||
|  | ||||
| const NotFound = () => ( | ||||
|   <> | ||||
|     <Header | ||||
|       block | ||||
|       as="h4" | ||||
|       content="404" | ||||
|       attached="top" | ||||
|       icon="info" | ||||
|       className="small-icon" | ||||
|     /> | ||||
|     <Segment attached="bottom"> | ||||
|       未找到所请求的页面 | ||||
|     </Segment> | ||||
|     <Message negative> | ||||
|       <Message.Header>页面不存在</Message.Header> | ||||
|       <p>请检查你的浏览器地址是否正确</p> | ||||
|     </Message> | ||||
|   </> | ||||
| ); | ||||
|  | ||||
|   | ||||
| @@ -102,7 +102,7 @@ const EditUser = () => { | ||||
|               label='密码' | ||||
|               name='password' | ||||
|               type={'password'} | ||||
|               placeholder={'请输入新的密码'} | ||||
|               placeholder={'请输入新的密码,最短 8 位'} | ||||
|               onChange={handleInputChange} | ||||
|               value={password} | ||||
|               autoComplete='new-password' | ||||
|   | ||||
		Reference in New Issue
	
	Block a user