mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-01 06:13:43 +08:00 
			
		
		
		
	Compare commits
	
		
			31 Commits
		
	
	
		
			v0.5.6-alp
			...
			v0.5.7
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | aec343dc38 | ||
|  | 89d458b9cf | ||
|  | 63fafba112 | ||
|  | a398f35968 | ||
|  | 57aa637c77 | ||
|  | 3b483639a4 | ||
|  | 22980b4c44 | ||
|  | 64cdb7eafb | ||
|  | 824444244b | ||
|  | fbe9985f57 | ||
|  | a27a5bcc06 | ||
|  | e28d4b1741 | ||
|  | f073592d39 | ||
|  | fa41ca9805 | ||
|  | e338de45b6 | ||
|  | 114587b46f | ||
|  | b4b4acc288 | ||
|  | d663de3e3a | ||
|  | a85ecace2e | ||
|  | fbdea91ea1 | ||
|  | 8d34b7a77e | ||
|  | cbd62011b8 | ||
|  | 4701897e2e | ||
|  | 0f6c132a80 | ||
|  | 3cac45dc85 | ||
|  | 47c08c72ce | ||
|  | 53b2cace0b | ||
|  | f0fc991b44 | ||
|  | 594f06e7b0 | ||
|  | 197d1d7a9d | ||
|  | f9b748c2ca | 
							
								
								
									
										71
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										71
									
								
								README.md
									
									
									
									
									
								
							| @@ -59,6 +59,9 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
| > **Warning** | > **Warning** | ||||||
| > 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 | > 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 | ||||||
|  |  | ||||||
|  | > **Warning** | ||||||
|  | > 使用 root 用户初次登录系统后,务必修改默认密码 `123456`! | ||||||
|  |  | ||||||
| ## 功能 | ## 功能 | ||||||
| 1. 支持多种大模型: | 1. 支持多种大模型: | ||||||
|    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) |    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||||
| @@ -69,6 +72,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
|    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) |    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||||
|    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) |    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) | ||||||
|    + [x] [360 智脑](https://ai.360.cn) |    + [x] [360 智脑](https://ai.360.cn) | ||||||
|  |    + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) | ||||||
| 2. 支持配置镜像以及众多第三方代理服务: | 2. 支持配置镜像以及众多第三方代理服务: | ||||||
|    + [x] [OpenAI-SB](https://openai-sb.com) |    + [x] [OpenAI-SB](https://openai-sb.com) | ||||||
|    + [x] [CloseAI](https://console.closeai-asia.com/r/2412) |    + [x] [CloseAI](https://console.closeai-asia.com/r/2412) | ||||||
| @@ -91,23 +95,30 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
| 15. 支持模型映射,重定向用户的请求模型。 | 15. 支持模型映射,重定向用户的请求模型。 | ||||||
| 16. 支持失败自动重试。 | 16. 支持失败自动重试。 | ||||||
| 17. 支持绘图接口。 | 17. 支持绘图接口。 | ||||||
| 18. 支持丰富的**自定义**设置, | 18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。 | ||||||
|  | 19. 支持丰富的**自定义**设置, | ||||||
|     1. 支持自定义系统名称,logo 以及页脚。 |     1. 支持自定义系统名称,logo 以及页脚。 | ||||||
|     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 |     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 | ||||||
| 19. 支持通过系统访问令牌访问管理 API。 | 20. 支持通过系统访问令牌访问管理 API。 | ||||||
| 20. 支持 Cloudflare Turnstile 用户校验。 | 21. 支持 Cloudflare Turnstile 用户校验。 | ||||||
| 21. 支持用户管理,支持**多种用户登录注册方式**: | 22. 支持用户管理,支持**多种用户登录注册方式**: | ||||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 |     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 |     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 |     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||||
|  |  | ||||||
| ## 部署 | ## 部署 | ||||||
| ### 基于 Docker 进行部署 | ### 基于 Docker 进行部署 | ||||||
| 部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api` | ```shell | ||||||
|  | # 使用 SQLite 的部署命令: | ||||||
|  | docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api | ||||||
|  | # 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数,不清楚如何修改请参见下面环境变量一节。 | ||||||
|  | # 例如: | ||||||
|  | docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api | ||||||
|  | ``` | ||||||
|  |  | ||||||
| 其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。 | 其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。 | ||||||
|  |  | ||||||
| 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 | 数据和日志将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 | ||||||
|  |  | ||||||
| 如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。 | 如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。 | ||||||
|  |  | ||||||
| @@ -236,7 +247,7 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | |||||||
| <summary><strong>部署到 Zeabur</strong></summary> | <summary><strong>部署到 Zeabur</strong></summary> | ||||||
| <div> | <div> | ||||||
|  |  | ||||||
| > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用。 | > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用 | ||||||
|  |  | ||||||
| 1. 首先 fork 一份代码。 | 1. 首先 fork 一份代码。 | ||||||
| 2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 | 2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 | ||||||
| @@ -251,6 +262,17 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | |||||||
| </div> | </div> | ||||||
| </details> | </details> | ||||||
|  |  | ||||||
|  | <details> | ||||||
|  | <summary><strong>部署到 Render</strong></summary> | ||||||
|  | <div> | ||||||
|  |  | ||||||
|  | > Render 提供免费额度,绑卡后可以进一步提升额度 | ||||||
|  |  | ||||||
|  | Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashboard.render.com | ||||||
|  |  | ||||||
|  | </div> | ||||||
|  | </details> | ||||||
|  |  | ||||||
| ## 配置 | ## 配置 | ||||||
| 系统本身开箱即用。 | 系统本身开箱即用。 | ||||||
|  |  | ||||||
| @@ -278,10 +300,11 @@ OPENAI_API_BASE="https://<HOST>:<PORT>/v1" | |||||||
| ```mermaid | ```mermaid | ||||||
| graph LR | graph LR | ||||||
|     A(用户) |     A(用户) | ||||||
|     A --->|请求| B(One API) |     A --->|使用 One API 分发的 key 进行请求| B(One API) | ||||||
|     B -->|中继请求| C(OpenAI) |     B -->|中继请求| C(OpenAI) | ||||||
|     B -->|中继请求| D(Azure) |     B -->|中继请求| D(Azure) | ||||||
|     B -->|中继请求| E(其他下游渠道) |     B -->|中继请求| E(其他 OpenAI API 格式下游渠道) | ||||||
|  |     B -->|中继并修改请求体和返回体| F(非 OpenAI API 格式下游渠道) | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| 可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。 | 可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。 | ||||||
| @@ -309,24 +332,30 @@ graph LR | |||||||
|      + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 |      + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 | ||||||
| 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | ||||||
|    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` |    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||||
| 5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。 | 5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||||
|  |    + 例子:`MEMORY_CACHE_ENABLED=true` | ||||||
|  | 6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 | ||||||
|    + 例子:`SYNC_FREQUENCY=60` |    + 例子:`SYNC_FREQUENCY=60` | ||||||
| 6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | 7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | ||||||
|    + 例子:`NODE_TYPE=slave` |    + 例子:`NODE_TYPE=slave` | ||||||
| 7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | 8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | ||||||
|    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` |    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` | ||||||
| 8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | 9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | ||||||
|    + 例子:`CHANNEL_TEST_FREQUENCY=1440` |    + 例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||||
| 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | 10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||||
|    + 例子:`POLLING_INTERVAL=5` |     + 例子:`POLLING_INTERVAL=5` | ||||||
| 10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | 11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||||
|     + 例子:`BATCH_UPDATE_ENABLED=true` |     + 例子:`BATCH_UPDATE_ENABLED=true` | ||||||
|     + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 |     + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 | ||||||
| 11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | 12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||||
|     + 例子:`BATCH_UPDATE_INTERVAL=5` |     + 例子:`BATCH_UPDATE_INTERVAL=5` | ||||||
| 12. 请求频率限制: | 13. 请求频率限制: | ||||||
|     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 |     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 | ||||||
|     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 |     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 | ||||||
|  | 14. 编码器缓存设置: | ||||||
|  |     + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 | ||||||
|  |     + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 | ||||||
|  | 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 | ||||||
|  |  | ||||||
| ### 命令行参数 | ### 命令行参数 | ||||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||||
| @@ -366,6 +395,12 @@ https://openai.justsong.cn | |||||||
|    + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 |    + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 | ||||||
| 6. 报错:`当前分组负载已饱和,请稍后再试` | 6. 报错:`当前分组负载已饱和,请稍后再试` | ||||||
|    + 上游通道 429 了。 |    + 上游通道 429 了。 | ||||||
|  | 7. 升级之后我的数据会丢失吗? | ||||||
|  |    + 如果使用 MySQL,不会。 | ||||||
|  |    + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 | ||||||
|  | 8. 升级之前数据库需要做变更吗? | ||||||
|  |    + 一般情况下不需要,系统将在初始化的时候自动调整。 | ||||||
|  |    + 如果需要的话,我会在更新日志中说明,并给出脚本。 | ||||||
|  |  | ||||||
| ## 相关项目 | ## 相关项目 | ||||||
| * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | ||||||
|   | |||||||
| @@ -21,12 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens | |||||||
| var DisplayInCurrencyEnabled = true | var DisplayInCurrencyEnabled = true | ||||||
| var DisplayTokenStatEnabled = true | var DisplayTokenStatEnabled = true | ||||||
|  |  | ||||||
| var UsingSQLite = false |  | ||||||
|  |  | ||||||
| // Any options with "Secret", "Token" in its key won't be return by GetOptions | // Any options with "Secret", "Token" in its key won't be return by GetOptions | ||||||
|  |  | ||||||
| var SessionSecret = uuid.New().String() | var SessionSecret = uuid.New().String() | ||||||
| var SQLitePath = "one-api.db" |  | ||||||
|  |  | ||||||
| var OptionMap map[string]string | var OptionMap map[string]string | ||||||
| var OptionMapRWMutex sync.RWMutex | var OptionMapRWMutex sync.RWMutex | ||||||
| @@ -56,6 +53,7 @@ var EmailDomainWhitelist = []string{ | |||||||
| } | } | ||||||
|  |  | ||||||
| var DebugEnabled = os.Getenv("DEBUG") == "true" | var DebugEnabled = os.Getenv("DEBUG") == "true" | ||||||
|  | var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" | ||||||
|  |  | ||||||
| var LogConsumeEnabled = true | var LogConsumeEnabled = true | ||||||
|  |  | ||||||
| @@ -92,11 +90,13 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" | |||||||
| var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | ||||||
| var RequestInterval = time.Duration(requestInterval) * time.Second | var RequestInterval = time.Duration(requestInterval) * time.Second | ||||||
|  |  | ||||||
| var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY | var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second | ||||||
|  |  | ||||||
| var BatchUpdateEnabled = false | var BatchUpdateEnabled = false | ||||||
| var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) | var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) | ||||||
|  |  | ||||||
|  | var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	RequestIdKey = "X-Oneapi-Request-Id" | 	RequestIdKey = "X-Oneapi-Request-Id" | ||||||
| ) | ) | ||||||
| @@ -155,9 +155,10 @@ const ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	ChannelStatusUnknown  = 0 | 	ChannelStatusUnknown          = 0 | ||||||
| 	ChannelStatusEnabled  = 1 // don't use 0, 0 is the default value! | 	ChannelStatusEnabled          = 1 // don't use 0, 0 is the default value! | ||||||
| 	ChannelStatusDisabled = 2 // also don't use 0 | 	ChannelStatusManuallyDisabled = 2 // also don't use 0 | ||||||
|  | 	ChannelStatusAutoDisabled     = 3 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| @@ -184,30 +185,32 @@ const ( | |||||||
| 	ChannelTypeOpenRouter     = 20 | 	ChannelTypeOpenRouter     = 20 | ||||||
| 	ChannelTypeAIProxyLibrary = 21 | 	ChannelTypeAIProxyLibrary = 21 | ||||||
| 	ChannelTypeFastGPT        = 22 | 	ChannelTypeFastGPT        = 22 | ||||||
|  | 	ChannelTypeTencent        = 23 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ChannelBaseURLs = []string{ | var ChannelBaseURLs = []string{ | ||||||
| 	"",                                // 0 | 	"",                                  // 0 | ||||||
| 	"https://api.openai.com",          // 1 | 	"https://api.openai.com",            // 1 | ||||||
| 	"https://oa.api2d.net",            // 2 | 	"https://oa.api2d.net",              // 2 | ||||||
| 	"",                                // 3 | 	"",                                  // 3 | ||||||
| 	"https://api.closeai-proxy.xyz",   // 4 | 	"https://api.closeai-proxy.xyz",     // 4 | ||||||
| 	"https://api.openai-sb.com",       // 5 | 	"https://api.openai-sb.com",         // 5 | ||||||
| 	"https://api.openaimax.com",       // 6 | 	"https://api.openaimax.com",         // 6 | ||||||
| 	"https://api.ohmygpt.com",         // 7 | 	"https://api.ohmygpt.com",           // 7 | ||||||
| 	"",                                // 8 | 	"",                                  // 8 | ||||||
| 	"https://api.caipacity.com",       // 9 | 	"https://api.caipacity.com",         // 9 | ||||||
| 	"https://api.aiproxy.io",          // 10 | 	"https://api.aiproxy.io",            // 10 | ||||||
| 	"",                                // 11 | 	"",                                  // 11 | ||||||
| 	"https://api.api2gpt.com",         // 12 | 	"https://api.api2gpt.com",           // 12 | ||||||
| 	"https://api.aigc2d.com",          // 13 | 	"https://api.aigc2d.com",            // 13 | ||||||
| 	"https://api.anthropic.com",       // 14 | 	"https://api.anthropic.com",         // 14 | ||||||
| 	"https://aip.baidubce.com",        // 15 | 	"https://aip.baidubce.com",          // 15 | ||||||
| 	"https://open.bigmodel.cn",        // 16 | 	"https://open.bigmodel.cn",          // 16 | ||||||
| 	"https://dashscope.aliyuncs.com",  // 17 | 	"https://dashscope.aliyuncs.com",    // 17 | ||||||
| 	"",                                // 18 | 	"",                                  // 18 | ||||||
| 	"https://ai.360.cn",               // 19 | 	"https://ai.360.cn",                 // 19 | ||||||
| 	"https://openrouter.ai/api",       // 20 | 	"https://openrouter.ai/api",         // 20 | ||||||
| 	"https://api.aiproxy.io",          // 21 | 	"https://api.aiproxy.io",            // 21 | ||||||
| 	"https://fastgpt.run/api/openapi", // 22 | 	"https://fastgpt.run/api/openapi",   // 22 | ||||||
|  | 	"https://hunyuan.cloud.tencent.com", //23 | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										6
									
								
								common/database.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								common/database.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | |||||||
|  | package common | ||||||
|  |  | ||||||
|  | var UsingSQLite = false | ||||||
|  | var UsingPostgreSQL = false | ||||||
|  |  | ||||||
|  | var SQLitePath = "one-api.db" | ||||||
| @@ -46,6 +46,7 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"claude-2":                  5.51,   // $11.02 / 1M tokens | 	"claude-2":                  5.51,   // $11.02 / 1M tokens | ||||||
| 	"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens | 	"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens | ||||||
| 	"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens | 	"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens | ||||||
|  | 	"ERNIE-Bot-4":               8.572,  // ¥0.12 / 1k tokens | ||||||
| 	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens | 	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens | ||||||
| 	"PaLM-2":                    1, | 	"PaLM-2":                    1, | ||||||
| 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | ||||||
| @@ -59,7 +60,7 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens | 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens | ||||||
| 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens | 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens | ||||||
| 	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens | 	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens | ||||||
| 	"360GPT_S2_V9.4":            0.8572, // ¥0.012 / 1k tokens | 	"hunyuan":                   7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 | ||||||
| } | } | ||||||
|  |  | ||||||
| func ModelRatio2JSONString() string { | func ModelRatio2JSONString() string { | ||||||
|   | |||||||
| @@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int { | |||||||
| func MessageWithRequestId(message string, id string) string { | func MessageWithRequestId(message string, id string) string { | ||||||
| 	return fmt.Sprintf("%s (request id: %s)", message, id) | 	return fmt.Sprintf("%s (request id: %s)", message, id) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func String2Int(str string) int { | ||||||
|  | 	num, err := strconv.Atoi(str) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0 | ||||||
|  | 	} | ||||||
|  | 	return num | ||||||
|  | } | ||||||
|   | |||||||
| @@ -5,13 +5,14 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { | func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { | ||||||
| @@ -141,7 +142,7 @@ func disableChannel(channelId int, channelName string, reason string) { | |||||||
| 	if common.RootUserEmail == "" { | 	if common.RootUserEmail == "" { | ||||||
| 		common.RootUserEmail = model.GetRootUserEmail() | 		common.RootUserEmail = model.GetRootUserEmail() | ||||||
| 	} | 	} | ||||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled) | 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) | ||||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | ||||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||||
| 	err := common.SendEmail(subject, common.RootUserEmail, content) | 	err := common.SendEmail(subject, common.RootUserEmail, content) | ||||||
|   | |||||||
| @@ -127,6 +127,23 @@ func DeleteChannel(c *gin.Context) { | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func DeleteDisabledChannel(c *gin.Context) { | ||||||
|  | 	rows, err := model.DeleteDisabledChannel() | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 		"success": true, | ||||||
|  | 		"message": "", | ||||||
|  | 		"data":    rows, | ||||||
|  | 	}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
| func UpdateChannel(c *gin.Context) { | func UpdateChannel(c *gin.Context) { | ||||||
| 	channel := model.Channel{} | 	channel := model.Channel{} | ||||||
| 	err := c.ShouldBindJSON(&channel) | 	err := c.ShouldBindJSON(&channel) | ||||||
|   | |||||||
| @@ -306,6 +306,15 @@ func init() { | |||||||
| 			Root:       "ERNIE-Bot-turbo", | 			Root:       "ERNIE-Bot-turbo", | ||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "ERNIE-Bot-4", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "baidu", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "ERNIE-Bot-4", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			Id:         "Embedding-V1", | 			Id:         "Embedding-V1", | ||||||
| 			Object:     "model", | 			Object:     "model", | ||||||
| @@ -424,12 +433,12 @@ func init() { | |||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			Id:         "360GPT_S2_V9.4", | 			Id:         "hunyuan", | ||||||
| 			Object:     "model", | 			Object:     "model", | ||||||
| 			Created:    1677649963, | 			Created:    1677649963, | ||||||
| 			OwnedBy:    "360", | 			OwnedBy:    "tencent", | ||||||
| 			Permission: permission, | 			Permission: permission, | ||||||
| 			Root:       "360GPT_S2_V9.4", | 			Root:       "hunyuan", | ||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -46,7 +46,7 @@ func UpdateOption(c *gin.Context) { | |||||||
| 		if option.Value == "true" && common.GitHubClientId == "" { | 		if option.Value == "true" && common.GitHubClientId == "" { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client ID 以及 GitHub Client Secret!", | 				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", | ||||||
| 			}) | 			}) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -4,13 +4,13 @@ import ( | |||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||||
| @@ -31,6 +31,9 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
|  | 	if userQuota-preConsumedQuota < 0 { | ||||||
|  | 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||||
|  | 	} | ||||||
| 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||||
| @@ -62,12 +65,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
|  |  | ||||||
| 	baseURL := common.ChannelBaseURLs[channelType] | 	baseURL := common.ChannelBaseURLs[channelType] | ||||||
| 	requestURL := c.Request.URL.String() | 	requestURL := c.Request.URL.String() | ||||||
|  |  | ||||||
| 	if c.GetString("base_url") != "" { | 	if c.GetString("base_url") != "" { | ||||||
| 		baseURL = c.GetString("base_url") | 		baseURL = c.GetString("base_url") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) | ||||||
| 	requestBody := c.Request.Body | 	requestBody := c.Request.Body | ||||||
|  |  | ||||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||||
|   | |||||||
| @@ -6,12 +6,11 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||||
| @@ -61,16 +60,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 			isModelMapped = true | 			isModelMapped = true | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	baseURL := common.ChannelBaseURLs[channelType] | 	baseURL := common.ChannelBaseURLs[channelType] | ||||||
| 	requestURL := c.Request.URL.String() | 	requestURL := c.Request.URL.String() | ||||||
|  |  | ||||||
| 	if c.GetString("base_url") != "" { | 	if c.GetString("base_url") != "" { | ||||||
| 		baseURL = c.GetString("base_url") | 		baseURL = c.GetString("base_url") | ||||||
| 	} | 	} | ||||||
|  | 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) | ||||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) |  | ||||||
|  |  | ||||||
| 	var requestBody io.Reader | 	var requestBody io.Reader | ||||||
| 	if isModelMapped { | 	if isModelMapped { | ||||||
| 		jsonStr, err := json.Marshal(imageRequest) | 		jsonStr, err := json.Marshal(imageRequest) | ||||||
| @@ -99,7 +94,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	quota := int(ratio*sizeRatio*1000) * imageRequest.N | 	quota := int(ratio*sizeRatio*1000) * imageRequest.N | ||||||
|  |  | ||||||
| 	if consumeQuota && userQuota-quota < 0 { | 	if consumeQuota && userQuota-quota < 0 { | ||||||
| 		return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden) | 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||||
|   | |||||||
							
								
								
									
										287
									
								
								controller/relay-tencent.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										287
									
								
								controller/relay-tencent.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,287 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"crypto/hmac" | ||||||
|  | 	"crypto/sha1" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"sort" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // https://cloud.tencent.com/document/product/1729/97732 | ||||||
|  |  | ||||||
|  | type TencentMessage struct { | ||||||
|  | 	Role    string `json:"role"` | ||||||
|  | 	Content string `json:"content"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TencentChatRequest struct { | ||||||
|  | 	AppId    int64  `json:"app_id"`    // 腾讯云账号的 APPID | ||||||
|  | 	SecretId string `json:"secret_id"` // 官网 SecretId | ||||||
|  | 	// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 | ||||||
|  | 	// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 | ||||||
|  | 	Timestamp int64 `json:"timestamp"` | ||||||
|  | 	// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, | ||||||
|  | 	// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 | ||||||
|  | 	Expired int64  `json:"expired"` | ||||||
|  | 	QueryID string `json:"query_id"` //请求 Id,用于问题排查 | ||||||
|  | 	// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 | ||||||
|  | 	// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 | ||||||
|  | 	// 建议该参数和 top_p 只设置1个,不要同时更改 top_p | ||||||
|  | 	Temperature float64 `json:"temperature"` | ||||||
|  | 	// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 | ||||||
|  | 	// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 | ||||||
|  | 	// 建议该参数和 temperature 只设置1个,不要同时更改 | ||||||
|  | 	TopP float64 `json:"top_p"` | ||||||
|  | 	// Stream 0:同步,1:流式 (默认,协议:SSE) | ||||||
|  | 	// 同步请求超时:60s,如果内容较长建议使用流式 | ||||||
|  | 	Stream int `json:"stream"` | ||||||
|  | 	// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 | ||||||
|  | 	// 输入 content 总数最大支持 3000 token。 | ||||||
|  | 	Messages []TencentMessage `json:"messages"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TencentError struct { | ||||||
|  | 	Code    int    `json:"code"` | ||||||
|  | 	Message string `json:"message"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TencentUsage struct { | ||||||
|  | 	InputTokens  int `json:"input_tokens"` | ||||||
|  | 	OutputTokens int `json:"output_tokens"` | ||||||
|  | 	TotalTokens  int `json:"total_tokens"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TencentResponseChoices struct { | ||||||
|  | 	FinishReason string         `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 | ||||||
|  | 	Messages     TencentMessage `json:"messages,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 | ||||||
|  | 	Delta        TencentMessage `json:"delta,omitempty"`         // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TencentChatResponse struct { | ||||||
|  | 	Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 | ||||||
|  | 	Created string                   `json:"created,omitempty"` // unix 时间戳的字符串 | ||||||
|  | 	Id      string                   `json:"id,omitempty"`      // 会话 id | ||||||
|  | 	Usage   Usage                    `json:"usage,omitempty"`   // token 数量 | ||||||
|  | 	Error   TencentError             `json:"error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值 | ||||||
|  | 	Note    string                   `json:"note,omitempty"`    // 注释 | ||||||
|  | 	ReqID   string                   `json:"req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { | ||||||
|  | 	messages := make([]TencentMessage, 0, len(request.Messages)) | ||||||
|  | 	for i := 0; i < len(request.Messages); i++ { | ||||||
|  | 		message := request.Messages[i] | ||||||
|  | 		if message.Role == "system" { | ||||||
|  | 			messages = append(messages, TencentMessage{ | ||||||
|  | 				Role:    "user", | ||||||
|  | 				Content: message.Content, | ||||||
|  | 			}) | ||||||
|  | 			messages = append(messages, TencentMessage{ | ||||||
|  | 				Role:    "assistant", | ||||||
|  | 				Content: "Okay", | ||||||
|  | 			}) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		messages = append(messages, TencentMessage{ | ||||||
|  | 			Content: message.Content, | ||||||
|  | 			Role:    message.Role, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	stream := 0 | ||||||
|  | 	if request.Stream { | ||||||
|  | 		stream = 1 | ||||||
|  | 	} | ||||||
|  | 	return &TencentChatRequest{ | ||||||
|  | 		Timestamp:   common.GetTimestamp(), | ||||||
|  | 		Expired:     common.GetTimestamp() + 24*60*60, | ||||||
|  | 		QueryID:     common.GetUUID(), | ||||||
|  | 		Temperature: request.Temperature, | ||||||
|  | 		TopP:        request.TopP, | ||||||
|  | 		Stream:      stream, | ||||||
|  | 		Messages:    messages, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { | ||||||
|  | 	fullTextResponse := OpenAITextResponse{ | ||||||
|  | 		Object:  "chat.completion", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Usage:   response.Usage, | ||||||
|  | 	} | ||||||
|  | 	if len(response.Choices) > 0 { | ||||||
|  | 		choice := OpenAITextResponseChoice{ | ||||||
|  | 			Index: 0, | ||||||
|  | 			Message: Message{ | ||||||
|  | 				Role:    "assistant", | ||||||
|  | 				Content: response.Choices[0].Messages.Content, | ||||||
|  | 			}, | ||||||
|  | 			FinishReason: response.Choices[0].FinishReason, | ||||||
|  | 		} | ||||||
|  | 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice) | ||||||
|  | 	} | ||||||
|  | 	return &fullTextResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse { | ||||||
|  | 	response := ChatCompletionsStreamResponse{ | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   "tencent-hunyuan", | ||||||
|  | 	} | ||||||
|  | 	if len(TencentResponse.Choices) > 0 { | ||||||
|  | 		var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 		choice.Delta.Content = TencentResponse.Choices[0].Delta.Content | ||||||
|  | 		if TencentResponse.Choices[0].FinishReason == "stop" { | ||||||
|  | 			choice.FinishReason = &stopFinishReason | ||||||
|  | 		} | ||||||
|  | 		response.Choices = append(response.Choices, choice) | ||||||
|  | 	} | ||||||
|  | 	return &response | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||||||
|  | 	var responseText string | ||||||
|  | 	scanner := bufio.NewScanner(resp.Body) | ||||||
|  | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
|  | 		if atEOF && len(data) == 0 { | ||||||
|  | 			return 0, nil, nil | ||||||
|  | 		} | ||||||
|  | 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||||
|  | 			return i + 1, data[0:i], nil | ||||||
|  | 		} | ||||||
|  | 		if atEOF { | ||||||
|  | 			return len(data), data, nil | ||||||
|  | 		} | ||||||
|  | 		return 0, nil, nil | ||||||
|  | 	}) | ||||||
|  | 	dataChan := make(chan string) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		for scanner.Scan() { | ||||||
|  | 			data := scanner.Text() | ||||||
|  | 			if len(data) < 5 { // ignore blank line or wrong format | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			if data[:5] != "data:" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			data = data[5:] | ||||||
|  | 			dataChan <- data | ||||||
|  | 		} | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  | 	setEventStreamHeaders(c) | ||||||
|  | 	c.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case data := <-dataChan: | ||||||
|  | 			var TencentResponse TencentChatResponse | ||||||
|  | 			err := json.Unmarshal([]byte(data), &TencentResponse) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			response := streamResponseTencent2OpenAI(&TencentResponse) | ||||||
|  | 			if len(response.Choices) != 0 { | ||||||
|  | 				responseText += response.Choices[0].Delta.Content | ||||||
|  | 			} | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	err := resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||||
|  | 	} | ||||||
|  | 	return nil, responseText | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var TencentResponse TencentChatResponse | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &TencentResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	if TencentResponse.Error.Code != 0 { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: TencentResponse.Error.Message, | ||||||
|  | 				Code:    TencentResponse.Error.Code, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := responseTencent2OpenAI(&TencentResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) { | ||||||
|  | 	parts := strings.Split(config, "|") | ||||||
|  | 	if len(parts) != 3 { | ||||||
|  | 		err = errors.New("invalid tencent config") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	appId, err = strconv.ParseInt(parts[0], 10, 64) | ||||||
|  | 	secretId = parts[1] | ||||||
|  | 	secretKey = parts[2] | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getTencentSign(req TencentChatRequest, secretKey string) string { | ||||||
|  | 	params := make([]string, 0) | ||||||
|  | 	params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) | ||||||
|  | 	params = append(params, "secret_id="+req.SecretId) | ||||||
|  | 	params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) | ||||||
|  | 	params = append(params, "query_id="+req.QueryID) | ||||||
|  | 	params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) | ||||||
|  | 	params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) | ||||||
|  | 	params = append(params, "stream="+strconv.Itoa(req.Stream)) | ||||||
|  | 	params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) | ||||||
|  |  | ||||||
|  | 	var messageStr string | ||||||
|  | 	for _, msg := range req.Messages { | ||||||
|  | 		messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) | ||||||
|  | 	} | ||||||
|  | 	messageStr = strings.TrimSuffix(messageStr, ",") | ||||||
|  | 	params = append(params, "messages=["+messageStr+"]") | ||||||
|  |  | ||||||
|  | 	sort.Sort(sort.StringSlice(params)) | ||||||
|  | 	url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") | ||||||
|  | 	mac := hmac.New(sha1.New, []byte(secretKey)) | ||||||
|  | 	signURL := url | ||||||
|  | 	mac.Write([]byte(signURL)) | ||||||
|  | 	sign := mac.Sum([]byte(nil)) | ||||||
|  | 	return base64.StdEncoding.EncodeToString(sign) | ||||||
|  | } | ||||||
| @@ -6,13 +6,14 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| @@ -24,13 +25,21 @@ const ( | |||||||
| 	APITypeAli | 	APITypeAli | ||||||
| 	APITypeXunfei | 	APITypeXunfei | ||||||
| 	APITypeAIProxyLibrary | 	APITypeAIProxyLibrary | ||||||
|  | 	APITypeTencent | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var httpClient *http.Client | var httpClient *http.Client | ||||||
| var impatientHTTPClient *http.Client | var impatientHTTPClient *http.Client | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
| 	httpClient = &http.Client{} | 	if common.RelayTimeout == 0 { | ||||||
|  | 		httpClient = &http.Client{} | ||||||
|  | 	} else { | ||||||
|  | 		httpClient = &http.Client{ | ||||||
|  | 			Timeout: time.Duration(common.RelayTimeout) * time.Second, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	impatientHTTPClient = &http.Client{ | 	impatientHTTPClient = &http.Client{ | ||||||
| 		Timeout: 5 * time.Second, | 		Timeout: 5 * time.Second, | ||||||
| 	} | 	} | ||||||
| @@ -109,13 +118,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		apiType = APITypeXunfei | 		apiType = APITypeXunfei | ||||||
| 	case common.ChannelTypeAIProxyLibrary: | 	case common.ChannelTypeAIProxyLibrary: | ||||||
| 		apiType = APITypeAIProxyLibrary | 		apiType = APITypeAIProxyLibrary | ||||||
|  | 	case common.ChannelTypeTencent: | ||||||
|  | 		apiType = APITypeTencent | ||||||
| 	} | 	} | ||||||
| 	baseURL := common.ChannelBaseURLs[channelType] | 	baseURL := common.ChannelBaseURLs[channelType] | ||||||
| 	requestURL := c.Request.URL.String() | 	requestURL := c.Request.URL.String() | ||||||
| 	if c.GetString("base_url") != "" { | 	if c.GetString("base_url") != "" { | ||||||
| 		baseURL = c.GetString("base_url") | 		baseURL = c.GetString("base_url") | ||||||
| 	} | 	} | ||||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) | ||||||
| 	switch apiType { | 	switch apiType { | ||||||
| 	case APITypeOpenAI: | 	case APITypeOpenAI: | ||||||
| 		if channelType == common.ChannelTypeAzure { | 		if channelType == common.ChannelTypeAzure { | ||||||
| @@ -148,6 +159,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" | 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" | ||||||
| 		case "ERNIE-Bot-turbo": | 		case "ERNIE-Bot-turbo": | ||||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" | 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" | ||||||
|  | 		case "ERNIE-Bot-4": | ||||||
|  | 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" | ||||||
| 		case "BLOOMZ-7B": | 		case "BLOOMZ-7B": | ||||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" | 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" | ||||||
| 		case "Embedding-V1": | 		case "Embedding-V1": | ||||||
| @@ -179,6 +192,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		if relayMode == RelayModeEmbeddings { | 		if relayMode == RelayModeEmbeddings { | ||||||
| 			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" | 			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" | ||||||
| 		} | 		} | ||||||
|  | 	case APITypeTencent: | ||||||
|  | 		fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" | ||||||
| 	case APITypeAIProxyLibrary: | 	case APITypeAIProxyLibrary: | ||||||
| 		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) | 		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) | ||||||
| 	} | 	} | ||||||
| @@ -204,6 +219,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
|  | 	if userQuota-preConsumedQuota < 0 { | ||||||
|  | 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||||
|  | 	} | ||||||
| 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||||
| @@ -282,6 +300,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
|  | 	case APITypeTencent: | ||||||
|  | 		apiKey := c.Request.Header.Get("Authorization") | ||||||
|  | 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||||
|  | 		appId, secretId, secretKey, err := parseTencentConfig(apiKey) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) | ||||||
|  | 		} | ||||||
|  | 		tencentRequest := requestOpenAI2Tencent(textRequest) | ||||||
|  | 		tencentRequest.AppId = appId | ||||||
|  | 		tencentRequest.SecretId = secretId | ||||||
|  | 		jsonStr, err := json.Marshal(tencentRequest) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
|  | 		} | ||||||
|  | 		sign := getTencentSign(*tencentRequest, secretKey) | ||||||
|  | 		c.Request.Header.Set("Authorization", sign) | ||||||
|  | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
| 	case APITypeAIProxyLibrary: | 	case APITypeAIProxyLibrary: | ||||||
| 		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) | 		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) | ||||||
| 		aiProxyLibraryRequest.LibraryId = c.GetString("library_id") | 		aiProxyLibraryRequest.LibraryId = c.GetString("library_id") | ||||||
| @@ -329,11 +364,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			if textRequest.Stream { | 			if textRequest.Stream { | ||||||
| 				req.Header.Set("X-DashScope-SSE", "enable") | 				req.Header.Set("X-DashScope-SSE", "enable") | ||||||
| 			} | 			} | ||||||
|  | 		case APITypeTencent: | ||||||
|  | 			req.Header.Set("Authorization", apiKey) | ||||||
| 		default: | 		default: | ||||||
| 			req.Header.Set("Authorization", "Bearer "+apiKey) | 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||||
| 		} | 		} | ||||||
| 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||||
| 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||||
|  | 		if isStream && c.Request.Header.Get("Accept") == "" { | ||||||
|  | 			req.Header.Set("Accept", "text/event-stream") | ||||||
|  | 		} | ||||||
| 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | ||||||
| 		resp, err = httpClient.Do(req) | 		resp, err = httpClient.Do(req) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @@ -581,6 +621,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			} | 			} | ||||||
| 			return nil | 			return nil | ||||||
| 		} | 		} | ||||||
|  | 	case APITypeTencent: | ||||||
|  | 		if isStream { | ||||||
|  | 			err, responseText := tencentStreamHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			textResponse.Usage.PromptTokens = promptTokens | ||||||
|  | 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||||
|  | 			return nil | ||||||
|  | 		} else { | ||||||
|  | 			err, usage := tencentHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
| 	default: | 	default: | ||||||
| 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -9,44 +9,53 @@ import ( | |||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var stopFinishReason = "stop" | var stopFinishReason = "stop" | ||||||
|  |  | ||||||
|  | // tokenEncoderMap won't grow after initialization | ||||||
| var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | ||||||
|  | var defaultTokenEncoder *tiktoken.Tiktoken | ||||||
|  |  | ||||||
| func InitTokenEncoders() { | func InitTokenEncoders() { | ||||||
| 	common.SysLog("initializing token encoders") | 	common.SysLog("initializing token encoders") | ||||||
| 	fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") | 	gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error())) | 		common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) | ||||||
|  | 	} | ||||||
|  | 	defaultTokenEncoder = gpt35TokenEncoder | ||||||
|  | 	gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) | ||||||
| 	} | 	} | ||||||
| 	for model, _ := range common.ModelRatio { | 	for model, _ := range common.ModelRatio { | ||||||
| 		tokenEncoder, err := tiktoken.EncodingForModel(model) | 		if strings.HasPrefix(model, "gpt-3.5") { | ||||||
| 		if err != nil { | 			tokenEncoderMap[model] = gpt35TokenEncoder | ||||||
| 			common.SysError(fmt.Sprintf("using fallback encoder for model %s", model)) | 		} else if strings.HasPrefix(model, "gpt-4") { | ||||||
| 			tokenEncoderMap[model] = fallbackTokenEncoder | 			tokenEncoderMap[model] = gpt4TokenEncoder | ||||||
| 			continue | 		} else { | ||||||
|  | 			tokenEncoderMap[model] = nil | ||||||
| 		} | 		} | ||||||
| 		tokenEncoderMap[model] = tokenEncoder |  | ||||||
| 	} | 	} | ||||||
| 	common.SysLog("token encoders initialized") | 	common.SysLog("token encoders initialized") | ||||||
| } | } | ||||||
|  |  | ||||||
| func getTokenEncoder(model string) *tiktoken.Tiktoken { | func getTokenEncoder(model string) *tiktoken.Tiktoken { | ||||||
| 	if tokenEncoder, ok := tokenEncoderMap[model]; ok { | 	tokenEncoder, ok := tokenEncoderMap[model] | ||||||
|  | 	if ok && tokenEncoder != nil { | ||||||
| 		return tokenEncoder | 		return tokenEncoder | ||||||
| 	} | 	} | ||||||
| 	tokenEncoder, err := tiktoken.EncodingForModel(model) | 	if ok { | ||||||
| 	if err != nil { | 		tokenEncoder, err := tiktoken.EncodingForModel(model) | ||||||
| 		common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) |  | ||||||
| 		tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo") |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error())) | 			common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) | ||||||
|  | 			tokenEncoder = defaultTokenEncoder | ||||||
| 		} | 		} | ||||||
|  | 		tokenEncoderMap[model] = tokenEncoder | ||||||
|  | 		return tokenEncoder | ||||||
| 	} | 	} | ||||||
| 	tokenEncoderMap[model] = tokenEncoder | 	return defaultTokenEncoder | ||||||
| 	return tokenEncoder |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | ||||||
| @@ -167,3 +176,13 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr | |||||||
| 	openAIErrorWithStatusCode.OpenAIError = textResponse.Error | 	openAIErrorWithStatusCode.OpenAIError = textResponse.Error | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func getFullRequestURL(baseURL string, requestURL string, channelType int) string { | ||||||
|  | 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||||
|  | 	if channelType == common.ChannelTypeOpenAI { | ||||||
|  | 		if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { | ||||||
|  | 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return fullRequestURL | ||||||
|  | } | ||||||
|   | |||||||
| @@ -220,6 +220,9 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin | |||||||
| 	for !stop { | 	for !stop { | ||||||
| 		select { | 		select { | ||||||
| 		case xunfeiResponse = <-dataChan: | 		case xunfeiResponse = <-dataChan: | ||||||
|  | 			if len(xunfeiResponse.Payload.Choices.Text) == 0 { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
| 			content += xunfeiResponse.Payload.Choices.Text[0].Content | 			content += xunfeiResponse.Payload.Choices.Text[0].Content | ||||||
| 			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens | 			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens | ||||||
| 			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens | 			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens | ||||||
| @@ -295,8 +298,8 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, | |||||||
| 		common.SysLog("api_version not found, use default: " + apiVersion) | 		common.SysLog("api_version not found, use default: " + apiVersion) | ||||||
| 	} | 	} | ||||||
| 	domain := "general" | 	domain := "general" | ||||||
| 	if apiVersion == "v2.1" { | 	if apiVersion != "v1.1" { | ||||||
| 		domain = "generalv2" | 		domain += strings.Split(apiVersion, ".")[0] | ||||||
| 	} | 	} | ||||||
| 	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) | 	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) | ||||||
| 	return domain, authUrl | 	return domain, authUrl | ||||||
|   | |||||||
| @@ -23,7 +23,7 @@ services: | |||||||
|     depends_on: |     depends_on: | ||||||
|       - redis |       - redis | ||||||
|     healthcheck: |     healthcheck: | ||||||
|       test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ] |       test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] | ||||||
|       interval: 30s |       interval: 30s | ||||||
|       timeout: 10s |       timeout: 10s | ||||||
|       retries: 3 |       retries: 3 | ||||||
|   | |||||||
							
								
								
									
										10
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								go.mod
									
									
									
									
									
								
							| @@ -15,8 +15,9 @@ require ( | |||||||
| 	github.com/google/uuid v1.3.0 | 	github.com/google/uuid v1.3.0 | ||||||
| 	github.com/gorilla/websocket v1.5.0 | 	github.com/gorilla/websocket v1.5.0 | ||||||
| 	github.com/pkoukk/tiktoken-go v0.1.5 | 	github.com/pkoukk/tiktoken-go v0.1.5 | ||||||
| 	golang.org/x/crypto v0.9.0 | 	golang.org/x/crypto v0.14.0 | ||||||
| 	gorm.io/driver/mysql v1.4.3 | 	gorm.io/driver/mysql v1.4.3 | ||||||
|  | 	gorm.io/driver/postgres v1.5.2 | ||||||
| 	gorm.io/driver/sqlite v1.4.3 | 	gorm.io/driver/sqlite v1.4.3 | ||||||
| 	gorm.io/gorm v1.25.0 | 	gorm.io/gorm v1.25.0 | ||||||
| ) | ) | ||||||
| @@ -52,10 +53,9 @@ require ( | |||||||
| 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | ||||||
| 	github.com/ugorji/go/codec v1.2.11 // indirect | 	github.com/ugorji/go/codec v1.2.11 // indirect | ||||||
| 	golang.org/x/arch v0.3.0 // indirect | 	golang.org/x/arch v0.3.0 // indirect | ||||||
| 	golang.org/x/net v0.10.0 // indirect | 	golang.org/x/net v0.17.0 // indirect | ||||||
| 	golang.org/x/sys v0.8.0 // indirect | 	golang.org/x/sys v0.13.0 // indirect | ||||||
| 	golang.org/x/text v0.9.0 // indirect | 	golang.org/x/text v0.13.0 // indirect | ||||||
| 	google.golang.org/protobuf v1.30.0 // indirect | 	google.golang.org/protobuf v1.30.0 // indirect | ||||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||||
| 	gorm.io/driver/postgres v1.5.2 // indirect |  | ||||||
| ) | ) | ||||||
|   | |||||||
							
								
								
									
										17
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								go.sum
									
									
									
									
									
								
							| @@ -150,11 +150,11 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu | |||||||
| golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= | golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= | ||||||
| golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | ||||||
| golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= | golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= | ||||||
| golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= | golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= | ||||||
| golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= | golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= | ||||||
| golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | ||||||
| golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= | golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= | ||||||
| golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= | golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= | ||||||
| golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||||
| golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||||
| golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
| @@ -162,14 +162,14 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc | |||||||
| golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
| golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
| golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
| golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= | golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= | ||||||
| golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
| golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | ||||||
| golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= | ||||||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||||
| golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||||
| golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= | golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= | ||||||
| golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= | golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= | ||||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= | ||||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||||
| @@ -198,7 +198,6 @@ gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBp | |||||||
| gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= | gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= | ||||||
| gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= | gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= | ||||||
| gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= | gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= | ||||||
| gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74= |  | ||||||
| gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= | gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= | ||||||
| gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= | gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= | ||||||
| gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= | gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= | ||||||
|   | |||||||
							
								
								
									
										20
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								main.go
									
									
									
									
									
								
							| @@ -2,6 +2,7 @@ package main | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"embed" | 	"embed" | ||||||
|  | 	"fmt" | ||||||
| 	"github.com/gin-contrib/sessions" | 	"github.com/gin-contrib/sessions" | ||||||
| 	"github.com/gin-contrib/sessions/cookie" | 	"github.com/gin-contrib/sessions/cookie" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| @@ -50,18 +51,17 @@ func main() { | |||||||
| 	// Initialize options | 	// Initialize options | ||||||
| 	model.InitOptionMap() | 	model.InitOptionMap() | ||||||
| 	if common.RedisEnabled { | 	if common.RedisEnabled { | ||||||
|  | 		// for compatibility with old versions | ||||||
|  | 		common.MemoryCacheEnabled = true | ||||||
|  | 	} | ||||||
|  | 	if common.MemoryCacheEnabled { | ||||||
|  | 		common.SysLog("memory cache enabled") | ||||||
|  | 		common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) | ||||||
| 		model.InitChannelCache() | 		model.InitChannelCache() | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("SYNC_FREQUENCY") != "" { | 	if common.MemoryCacheEnabled { | ||||||
| 		frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY")) | 		go model.SyncOptions(common.SyncFrequency) | ||||||
| 		if err != nil { | 		go model.SyncChannelCache(common.SyncFrequency) | ||||||
| 			common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 		common.SyncFrequency = frequency |  | ||||||
| 		go model.SyncOptions(frequency) |  | ||||||
| 		if common.RedisEnabled { |  | ||||||
| 			go model.SyncChannelCache(frequency) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { | 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { | ||||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) | 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) | ||||||
|   | |||||||
| @@ -94,7 +94,7 @@ func TokenAuth() func(c *gin.Context) { | |||||||
| 			abortWithMessage(c, http.StatusUnauthorized, err.Error()) | 			abortWithMessage(c, http.StatusUnauthorized, err.Error()) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		userEnabled, err := model.IsUserEnabled(token.UserId) | 		userEnabled, err := model.CacheIsUserEnabled(token.UserId) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			abortWithMessage(c, http.StatusInternalServerError, err.Error()) | 			abortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
| 			return | 			return | ||||||
|   | |||||||
| @@ -25,12 +25,12 @@ func Distribute() func(c *gin.Context) { | |||||||
| 		if ok { | 		if ok { | ||||||
| 			id, err := strconv.Atoi(channelId.(string)) | 			id, err := strconv.Atoi(channelId.(string)) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") | 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			channel, err = model.GetChannelById(id, true) | 			channel, err = model.GetChannelById(id, true) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") | 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			if channel.Status != common.ChannelStatusEnabled { | 			if channel.Status != common.ChannelStatusEnabled { | ||||||
|   | |||||||
| @@ -15,10 +15,17 @@ type Ability struct { | |||||||
|  |  | ||||||
| func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||||
| 	ability := Ability{} | 	ability := Ability{} | ||||||
|  | 	groupCol := "`group`" | ||||||
|  | 	trueVal := "1" | ||||||
|  | 	if common.UsingPostgreSQL { | ||||||
|  | 		groupCol = `"group"` | ||||||
|  | 		trueVal = "true" | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	var err error = nil | 	var err error = nil | ||||||
| 	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model) | 	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) | ||||||
| 	channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery) | 	channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) | ||||||
| 	if common.UsingSQLite { | 	if common.UsingSQLite || common.UsingPostgreSQL { | ||||||
| 		err = channelQuery.Order("RANDOM()").First(&ability).Error | 		err = channelQuery.Order("RANDOM()").First(&ability).Error | ||||||
| 	} else { | 	} else { | ||||||
| 		err = channelQuery.Order("RAND()").First(&ability).Error | 		err = channelQuery.Order("RAND()").First(&ability).Error | ||||||
|   | |||||||
| @@ -21,14 +21,18 @@ var ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| func CacheGetTokenByKey(key string) (*Token, error) { | func CacheGetTokenByKey(key string) (*Token, error) { | ||||||
|  | 	keyCol := "`key`" | ||||||
|  | 	if common.UsingPostgreSQL { | ||||||
|  | 		keyCol = `"key"` | ||||||
|  | 	} | ||||||
| 	var token Token | 	var token Token | ||||||
| 	if !common.RedisEnabled { | 	if !common.RedisEnabled { | ||||||
| 		err := DB.Where("`key` = ?", key).First(&token).Error | 		err := DB.Where(keyCol+" = ?", key).First(&token).Error | ||||||
| 		return &token, err | 		return &token, err | ||||||
| 	} | 	} | ||||||
| 	tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) | 	tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		err := DB.Where("`key` = ?", key).First(&token).Error | 		err := DB.Where(keyCol+" = ?", key).First(&token).Error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| @@ -186,7 +190,7 @@ func SyncChannelCache(frequency int) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||||
| 	if !common.RedisEnabled { | 	if !common.MemoryCacheEnabled { | ||||||
| 		return GetRandomSatisfiedChannel(group, model) | 		return GetRandomSatisfiedChannel(group, model) | ||||||
| 	} | 	} | ||||||
| 	channelSyncLock.RLock() | 	channelSyncLock.RLock() | ||||||
|   | |||||||
| @@ -11,7 +11,7 @@ type Channel struct { | |||||||
| 	Key                string  `json:"key" gorm:"not null;index"` | 	Key                string  `json:"key" gorm:"not null;index"` | ||||||
| 	Status             int     `json:"status" gorm:"default:1"` | 	Status             int     `json:"status" gorm:"default:1"` | ||||||
| 	Name               string  `json:"name" gorm:"index"` | 	Name               string  `json:"name" gorm:"index"` | ||||||
| 	Weight             int     `json:"weight"` | 	Weight             *uint   `json:"weight" gorm:"default:0"` | ||||||
| 	CreatedTime        int64   `json:"created_time" gorm:"bigint"` | 	CreatedTime        int64   `json:"created_time" gorm:"bigint"` | ||||||
| 	TestTime           int64   `json:"test_time" gorm:"bigint"` | 	TestTime           int64   `json:"test_time" gorm:"bigint"` | ||||||
| 	ResponseTime       int     `json:"response_time"` // in milliseconds | 	ResponseTime       int     `json:"response_time"` // in milliseconds | ||||||
| @@ -38,7 +38,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func SearchChannels(keyword string) (channels []*Channel, err error) { | func SearchChannels(keyword string) (channels []*Channel, err error) { | ||||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error | 	keyCol := "`key`" | ||||||
|  | 	if common.UsingPostgreSQL { | ||||||
|  | 		keyCol = `"key"` | ||||||
|  | 	} | ||||||
|  | 	err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error | ||||||
| 	return channels, err | 	return channels, err | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -53,17 +57,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) { | |||||||
| 	return &channel, err | 	return &channel, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetRandomChannel() (*Channel, error) { |  | ||||||
| 	channel := Channel{} |  | ||||||
| 	var err error = nil |  | ||||||
| 	if common.UsingSQLite { |  | ||||||
| 		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error |  | ||||||
| 	} else { |  | ||||||
| 		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error |  | ||||||
| 	} |  | ||||||
| 	return &channel, err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func BatchInsertChannels(channels []Channel) error { | func BatchInsertChannels(channels []Channel) error { | ||||||
| 	var err error | 	var err error | ||||||
| 	err = DB.Create(&channels).Error | 	err = DB.Create(&channels).Error | ||||||
| @@ -176,3 +169,13 @@ func updateChannelUsedQuota(id int, quota int) { | |||||||
| 		common.SysError("failed to update channel used quota: " + err.Error()) | 		common.SysError("failed to update channel used quota: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func DeleteChannelByStatus(status int64) (int64, error) { | ||||||
|  | 	result := DB.Where("status = ?", status).Delete(&Channel{}) | ||||||
|  | 	return result.RowsAffected, result.Error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func DeleteDisabledChannel() (int64, error) { | ||||||
|  | 	result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) | ||||||
|  | 	return result.RowsAffected, result.Error | ||||||
|  | } | ||||||
|   | |||||||
							
								
								
									
										17
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -8,18 +8,18 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| type Log struct { | type Log struct { | ||||||
| 	Id               int    `json:"id"` | 	Id               int    `json:"id;index:idx_created_at_id,priority:1"` | ||||||
| 	UserId           int    `json:"user_id"` | 	UserId           int    `json:"user_id" gorm:"index"` | ||||||
| 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index"` | 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"` | ||||||
| 	Type             int    `json:"type" gorm:"index"` | 	Type             int    `json:"type" gorm:"index:idx_created_at_type"` | ||||||
| 	Content          string `json:"content"` | 	Content          string `json:"content"` | ||||||
| 	Username         string `json:"username" gorm:"index;default:''"` | 	Username         string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` | ||||||
| 	TokenName        string `json:"token_name" gorm:"index;default:''"` | 	TokenName        string `json:"token_name" gorm:"index;default:''"` | ||||||
| 	ModelName        string `json:"model_name" gorm:"index;default:''"` | 	ModelName        string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` | ||||||
| 	Quota            int    `json:"quota" gorm:"default:0"` | 	Quota            int    `json:"quota" gorm:"default:0"` | ||||||
| 	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"` | 	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"` | ||||||
| 	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"` | 	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"` | ||||||
| 	Channel          int    `json:"channel" gorm:"default:0"` | 	ChannelId        int    `json:"channel" gorm:"index"` | ||||||
| } | } | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| @@ -47,7 +47,6 @@ func RecordLog(userId int, logType int, content string) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | ||||||
| 	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | 	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||||
| 	if !common.LogConsumeEnabled { | 	if !common.LogConsumeEnabled { | ||||||
| @@ -64,7 +63,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke | |||||||
| 		TokenName:        tokenName, | 		TokenName:        tokenName, | ||||||
| 		ModelName:        modelName, | 		ModelName:        modelName, | ||||||
| 		Quota:            quota, | 		Quota:            quota, | ||||||
| 		Channel:          channelId, | 		ChannelId:        channelId, | ||||||
| 	} | 	} | ||||||
| 	err := DB.Create(log).Error | 	err := DB.Create(log).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|   | |||||||
| @@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) { | |||||||
| 		if strings.HasPrefix(dsn, "postgres://") { | 		if strings.HasPrefix(dsn, "postgres://") { | ||||||
| 			// Use PostgreSQL | 			// Use PostgreSQL | ||||||
| 			common.SysLog("using PostgreSQL as database") | 			common.SysLog("using PostgreSQL as database") | ||||||
|  | 			common.UsingPostgreSQL = true | ||||||
| 			return gorm.Open(postgres.New(postgres.Config{ | 			return gorm.Open(postgres.New(postgres.Config{ | ||||||
| 				DSN:                  dsn, | 				DSN:                  dsn, | ||||||
| 				PreferSimpleProtocol: true, // disables implicit prepared statement usage | 				PreferSimpleProtocol: true, // disables implicit prepared statement usage | ||||||
| @@ -81,6 +82,7 @@ func InitDB() (err error) { | |||||||
| 		if !common.IsMasterNode { | 		if !common.IsMasterNode { | ||||||
| 			return nil | 			return nil | ||||||
| 		} | 		} | ||||||
|  | 		common.SysLog("database migration started") | ||||||
| 		err = db.AutoMigrate(&Channel{}) | 		err = db.AutoMigrate(&Channel{}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
|   | |||||||
| @@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) { | |||||||
| 	} | 	} | ||||||
| 	redemption := &Redemption{} | 	redemption := &Redemption{} | ||||||
|  |  | ||||||
|  | 	keyCol := "`key`" | ||||||
|  | 	if common.UsingPostgreSQL { | ||||||
|  | 		keyCol = `"key"` | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	err = DB.Transaction(func(tx *gorm.DB) error { | 	err = DB.Transaction(func(tx *gorm.DB) error { | ||||||
| 		err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error | 		err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errors.New("无效的兑换码") | 			return errors.New("无效的兑换码") | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -266,7 +266,12 @@ func GetUserEmail(id int) (email string, err error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GetUserGroup(id int) (group string, err error) { | func GetUserGroup(id int) (group string, err error) { | ||||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error | 	groupCol := "`group`" | ||||||
|  | 	if common.UsingPostgreSQL { | ||||||
|  | 		groupCol = `"group"` | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error | ||||||
| 	return group, err | 	return group, err | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -309,7 +314,8 @@ func GetRootUserEmail() (email string) { | |||||||
|  |  | ||||||
| func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | ||||||
| 	if common.BatchUpdateEnabled { | 	if common.BatchUpdateEnabled { | ||||||
| 		addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota) | 		addNewRecord(BatchUpdateTypeUsedQuota, id, quota) | ||||||
|  | 		addNewRecord(BatchUpdateTypeRequestCount, id, 1) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	updateUserUsedQuotaAndRequestCount(id, quota, 1) | 	updateUserUsedQuotaAndRequestCount(id, quota, 1) | ||||||
| @@ -327,6 +333,24 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func updateUserUsedQuota(id int, quota int) { | ||||||
|  | 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | ||||||
|  | 		map[string]interface{}{ | ||||||
|  | 			"used_quota": gorm.Expr("used_quota + ?", quota), | ||||||
|  | 		}, | ||||||
|  | 	).Error | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.SysError("failed to update user used quota: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func updateUserRequestCount(id int, count int) { | ||||||
|  | 	err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.SysError("failed to update user request count: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func GetUsernameById(id int) (username string) { | func GetUsernameById(id int) (username string) { | ||||||
| 	DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username) | 	DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username) | ||||||
| 	return username | 	return username | ||||||
|   | |||||||
| @@ -6,13 +6,13 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock |  | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	BatchUpdateTypeUserQuota = iota | 	BatchUpdateTypeUserQuota = iota | ||||||
| 	BatchUpdateTypeTokenQuota | 	BatchUpdateTypeTokenQuota | ||||||
| 	BatchUpdateTypeUsedQuotaAndRequestCount | 	BatchUpdateTypeUsedQuota | ||||||
| 	BatchUpdateTypeChannelUsedQuota | 	BatchUpdateTypeChannelUsedQuota | ||||||
|  | 	BatchUpdateTypeRequestCount | ||||||
|  | 	BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var batchUpdateStores []map[int]int | var batchUpdateStores []map[int]int | ||||||
| @@ -51,7 +51,7 @@ func batchUpdate() { | |||||||
| 		store := batchUpdateStores[i] | 		store := batchUpdateStores[i] | ||||||
| 		batchUpdateStores[i] = make(map[int]int) | 		batchUpdateStores[i] = make(map[int]int) | ||||||
| 		batchUpdateLocks[i].Unlock() | 		batchUpdateLocks[i].Unlock() | ||||||
|  | 		// TODO: maybe we can combine updates with same key? | ||||||
| 		for key, value := range store { | 		for key, value := range store { | ||||||
| 			switch i { | 			switch i { | ||||||
| 			case BatchUpdateTypeUserQuota: | 			case BatchUpdateTypeUserQuota: | ||||||
| @@ -64,8 +64,10 @@ func batchUpdate() { | |||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					common.SysError("failed to batch update token quota: " + err.Error()) | 					common.SysError("failed to batch update token quota: " + err.Error()) | ||||||
| 				} | 				} | ||||||
| 			case BatchUpdateTypeUsedQuotaAndRequestCount: | 			case BatchUpdateTypeUsedQuota: | ||||||
| 				updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect | 				updateUserUsedQuota(key, value) | ||||||
|  | 			case BatchUpdateTypeRequestCount: | ||||||
|  | 				updateUserRequestCount(key, value) | ||||||
| 			case BatchUpdateTypeChannelUsedQuota: | 			case BatchUpdateTypeChannelUsedQuota: | ||||||
| 				updateChannelUsedQuota(key, value) | 				updateChannelUsedQuota(key, value) | ||||||
| 			} | 			} | ||||||
|   | |||||||
| @@ -74,6 +74,7 @@ func SetApiRouter(router *gin.Engine) { | |||||||
| 			channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) | 			channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) | ||||||
| 			channelRoute.POST("/", controller.AddChannel) | 			channelRoute.POST("/", controller.AddChannel) | ||||||
| 			channelRoute.PUT("/", controller.UpdateChannel) | 			channelRoute.PUT("/", controller.UpdateChannel) | ||||||
|  | 			channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel) | ||||||
| 			channelRoute.DELETE("/:id", controller.DeleteChannel) | 			channelRoute.DELETE("/:id", controller.DeleteChannel) | ||||||
| 		} | 		} | ||||||
| 		tokenRoute := apiRouter.Group("/token") | 		tokenRoute := apiRouter.Group("/token") | ||||||
|   | |||||||
| @@ -283,7 +283,9 @@ function App() { | |||||||
|           </Suspense> |           </Suspense> | ||||||
|         } |         } | ||||||
|       /> |       /> | ||||||
|       <Route path='*' element={NotFound} /> |       <Route path='*' element={ | ||||||
|  |           <NotFound /> | ||||||
|  |       } /> | ||||||
|     </Routes> |     </Routes> | ||||||
|   ); |   ); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react'; | import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; | ||||||
| import { Link } from 'react-router-dom'; | import { Link } from 'react-router-dom'; | ||||||
| import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers'; | import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; | ||||||
|  |  | ||||||
| import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; | import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; | ||||||
| import { renderGroup, renderNumber } from '../helpers/render'; | import { renderGroup, renderNumber } from '../helpers/render'; | ||||||
| @@ -55,6 +55,7 @@ const ChannelsTable = () => { | |||||||
|   const [searchKeyword, setSearchKeyword] = useState(''); |   const [searchKeyword, setSearchKeyword] = useState(''); | ||||||
|   const [searching, setSearching] = useState(false); |   const [searching, setSearching] = useState(false); | ||||||
|   const [updatingBalance, setUpdatingBalance] = useState(false); |   const [updatingBalance, setUpdatingBalance] = useState(false); | ||||||
|  |   const [showPrompt, setShowPrompt] = useState(shouldShowPrompt("channel-test")); | ||||||
|  |  | ||||||
|   const loadChannels = async (startIdx) => { |   const loadChannels = async (startIdx) => { | ||||||
|     const res = await API.get(`/api/channel/?p=${startIdx}`); |     const res = await API.get(`/api/channel/?p=${startIdx}`); | ||||||
| @@ -96,7 +97,7 @@ const ChannelsTable = () => { | |||||||
|       }); |       }); | ||||||
|   }, []); |   }, []); | ||||||
|  |  | ||||||
|   const manageChannel = async (id, action, idx, priority) => { |   const manageChannel = async (id, action, idx, value) => { | ||||||
|     let data = { id }; |     let data = { id }; | ||||||
|     let res; |     let res; | ||||||
|     switch (action) { |     switch (action) { | ||||||
| @@ -112,10 +113,20 @@ const ChannelsTable = () => { | |||||||
|         res = await API.put('/api/channel/', data); |         res = await API.put('/api/channel/', data); | ||||||
|         break; |         break; | ||||||
|       case 'priority': |       case 'priority': | ||||||
|         if (priority === '') { |         if (value === '') { | ||||||
|           return; |           return; | ||||||
|         } |         } | ||||||
|         data.priority = parseInt(priority); |         data.priority = parseInt(value); | ||||||
|  |         res = await API.put('/api/channel/', data); | ||||||
|  |         break; | ||||||
|  |       case 'weight': | ||||||
|  |         if (value === '') { | ||||||
|  |           return; | ||||||
|  |         } | ||||||
|  |         data.weight = parseInt(value); | ||||||
|  |         if (data.weight < 0) { | ||||||
|  |           data.weight = 0; | ||||||
|  |         } | ||||||
|         res = await API.put('/api/channel/', data); |         res = await API.put('/api/channel/', data); | ||||||
|         break; |         break; | ||||||
|     } |     } | ||||||
| @@ -142,9 +153,23 @@ const ChannelsTable = () => { | |||||||
|         return <Label basic color='green'>已启用</Label>; |         return <Label basic color='green'>已启用</Label>; | ||||||
|       case 2: |       case 2: | ||||||
|         return ( |         return ( | ||||||
|           <Label basic color='red'> |           <Popup | ||||||
|             已禁用 |             trigger={<Label basic color='red'> | ||||||
|           </Label> |               已禁用 | ||||||
|  |             </Label>} | ||||||
|  |             content='本渠道被手动禁用' | ||||||
|  |             basic | ||||||
|  |           /> | ||||||
|  |         ); | ||||||
|  |       case 3: | ||||||
|  |         return ( | ||||||
|  |           <Popup | ||||||
|  |             trigger={<Label basic color='yellow'> | ||||||
|  |               已禁用 | ||||||
|  |             </Label>} | ||||||
|  |             content='本渠道被程序自动禁用' | ||||||
|  |             basic | ||||||
|  |           /> | ||||||
|         ); |         ); | ||||||
|       default: |       default: | ||||||
|         return ( |         return ( | ||||||
| @@ -202,7 +227,6 @@ const ChannelsTable = () => { | |||||||
|       showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); |       showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); | ||||||
|     } else { |     } else { | ||||||
|       showError(message); |       showError(message); | ||||||
|       showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。") |  | ||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
| @@ -216,6 +240,17 @@ const ChannelsTable = () => { | |||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   const deleteAllDisabledChannels = async () => { | ||||||
|  |     const res = await API.delete(`/api/channel/disabled`); | ||||||
|  |     const { success, message, data } = res.data; | ||||||
|  |     if (success) { | ||||||
|  |       showSuccess(`已删除所有禁用渠道,共计 ${data} 个`); | ||||||
|  |       await refresh(); | ||||||
|  |     } else { | ||||||
|  |       showError(message); | ||||||
|  |     } | ||||||
|  |   }; | ||||||
|  |  | ||||||
|   const updateChannelBalance = async (id, name, idx) => { |   const updateChannelBalance = async (id, name, idx) => { | ||||||
|     const res = await API.get(`/api/channel/update_balance/${id}/`); |     const res = await API.get(`/api/channel/update_balance/${id}/`); | ||||||
|     const { success, message, balance } = res.data; |     const { success, message, balance } = res.data; | ||||||
| @@ -282,7 +317,19 @@ const ChannelsTable = () => { | |||||||
|           onChange={handleKeywordChange} |           onChange={handleKeywordChange} | ||||||
|         /> |         /> | ||||||
|       </Form> |       </Form> | ||||||
|  |       { | ||||||
|  |         showPrompt && ( | ||||||
|  |           <Message onDismiss={() => { | ||||||
|  |             setShowPrompt(false); | ||||||
|  |             setPromptShown("channel-test"); | ||||||
|  |           }}> | ||||||
|  |             当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo | ||||||
|  |             模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。 | ||||||
|  |  | ||||||
|  |             另外,OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。 | ||||||
|  |           </Message> | ||||||
|  |         ) | ||||||
|  |       } | ||||||
|       <Table basic compact size='small'> |       <Table basic compact size='small'> | ||||||
|         <Table.Header> |         <Table.Header> | ||||||
|           <Table.Row> |           <Table.Row> | ||||||
| @@ -343,10 +390,10 @@ const ChannelsTable = () => { | |||||||
|               余额 |               余额 | ||||||
|             </Table.HeaderCell> |             </Table.HeaderCell> | ||||||
|             <Table.HeaderCell |             <Table.HeaderCell | ||||||
|                 style={{ cursor: 'pointer' }} |               style={{ cursor: 'pointer' }} | ||||||
|                 onClick={() => { |               onClick={() => { | ||||||
|                   sortChannel('priority'); |                 sortChannel('priority'); | ||||||
|                 }} |               }} | ||||||
|             > |             > | ||||||
|               优先级 |               优先级 | ||||||
|             </Table.HeaderCell> |             </Table.HeaderCell> | ||||||
| @@ -390,18 +437,18 @@ const ChannelsTable = () => { | |||||||
|                   </Table.Cell> |                   </Table.Cell> | ||||||
|                   <Table.Cell> |                   <Table.Cell> | ||||||
|                     <Popup |                     <Popup | ||||||
|                         trigger={<Input type="number"  defaultValue={channel.priority} onBlur={(event) => { |                       trigger={<Input type='number' defaultValue={channel.priority} onBlur={(event) => { | ||||||
|                           manageChannel( |                         manageChannel( | ||||||
|                               channel.id, |                           channel.id, | ||||||
|                               'priority', |                           'priority', | ||||||
|                               idx, |                           idx, | ||||||
|                               event.target.value, |                           event.target.value | ||||||
|                           ); |                         ); | ||||||
|                         }}> |                       }}> | ||||||
|                           <input style={{maxWidth:'60px'}} /> |                         <input style={{ maxWidth: '60px' }} /> | ||||||
|                         </Input>} |                       </Input>} | ||||||
|                         content='渠道选择优先级,越高越优先' |                       content='渠道选择优先级,越高越优先' | ||||||
|                         basic |                       basic | ||||||
|                     /> |                     /> | ||||||
|                   </Table.Cell> |                   </Table.Cell> | ||||||
|                   <Table.Cell> |                   <Table.Cell> | ||||||
| @@ -481,6 +528,20 @@ const ChannelsTable = () => { | |||||||
|               </Button> |               </Button> | ||||||
|               <Button size='small' onClick={updateAllChannelsBalance} |               <Button size='small' onClick={updateAllChannelsBalance} | ||||||
|                       loading={loading || updatingBalance}>更新所有已启用通道余额</Button> |                       loading={loading || updatingBalance}>更新所有已启用通道余额</Button> | ||||||
|  |               <Popup | ||||||
|  |                 trigger={ | ||||||
|  |                   <Button size='small' loading={loading}> | ||||||
|  |                     删除禁用渠道 | ||||||
|  |                   </Button> | ||||||
|  |                 } | ||||||
|  |                 on='click' | ||||||
|  |                 flowing | ||||||
|  |                 hoverable | ||||||
|  |               > | ||||||
|  |                 <Button size='small' loading={loading} negative onClick={deleteAllDisabledChannels}> | ||||||
|  |                   确认删除 | ||||||
|  |                 </Button> | ||||||
|  |               </Popup> | ||||||
|               <Pagination |               <Pagination | ||||||
|                 floated='right' |                 floated='right' | ||||||
|                 activePage={activePage} |                 activePage={activePage} | ||||||
|   | |||||||
| @@ -2,8 +2,8 @@ import React, { useContext, useEffect, useState } from 'react'; | |||||||
| import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react'; | import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react'; | ||||||
| import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | ||||||
| import { UserContext } from '../context/User'; | import { UserContext } from '../context/User'; | ||||||
| import { API, getLogo, showError, showSuccess } from '../helpers'; | import { API, getLogo, showError, showSuccess, showWarning } from '../helpers'; | ||||||
| import { getOAuthState, onGitHubOAuthClicked } from './utils'; | import { onGitHubOAuthClicked } from './utils'; | ||||||
|  |  | ||||||
| const LoginForm = () => { | const LoginForm = () => { | ||||||
|   const [inputs, setInputs] = useState({ |   const [inputs, setInputs] = useState({ | ||||||
| @@ -68,8 +68,14 @@ const LoginForm = () => { | |||||||
|       if (success) { |       if (success) { | ||||||
|         userDispatch({ type: 'login', payload: data }); |         userDispatch({ type: 'login', payload: data }); | ||||||
|         localStorage.setItem('user', JSON.stringify(data)); |         localStorage.setItem('user', JSON.stringify(data)); | ||||||
|         navigate('/'); |         if (username === 'root' && password === '123456') { | ||||||
|         showSuccess('登录成功!'); |           navigate('/user/edit'); | ||||||
|  |           showSuccess('登录成功!'); | ||||||
|  |           showWarning('请立刻修改默认密码!'); | ||||||
|  |         } else { | ||||||
|  |           navigate('/token'); | ||||||
|  |           showSuccess('登录成功!'); | ||||||
|  |         } | ||||||
|       } else { |       } else { | ||||||
|         showError(message); |         showError(message); | ||||||
|       } |       } | ||||||
| @@ -126,7 +132,7 @@ const LoginForm = () => { | |||||||
|                 circular |                 circular | ||||||
|                 color='black' |                 color='black' | ||||||
|                 icon='github' |                 icon='github' | ||||||
|                 onClick={()=>onGitHubOAuthClicked(status.github_client_id)} |                 onClick={() => onGitHubOAuthClicked(status.github_client_id)} | ||||||
|               /> |               /> | ||||||
|             ) : ( |             ) : ( | ||||||
|               <></> |               <></> | ||||||
|   | |||||||
| @@ -138,7 +138,7 @@ const TokensTable = () => { | |||||||
|     let defaultUrl; |     let defaultUrl; | ||||||
|    |    | ||||||
|     if (chatLink) { |     if (chatLink) { | ||||||
|       defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`; |       defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; | ||||||
|     } else { |     } else { | ||||||
|       defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; |       defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ export const CHANNEL_OPTIONS = [ | |||||||
|   { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, |   { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, | ||||||
|   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, |   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, | ||||||
|   { key: 19, text: '360 智脑', value: 19, color: 'blue' }, |   { key: 19, text: '360 智脑', value: 19, color: 'blue' }, | ||||||
|  |   { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, | ||||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, |   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||||
|   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, |   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, | ||||||
|   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, |   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, | ||||||
|   | |||||||
| @@ -186,4 +186,14 @@ export const verifyJSON = (str) => { | |||||||
|     return false; |     return false; | ||||||
|   } |   } | ||||||
|   return true; |   return true; | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | export function shouldShowPrompt(id) { | ||||||
|  |   let prompt = localStorage.getItem(`prompt-${id}`); | ||||||
|  |   return !prompt; | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
|  | export function setPromptShown(id) { | ||||||
|  |   localStorage.setItem(`prompt-${id}`, 'true'); | ||||||
|  | } | ||||||
| @@ -19,6 +19,8 @@ function type2secretPrompt(type) { | |||||||
|       return '按照如下格式输入:APPID|APISecret|APIKey'; |       return '按照如下格式输入:APPID|APISecret|APIKey'; | ||||||
|     case 22: |     case 22: | ||||||
|       return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041'; |       return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041'; | ||||||
|  |     case 23: | ||||||
|  |       return '按照如下格式输入:AppId|SecretId|SecretKey'; | ||||||
|     default: |     default: | ||||||
|       return '请输入渠道对应的鉴权密钥'; |       return '请输入渠道对应的鉴权密钥'; | ||||||
|   } |   } | ||||||
| @@ -64,7 +66,7 @@ const EditChannel = () => { | |||||||
|           localModels = ['PaLM-2']; |           localModels = ['PaLM-2']; | ||||||
|           break; |           break; | ||||||
|         case 15: |         case 15: | ||||||
|           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; |           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1']; | ||||||
|           break; |           break; | ||||||
|         case 17: |         case 17: | ||||||
|           localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1']; |           localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1']; | ||||||
| @@ -76,7 +78,10 @@ const EditChannel = () => { | |||||||
|           localModels = ['SparkDesk']; |           localModels = ['SparkDesk']; | ||||||
|           break; |           break; | ||||||
|         case 19: |         case 19: | ||||||
|           localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4']; |           localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; | ||||||
|  |           break; | ||||||
|  |         case 23: | ||||||
|  |           localModels = ['hunyuan']; | ||||||
|           break; |           break; | ||||||
|       } |       } | ||||||
|       setInputs((inputs) => ({ ...inputs, models: localModels })); |       setInputs((inputs) => ({ ...inputs, models: localModels })); | ||||||
|   | |||||||
| @@ -1,19 +1,12 @@ | |||||||
| import React from 'react'; | import React from 'react'; | ||||||
| import { Segment, Header } from 'semantic-ui-react'; | import { Message } from 'semantic-ui-react'; | ||||||
|  |  | ||||||
| const NotFound = () => ( | const NotFound = () => ( | ||||||
|   <> |   <> | ||||||
|     <Header |     <Message negative> | ||||||
|       block |       <Message.Header>页面不存在</Message.Header> | ||||||
|       as="h4" |       <p>请检查你的浏览器地址是否正确</p> | ||||||
|       content="404" |     </Message> | ||||||
|       attached="top" |  | ||||||
|       icon="info" |  | ||||||
|       className="small-icon" |  | ||||||
|     /> |  | ||||||
|     <Segment attached="bottom"> |  | ||||||
|       未找到所请求的页面 |  | ||||||
|     </Segment> |  | ||||||
|   </> |   </> | ||||||
| ); | ); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -102,7 +102,7 @@ const EditUser = () => { | |||||||
|               label='密码' |               label='密码' | ||||||
|               name='password' |               name='password' | ||||||
|               type={'password'} |               type={'password'} | ||||||
|               placeholder={'请输入新的密码'} |               placeholder={'请输入新的密码,最短 8 位'} | ||||||
|               onChange={handleInputChange} |               onChange={handleInputChange} | ||||||
|               value={password} |               value={password} | ||||||
|               autoComplete='new-password' |               autoComplete='new-password' | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user