mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-31 22:03:41 +08:00 
			
		
		
		
	Compare commits
	
		
			31 Commits
		
	
	
		
			v0.5.0-alp
			...
			v0.5.2-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | fe8f216dd9 | ||
|  | b7d0616ae0 | ||
|  | ce9c8024a6 | ||
|  | 8a866078b2 | ||
|  | 3e81d8af45 | ||
|  | b8cb86c2c1 | ||
|  | f45d586400 | ||
|  | 50dec03ff3 | ||
|  | f31d400b6f | ||
|  | 130e6bfd83 | ||
|  | d1335ebc01 | ||
|  | e92da7928b | ||
|  | d1b6f492b6 | ||
|  | b9f6461dd4 | ||
|  | 0a39521a3d | ||
|  | c134604cee | ||
|  | 929e43ef81 | ||
|  | dce8bbe1ca | ||
|  | bc2f48b1f2 | ||
|  | 889af8b2db | ||
|  | 4eea096654 | ||
|  | 4ab3211c0e | ||
|  | 3da119efba | ||
|  | dccd66b852 | ||
|  | 2fcd6852e0 | ||
|  | 9b4d1964d4 | ||
|  | 806bf8241c | ||
|  | ce93c9b6b2 | ||
|  | 4ec4289565 | ||
|  | 3dc5a0f91d | ||
|  | 80a846673a | 
							
								
								
									
										18
									
								
								README.en.md
									
									
									
									
									
								
							
							
						
						
									
										18
									
								
								README.en.md
									
									
									
									
									
								
							| @@ -10,7 +10,7 @@ | |||||||
|  |  | ||||||
| # One API | # One API | ||||||
|  |  | ||||||
| _✨ An OpenAI key management & redistribution system, easy to deploy & use ✨_ | _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use ✨_ | ||||||
|  |  | ||||||
| </div> | </div> | ||||||
|  |  | ||||||
| @@ -57,15 +57,13 @@ _✨ An OpenAI key management & redistribution system, easy to deploy & use ✨_ | |||||||
| > **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability. | > **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability. | ||||||
|  |  | ||||||
| ## Features | ## Features | ||||||
| 1. Supports multiple API access channels: | 1. Support for multiple large models: | ||||||
|     + [x] Official OpenAI channel (support proxy configuration) |    + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||||
|     + [x] **Azure OpenAI API** |    + [x] [Anthropic Claude Series Models](https://anthropic.com) | ||||||
|     + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) |    + [x] [Google PaLM2 Series Models](https://developers.generativeai.google) | ||||||
|     + [x] [OpenAI-SB](https://openai-sb.com) |    + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||||
|     + [x] [API2D](https://api2d.com/r/197971) |    + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) | ||||||
|     + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) |    + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) | ||||||
|     + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (invitation code: `OneAPI`) |  | ||||||
|     + [x] Custom channel: Various third-party proxy services not included in the list |  | ||||||
| 2. Supports access to multiple channels through **load balancing**. | 2. Supports access to multiple channels through **load balancing**. | ||||||
| 3. Supports **stream mode** that enables typewriter-like effect through stream transmission. | 3. Supports **stream mode** that enables typewriter-like effect through stream transmission. | ||||||
| 4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details. | 4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details. | ||||||
|   | |||||||
							
								
								
									
										49
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										49
									
								
								README.md
									
									
									
									
									
								
							| @@ -11,7 +11,7 @@ | |||||||
|  |  | ||||||
| # One API | # One API | ||||||
|  |  | ||||||
| _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用✨_ | _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 ✨_ | ||||||
|  |  | ||||||
| </div> | </div> | ||||||
|  |  | ||||||
| @@ -58,41 +58,42 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用 | |||||||
| > **Warning**:从 `v0.3` 版本升级到 `v0.4` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.3-v0.4.sql)。 | > **Warning**:从 `v0.3` 版本升级到 `v0.4` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.3-v0.4.sql)。 | ||||||
|  |  | ||||||
| ## 功能 | ## 功能 | ||||||
| 1. 支持多种 API 访问渠道: | 1. 支持多种大模型: | ||||||
|    + [x] OpenAI 官方通道(支持配置镜像) |    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||||
|    + [x] **Azure OpenAI API** |  | ||||||
|    + [x] [Anthropic Claude 系列模型](https://anthropic.com) |    + [x] [Anthropic Claude 系列模型](https://anthropic.com) | ||||||
|    + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) |    + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) | ||||||
|    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) |    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||||
|  |    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) | ||||||
|  |    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||||
|    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) |    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) | ||||||
|    + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) | 2. 支持配置镜像以及众多第三方代理服务: | ||||||
|    + [x] [OpenAI-SB](https://openai-sb.com) |    + [x] [OpenAI-SB](https://openai-sb.com) | ||||||
|    + [x] [API2D](https://api2d.com/r/197971) |    + [x] [API2D](https://api2d.com/r/197971) | ||||||
|    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) |    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) | ||||||
|    + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) |    + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) | ||||||
|    + [x] [CloseAI](https://console.closeai-asia.com/r/2412) |    + [x] [CloseAI](https://console.closeai-asia.com/r/2412) | ||||||
|    + [x] 自定义渠道:例如各种未收录的第三方代理服务 |    + [x] 自定义渠道:例如各种未收录的第三方代理服务 | ||||||
| 2. 支持通过**负载均衡**的方式访问多个渠道。 | 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||||
| 3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||||
| 4. 支持**多机部署**,[详见此处](#多机部署)。 | 5. 支持**多机部署**,[详见此处](#多机部署)。 | ||||||
| 5. 支持**令牌管理**,设置令牌的过期时间和额度。 | 6. 支持**令牌管理**,设置令牌的过期时间和额度。 | ||||||
| 6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 | 7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 | ||||||
| 7. 支持**通道管理**,批量创建通道。 | 8. 支持**通道管理**,批量创建通道。 | ||||||
| 8. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 | 9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 | ||||||
| 9. 支持渠道**设置模型列表**。 | 10. 支持渠道**设置模型列表**。 | ||||||
| 10. 支持**查看额度明细**。 | 11. 支持**查看额度明细**。 | ||||||
| 11. 支持**用户邀请奖励**。 | 12. 支持**用户邀请奖励**。 | ||||||
| 12. 支持以美元为单位显示额度。 | 13. 支持以美元为单位显示额度。 | ||||||
| 13. 支持发布公告,设置充值链接,设置新用户初始额度。 | 14. 支持发布公告,设置充值链接,设置新用户初始额度。 | ||||||
| 14. 支持模型映射,重定向用户的请求模型。 | 15. 支持模型映射,重定向用户的请求模型。 | ||||||
| 15. 支持失败自动重试。 | 16. 支持失败自动重试。 | ||||||
| 16. 支持绘图接口。 | 17. 支持绘图接口。 | ||||||
| 17. 支持丰富的**自定义**设置, | 18. 支持丰富的**自定义**设置, | ||||||
|     1. 支持自定义系统名称,logo 以及页脚。 |     1. 支持自定义系统名称,logo 以及页脚。 | ||||||
|     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 |     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 | ||||||
| 18. 支持通过系统访问令牌访问管理 API。 | 19. 支持通过系统访问令牌访问管理 API。 | ||||||
| 19. 支持 Cloudflare Turnstile 用户校验。 | 20. 支持 Cloudflare Turnstile 用户校验。 | ||||||
| 20. 支持用户管理,支持**多种用户登录注册方式**: | 21. 支持用户管理,支持**多种用户登录注册方式**: | ||||||
|     + 邮箱登录注册以及通过邮箱进行密码重置。 |     + 邮箱登录注册以及通过邮箱进行密码重置。 | ||||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 |     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 |     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||||
|   | |||||||
| @@ -77,6 +77,8 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" | |||||||
| var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | ||||||
| var RequestInterval = time.Duration(requestInterval) * time.Second | var RequestInterval = time.Duration(requestInterval) * time.Second | ||||||
|  |  | ||||||
|  | var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	RoleGuestUser  = 0 | 	RoleGuestUser  = 0 | ||||||
| 	RoleCommonUser = 1 | 	RoleCommonUser = 1 | ||||||
| @@ -154,6 +156,8 @@ const ( | |||||||
| 	ChannelTypeAnthropic = 14 | 	ChannelTypeAnthropic = 14 | ||||||
| 	ChannelTypeBaidu     = 15 | 	ChannelTypeBaidu     = 15 | ||||||
| 	ChannelTypeZhipu     = 16 | 	ChannelTypeZhipu     = 16 | ||||||
|  | 	ChannelTypeAli       = 17 | ||||||
|  | 	ChannelTypeXunfei    = 18 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ChannelBaseURLs = []string{ | var ChannelBaseURLs = []string{ | ||||||
| @@ -174,4 +178,6 @@ var ChannelBaseURLs = []string{ | |||||||
| 	"https://api.anthropic.com",      // 14 | 	"https://api.anthropic.com",      // 14 | ||||||
| 	"https://aip.baidubce.com",       // 15 | 	"https://aip.baidubce.com",       // 15 | ||||||
| 	"https://open.bigmodel.cn",       // 16 | 	"https://open.bigmodel.cn",       // 16 | ||||||
|  | 	"https://dashscope.aliyuncs.com", // 17 | ||||||
|  | 	"",                               // 18 | ||||||
| } | } | ||||||
|   | |||||||
| @@ -42,10 +42,14 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"claude-2":                30, | 	"claude-2":                30, | ||||||
| 	"ERNIE-Bot":               0.8572, // ¥0.012 / 1k tokens | 	"ERNIE-Bot":               0.8572, // ¥0.012 / 1k tokens | ||||||
| 	"ERNIE-Bot-turbo":         0.5715, // ¥0.008 / 1k tokens | 	"ERNIE-Bot-turbo":         0.5715, // ¥0.008 / 1k tokens | ||||||
|  | 	"Embedding-V1":            0.1429, // ¥0.002 / 1k tokens | ||||||
| 	"PaLM-2":                  1, | 	"PaLM-2":                  1, | ||||||
| 	"chatglm_pro":             0.7143, // ¥0.01 / 1k tokens | 	"chatglm_pro":             0.7143, // ¥0.01 / 1k tokens | ||||||
| 	"chatglm_std":             0.3572, // ¥0.005 / 1k tokens | 	"chatglm_std":             0.3572, // ¥0.005 / 1k tokens | ||||||
| 	"chatglm_lite":            0.1429, // ¥0.002 / 1k tokens | 	"chatglm_lite":            0.1429, // ¥0.002 / 1k tokens | ||||||
|  | 	"qwen-v1":                 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag | ||||||
|  | 	"qwen-plus-v1":            0.5715, // Same as above | ||||||
|  | 	"SparkDesk":               0.8572, // TBD | ||||||
| } | } | ||||||
|  |  | ||||||
| func ModelRatio2JSONString() string { | func ModelRatio2JSONString() string { | ||||||
|   | |||||||
| @@ -11,9 +11,11 @@ func GetSubscription(c *gin.Context) { | |||||||
| 	var usedQuota int | 	var usedQuota int | ||||||
| 	var err error | 	var err error | ||||||
| 	var token *model.Token | 	var token *model.Token | ||||||
|  | 	var expiredTime int64 | ||||||
| 	if common.DisplayTokenStatEnabled { | 	if common.DisplayTokenStatEnabled { | ||||||
| 		tokenId := c.GetInt("token_id") | 		tokenId := c.GetInt("token_id") | ||||||
| 		token, err = model.GetTokenById(tokenId) | 		token, err = model.GetTokenById(tokenId) | ||||||
|  | 		expiredTime = token.ExpiredTime | ||||||
| 		remainQuota = token.RemainQuota | 		remainQuota = token.RemainQuota | ||||||
| 		usedQuota = token.UsedQuota | 		usedQuota = token.UsedQuota | ||||||
| 	} else { | 	} else { | ||||||
| @@ -21,6 +23,9 @@ func GetSubscription(c *gin.Context) { | |||||||
| 		remainQuota, err = model.GetUserQuota(userId) | 		remainQuota, err = model.GetUserQuota(userId) | ||||||
| 		usedQuota, err = model.GetUserUsedQuota(userId) | 		usedQuota, err = model.GetUserUsedQuota(userId) | ||||||
| 	} | 	} | ||||||
|  | 	if expiredTime <= 0 { | ||||||
|  | 		expiredTime = 0 | ||||||
|  | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		openAIError := OpenAIError{ | 		openAIError := OpenAIError{ | ||||||
| 			Message: err.Error(), | 			Message: err.Error(), | ||||||
| @@ -45,6 +50,7 @@ func GetSubscription(c *gin.Context) { | |||||||
| 		SoftLimitUSD:       amount, | 		SoftLimitUSD:       amount, | ||||||
| 		HardLimitUSD:       amount, | 		HardLimitUSD:       amount, | ||||||
| 		SystemHardLimitUSD: amount, | 		SystemHardLimitUSD: amount, | ||||||
|  | 		AccessUntil:        expiredTime, | ||||||
| 	} | 	} | ||||||
| 	c.JSON(200, subscription) | 	c.JSON(200, subscription) | ||||||
| 	return | 	return | ||||||
|   | |||||||
| @@ -22,6 +22,7 @@ type OpenAISubscriptionResponse struct { | |||||||
| 	SoftLimitUSD       float64 `json:"soft_limit_usd"` | 	SoftLimitUSD       float64 `json:"soft_limit_usd"` | ||||||
| 	HardLimitUSD       float64 `json:"hard_limit_usd"` | 	HardLimitUSD       float64 `json:"hard_limit_usd"` | ||||||
| 	SystemHardLimitUSD float64 `json:"system_hard_limit_usd"` | 	SystemHardLimitUSD float64 `json:"system_hard_limit_usd"` | ||||||
|  | 	AccessUntil        int64   `json:"access_until"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type OpenAIUsageDailyCost struct { | type OpenAIUsageDailyCost struct { | ||||||
| @@ -84,7 +85,6 @@ func GetAuthHeader(token string) http.Header { | |||||||
| } | } | ||||||
|  |  | ||||||
| func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { | func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { | ||||||
| 	client := &http.Client{} |  | ||||||
| 	req, err := http.NewRequest(method, url, nil) | 	req, err := http.NewRequest(method, url, nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -92,10 +92,13 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He | |||||||
| 	for k := range headers { | 	for k := range headers { | ||||||
| 		req.Header.Add(k, headers.Get(k)) | 		req.Header.Add(k, headers.Get(k)) | ||||||
| 	} | 	} | ||||||
| 	res, err := client.Do(req) | 	res, err := httpClient.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | 	if res.StatusCode != http.StatusOK { | ||||||
|  | 		return nil, fmt.Errorf("status code: %d", res.StatusCode) | ||||||
|  | 	} | ||||||
| 	body, err := io.ReadAll(res.Body) | 	body, err := io.ReadAll(res.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
|   | |||||||
| @@ -16,6 +16,14 @@ import ( | |||||||
|  |  | ||||||
| func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { | func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { | ||||||
| 	switch channel.Type { | 	switch channel.Type { | ||||||
|  | 	case common.ChannelTypePaLM: | ||||||
|  | 		fallthrough | ||||||
|  | 	case common.ChannelTypeAnthropic: | ||||||
|  | 		fallthrough | ||||||
|  | 	case common.ChannelTypeBaidu: | ||||||
|  | 		fallthrough | ||||||
|  | 	case common.ChannelTypeZhipu: | ||||||
|  | 		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil | ||||||
| 	case common.ChannelTypeAzure: | 	case common.ChannelTypeAzure: | ||||||
| 		request.Model = "gpt-35-turbo" | 		request.Model = "gpt-35-turbo" | ||||||
| 	default: | 	default: | ||||||
| @@ -45,8 +53,7 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr | |||||||
| 		req.Header.Set("Authorization", "Bearer "+channel.Key) | 		req.Header.Set("Authorization", "Bearer "+channel.Key) | ||||||
| 	} | 	} | ||||||
| 	req.Header.Set("Content-Type", "application/json") | 	req.Header.Set("Content-Type", "application/json") | ||||||
| 	client := &http.Client{} | 	resp, err := httpClient.Do(req) | ||||||
| 	resp, err := client.Do(req) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err, nil | 		return err, nil | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -128,7 +128,8 @@ func SendPasswordResetEmail(c *gin.Context) { | |||||||
| 	subject := fmt.Sprintf("%s密码重置", common.SystemName) | 	subject := fmt.Sprintf("%s密码重置", common.SystemName) | ||||||
| 	content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ | 	content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ | ||||||
| 		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ | 		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ | ||||||
| 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, common.VerificationValidMinutes) | 		"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ | ||||||
|  | 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes) | ||||||
| 	err := common.SendEmail(subject, email, content) | 	err := common.SendEmail(subject, email, content) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|   | |||||||
| @@ -288,6 +288,15 @@ func init() { | |||||||
| 			Root:       "ERNIE-Bot-turbo", | 			Root:       "ERNIE-Bot-turbo", | ||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "Embedding-V1", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "baidu", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "Embedding-V1", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			Id:         "PaLM-2", | 			Id:         "PaLM-2", | ||||||
| 			Object:     "model", | 			Object:     "model", | ||||||
| @@ -324,6 +333,33 @@ func init() { | |||||||
| 			Root:       "chatglm_lite", | 			Root:       "chatglm_lite", | ||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "qwen-v1", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "ali", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "qwen-v1", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "qwen-plus-v1", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "ali", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "qwen-plus-v1", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "SparkDesk", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "xunfei", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "SparkDesk", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
| 	openAIModelsMap = make(map[string]OpenAIModels) | 	openAIModelsMap = make(map[string]OpenAIModels) | ||||||
| 	for _, model := range openAIModels { | 	for _, model := range openAIModels { | ||||||
|   | |||||||
							
								
								
									
										240
									
								
								controller/relay-ali.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										240
									
								
								controller/relay-ali.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,240 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r | ||||||
|  |  | ||||||
|  | type AliMessage struct { | ||||||
|  | 	User string `json:"user"` | ||||||
|  | 	Bot  string `json:"bot"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliInput struct { | ||||||
|  | 	Prompt  string       `json:"prompt"` | ||||||
|  | 	History []AliMessage `json:"history"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliParameters struct { | ||||||
|  | 	TopP         float64 `json:"top_p,omitempty"` | ||||||
|  | 	TopK         int     `json:"top_k,omitempty"` | ||||||
|  | 	Seed         uint64  `json:"seed,omitempty"` | ||||||
|  | 	EnableSearch bool    `json:"enable_search,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliChatRequest struct { | ||||||
|  | 	Model      string        `json:"model"` | ||||||
|  | 	Input      AliInput      `json:"input"` | ||||||
|  | 	Parameters AliParameters `json:"parameters,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliError struct { | ||||||
|  | 	Code      string `json:"code"` | ||||||
|  | 	Message   string `json:"message"` | ||||||
|  | 	RequestId string `json:"request_id"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliUsage struct { | ||||||
|  | 	InputTokens  int `json:"input_tokens"` | ||||||
|  | 	OutputTokens int `json:"output_tokens"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliOutput struct { | ||||||
|  | 	Text         string `json:"text"` | ||||||
|  | 	FinishReason string `json:"finish_reason"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliChatResponse struct { | ||||||
|  | 	Output AliOutput `json:"output"` | ||||||
|  | 	Usage  AliUsage  `json:"usage"` | ||||||
|  | 	AliError | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { | ||||||
|  | 	messages := make([]AliMessage, 0, len(request.Messages)) | ||||||
|  | 	prompt := "" | ||||||
|  | 	for i := 0; i < len(request.Messages); i++ { | ||||||
|  | 		message := request.Messages[i] | ||||||
|  | 		if message.Role == "system" { | ||||||
|  | 			messages = append(messages, AliMessage{ | ||||||
|  | 				User: message.Content, | ||||||
|  | 				Bot:  "Okay", | ||||||
|  | 			}) | ||||||
|  | 			continue | ||||||
|  | 		} else { | ||||||
|  | 			if i == len(request.Messages)-1 { | ||||||
|  | 				prompt = message.Content | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 			messages = append(messages, AliMessage{ | ||||||
|  | 				User: message.Content, | ||||||
|  | 				Bot:  request.Messages[i+1].Content, | ||||||
|  | 			}) | ||||||
|  | 			i++ | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return &AliChatRequest{ | ||||||
|  | 		Model: request.Model, | ||||||
|  | 		Input: AliInput{ | ||||||
|  | 			Prompt:  prompt, | ||||||
|  | 			History: messages, | ||||||
|  | 		}, | ||||||
|  | 		//Parameters: AliParameters{  // ChatGPT's parameters are not compatible with Ali's | ||||||
|  | 		//	TopP: request.TopP, | ||||||
|  | 		//	TopK: 50, | ||||||
|  | 		//	//Seed:         0, | ||||||
|  | 		//	//EnableSearch: false, | ||||||
|  | 		//}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { | ||||||
|  | 	choice := OpenAITextResponseChoice{ | ||||||
|  | 		Index: 0, | ||||||
|  | 		Message: Message{ | ||||||
|  | 			Role:    "assistant", | ||||||
|  | 			Content: response.Output.Text, | ||||||
|  | 		}, | ||||||
|  | 		FinishReason: response.Output.FinishReason, | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := OpenAITextResponse{ | ||||||
|  | 		Id:      response.RequestId, | ||||||
|  | 		Object:  "chat.completion", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Choices: []OpenAITextResponseChoice{choice}, | ||||||
|  | 		Usage: Usage{ | ||||||
|  | 			PromptTokens:     response.Usage.InputTokens, | ||||||
|  | 			CompletionTokens: response.Usage.OutputTokens, | ||||||
|  | 			TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	return &fullTextResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { | ||||||
|  | 	var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 	choice.Delta.Content = aliResponse.Output.Text | ||||||
|  | 	choice.FinishReason = aliResponse.Output.FinishReason | ||||||
|  | 	response := ChatCompletionsStreamResponse{ | ||||||
|  | 		Id:      aliResponse.RequestId, | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   "ernie-bot", | ||||||
|  | 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | 	return &response | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var usage Usage | ||||||
|  | 	scanner := bufio.NewScanner(resp.Body) | ||||||
|  | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
|  | 		if atEOF && len(data) == 0 { | ||||||
|  | 			return 0, nil, nil | ||||||
|  | 		} | ||||||
|  | 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||||
|  | 			return i + 1, data[0:i], nil | ||||||
|  | 		} | ||||||
|  | 		if atEOF { | ||||||
|  | 			return len(data), data, nil | ||||||
|  | 		} | ||||||
|  | 		return 0, nil, nil | ||||||
|  | 	}) | ||||||
|  | 	dataChan := make(chan string) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		for scanner.Scan() { | ||||||
|  | 			data := scanner.Text() | ||||||
|  | 			if len(data) < 5 { // ignore blank line or wrong format | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			if data[:5] != "data:" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			data = data[5:] | ||||||
|  | 			dataChan <- data | ||||||
|  | 		} | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||||
|  | 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||||
|  | 	c.Writer.Header().Set("Connection", "keep-alive") | ||||||
|  | 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||||
|  | 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||||
|  | 	lastResponseText := "" | ||||||
|  | 	c.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case data := <-dataChan: | ||||||
|  | 			var aliResponse AliChatResponse | ||||||
|  | 			err := json.Unmarshal([]byte(data), &aliResponse) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			usage.PromptTokens += aliResponse.Usage.InputTokens | ||||||
|  | 			usage.CompletionTokens += aliResponse.Usage.OutputTokens | ||||||
|  | 			usage.TotalTokens += aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | ||||||
|  | 			response := streamResponseAli2OpenAI(&aliResponse) | ||||||
|  | 			response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) | ||||||
|  | 			lastResponseText = aliResponse.Output.Text | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	err := resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	return nil, &usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var aliResponse AliChatResponse | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &aliResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	if aliResponse.Code != "" { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: aliResponse.Message, | ||||||
|  | 				Type:    aliResponse.Code, | ||||||
|  | 				Param:   aliResponse.RequestId, | ||||||
|  | 				Code:    aliResponse.Code, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := responseAli2OpenAI(&aliResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
| @@ -54,14 +54,44 @@ type BaiduChatStreamResponse struct { | |||||||
| 	IsEnd      bool `json:"is_end"` | 	IsEnd      bool `json:"is_end"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type BaiduEmbeddingRequest struct { | ||||||
|  | 	Input []string `json:"input"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type BaiduEmbeddingData struct { | ||||||
|  | 	Object    string    `json:"object"` | ||||||
|  | 	Embedding []float64 `json:"embedding"` | ||||||
|  | 	Index     int       `json:"index"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type BaiduEmbeddingResponse struct { | ||||||
|  | 	Id      string               `json:"id"` | ||||||
|  | 	Object  string               `json:"object"` | ||||||
|  | 	Created int64                `json:"created"` | ||||||
|  | 	Data    []BaiduEmbeddingData `json:"data"` | ||||||
|  | 	Usage   Usage                `json:"usage"` | ||||||
|  | 	BaiduError | ||||||
|  | } | ||||||
|  |  | ||||||
| func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { | func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { | ||||||
| 	messages := make([]BaiduMessage, 0, len(request.Messages)) | 	messages := make([]BaiduMessage, 0, len(request.Messages)) | ||||||
| 	for _, message := range request.Messages { | 	for _, message := range request.Messages { | ||||||
|  | 		if message.Role == "system" { | ||||||
|  | 			messages = append(messages, BaiduMessage{ | ||||||
|  | 				Role:    "user", | ||||||
|  | 				Content: message.Content, | ||||||
|  | 			}) | ||||||
|  | 			messages = append(messages, BaiduMessage{ | ||||||
|  | 				Role:    "assistant", | ||||||
|  | 				Content: "Okay", | ||||||
|  | 			}) | ||||||
|  | 		} else { | ||||||
| 			messages = append(messages, BaiduMessage{ | 			messages = append(messages, BaiduMessage{ | ||||||
| 				Role:    message.Role, | 				Role:    message.Role, | ||||||
| 				Content: message.Content, | 				Content: message.Content, | ||||||
| 			}) | 			}) | ||||||
| 		} | 		} | ||||||
|  | 	} | ||||||
| 	return &BaiduChatRequest{ | 	return &BaiduChatRequest{ | ||||||
| 		Messages: messages, | 		Messages: messages, | ||||||
| 		Stream:   request.Stream, | 		Stream:   request.Stream, | ||||||
| @@ -101,6 +131,36 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom | |||||||
| 	return &response | 	return &response | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { | ||||||
|  | 	baiduEmbeddingRequest := BaiduEmbeddingRequest{ | ||||||
|  | 		Input: nil, | ||||||
|  | 	} | ||||||
|  | 	switch request.Input.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		baiduEmbeddingRequest.Input = []string{request.Input.(string)} | ||||||
|  | 	case []string: | ||||||
|  | 		baiduEmbeddingRequest.Input = request.Input.([]string) | ||||||
|  | 	} | ||||||
|  | 	return &baiduEmbeddingRequest | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { | ||||||
|  | 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ | ||||||
|  | 		Object: "list", | ||||||
|  | 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), | ||||||
|  | 		Model:  "baidu-embedding", | ||||||
|  | 		Usage:  response.Usage, | ||||||
|  | 	} | ||||||
|  | 	for _, item := range response.Data { | ||||||
|  | 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ | ||||||
|  | 			Object:    item.Object, | ||||||
|  | 			Index:     item.Index, | ||||||
|  | 			Embedding: item.Embedding, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	return &openAIEmbeddingResponse | ||||||
|  | } | ||||||
|  |  | ||||||
| func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
| 	var usage Usage | 	var usage Usage | ||||||
| 	scanner := bufio.NewScanner(resp.Body) | 	scanner := bufio.NewScanner(resp.Body) | ||||||
| @@ -201,3 +261,39 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo | |||||||
| 	_, err = c.Writer.Write(jsonResponse) | 	_, err = c.Writer.Write(jsonResponse) | ||||||
| 	return nil, &fullTextResponse.Usage | 	return nil, &fullTextResponse.Usage | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var baiduResponse BaiduEmbeddingResponse | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &baiduResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	if baiduResponse.ErrorMsg != "" { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: baiduResponse.ErrorMsg, | ||||||
|  | 				Type:    "baidu_error", | ||||||
|  | 				Param:   "", | ||||||
|  | 				Code:    baiduResponse.ErrorCode, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
|   | |||||||
| @@ -69,11 +69,11 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { | |||||||
| 			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) | 			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) | ||||||
| 		} else if message.Role == "assistant" { | 		} else if message.Role == "assistant" { | ||||||
| 			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) | 			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) | ||||||
| 		} else { | 		} else if message.Role == "system" { | ||||||
| 			// ignore other roles | 			prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 	prompt += "\n\nAssistant:" | 	prompt += "\n\nAssistant:" | ||||||
| 	} |  | ||||||
| 	claudeRequest.Prompt = prompt | 	claudeRequest.Prompt = prompt | ||||||
| 	return &claudeRequest | 	return &claudeRequest | ||||||
| } | } | ||||||
|   | |||||||
| @@ -109,8 +109,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||||
|  |  | ||||||
| 	client := &http.Client{} | 	resp, err := httpClient.Do(req) | ||||||
| 	resp, err := client.Do(req) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -115,7 +115,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*Ope | |||||||
| 	} | 	} | ||||||
| 	// We shouldn't set the header before we parse the response body, because the parse part may fail. | 	// We shouldn't set the header before we parse the response body, because the parse part may fail. | ||||||
| 	// And then we will have to send an error response, but in this case, the header has already been set. | 	// And then we will have to send an error response, but in this case, the header has already been set. | ||||||
| 	// So the client will be confused by the response. | 	// So the httpClient will be confused by the response. | ||||||
| 	// For example, Postman will report error, and we cannot check the response at all. | 	// For example, Postman will report error, and we cannot check the response at all. | ||||||
| 	for k, v := range resp.Header { | 	for k, v := range resp.Header { | ||||||
| 		c.Writer.Header().Set(k, v[0]) | 		c.Writer.Header().Set(k, v[0]) | ||||||
|   | |||||||
| @@ -20,8 +20,16 @@ const ( | |||||||
| 	APITypePaLM | 	APITypePaLM | ||||||
| 	APITypeBaidu | 	APITypeBaidu | ||||||
| 	APITypeZhipu | 	APITypeZhipu | ||||||
|  | 	APITypeAli | ||||||
|  | 	APITypeXunfei | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | var httpClient *http.Client | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	httpClient = &http.Client{} | ||||||
|  | } | ||||||
|  |  | ||||||
| func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||||
| 	channelType := c.GetInt("channel") | 	channelType := c.GetInt("channel") | ||||||
| 	tokenId := c.GetInt("token_id") | 	tokenId := c.GetInt("token_id") | ||||||
| @@ -67,7 +75,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 	// map model name | 	// map model name | ||||||
| 	modelMapping := c.GetString("model_mapping") | 	modelMapping := c.GetString("model_mapping") | ||||||
| 	isModelMapped := false | 	isModelMapped := false | ||||||
| 	if modelMapping != "" { | 	if modelMapping != "" && modelMapping != "{}" { | ||||||
| 		modelMap := make(map[string]string) | 		modelMap := make(map[string]string) | ||||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @@ -79,14 +87,19 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	apiType := APITypeOpenAI | 	apiType := APITypeOpenAI | ||||||
| 	if strings.HasPrefix(textRequest.Model, "claude") { | 	switch channelType { | ||||||
|  | 	case common.ChannelTypeAnthropic: | ||||||
| 		apiType = APITypeClaude | 		apiType = APITypeClaude | ||||||
| 	} else if strings.HasPrefix(textRequest.Model, "ERNIE") { | 	case common.ChannelTypeBaidu: | ||||||
| 		apiType = APITypeBaidu | 		apiType = APITypeBaidu | ||||||
| 	} else if strings.HasPrefix(textRequest.Model, "PaLM") { | 	case common.ChannelTypePaLM: | ||||||
| 		apiType = APITypePaLM | 		apiType = APITypePaLM | ||||||
| 	} else if strings.HasPrefix(textRequest.Model, "chatglm_") { | 	case common.ChannelTypeZhipu: | ||||||
| 		apiType = APITypeZhipu | 		apiType = APITypeZhipu | ||||||
|  | 	case common.ChannelTypeAli: | ||||||
|  | 		apiType = APITypeAli | ||||||
|  | 	case common.ChannelTypeXunfei: | ||||||
|  | 		apiType = APITypeXunfei | ||||||
| 	} | 	} | ||||||
| 	baseURL := common.ChannelBaseURLs[channelType] | 	baseURL := common.ChannelBaseURLs[channelType] | ||||||
| 	requestURL := c.Request.URL.String() | 	requestURL := c.Request.URL.String() | ||||||
| @@ -128,12 +141,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" | 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" | ||||||
| 		case "BLOOMZ-7B": | 		case "BLOOMZ-7B": | ||||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" | 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" | ||||||
|  | 		case "Embedding-V1": | ||||||
|  | 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" | ||||||
| 		} | 		} | ||||||
| 		apiKey := c.Request.Header.Get("Authorization") | 		apiKey := c.Request.Header.Get("Authorization") | ||||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||||
| 		fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days | 		fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days | ||||||
| 	case APITypePaLM: | 	case APITypePaLM: | ||||||
| 		fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" | 		fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" | ||||||
|  | 		if baseURL != "" { | ||||||
|  | 			fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) | ||||||
|  | 		} | ||||||
| 		apiKey := c.Request.Header.Get("Authorization") | 		apiKey := c.Request.Header.Get("Authorization") | ||||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||||
| 		fullRequestURL += "?key=" + apiKey | 		fullRequestURL += "?key=" + apiKey | ||||||
| @@ -143,6 +161,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			method = "sse-invoke" | 			method = "sse-invoke" | ||||||
| 		} | 		} | ||||||
| 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) | 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) | ||||||
|  | 	case APITypeAli: | ||||||
|  | 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" | ||||||
| 	} | 	} | ||||||
| 	var promptTokens int | 	var promptTokens int | ||||||
| 	var completionTokens int | 	var completionTokens int | ||||||
| @@ -196,12 +216,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
| 	case APITypeBaidu: | 	case APITypeBaidu: | ||||||
|  | 		var jsonData []byte | ||||||
|  | 		var err error | ||||||
|  | 		switch relayMode { | ||||||
|  | 		case RelayModeEmbeddings: | ||||||
|  | 			baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) | ||||||
|  | 			jsonData, err = json.Marshal(baiduEmbeddingRequest) | ||||||
|  | 		default: | ||||||
| 			baiduRequest := requestOpenAI2Baidu(textRequest) | 			baiduRequest := requestOpenAI2Baidu(textRequest) | ||||||
| 		jsonStr, err := json.Marshal(baiduRequest) | 			jsonData, err = json.Marshal(baiduRequest) | ||||||
|  | 		} | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonData) | ||||||
| 	case APITypePaLM: | 	case APITypePaLM: | ||||||
| 		palmRequest := requestOpenAI2PaLM(textRequest) | 		palmRequest := requestOpenAI2PaLM(textRequest) | ||||||
| 		jsonStr, err := json.Marshal(palmRequest) | 		jsonStr, err := json.Marshal(palmRequest) | ||||||
| @@ -216,8 +244,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
|  | 	case APITypeAli: | ||||||
|  | 		aliRequest := requestOpenAI2Ali(textRequest) | ||||||
|  | 		jsonStr, err := json.Marshal(aliRequest) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var req *http.Request | ||||||
|  | 	var resp *http.Response | ||||||
|  | 	isStream := textRequest.Stream | ||||||
|  |  | ||||||
|  | 	if apiType != APITypeXunfei { // cause xunfei use websocket | ||||||
|  | 		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | 			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| @@ -240,12 +281,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		case APITypeZhipu: | 		case APITypeZhipu: | ||||||
| 			token := getZhipuToken(apiKey) | 			token := getZhipuToken(apiKey) | ||||||
| 			req.Header.Set("Authorization", token) | 			req.Header.Set("Authorization", token) | ||||||
|  | 		case APITypeAli: | ||||||
|  | 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||||
|  | 			if textRequest.Stream { | ||||||
|  | 				req.Header.Set("X-DashScope-SSE", "enable") | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||||
| 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||||
| 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | ||||||
| 	client := &http.Client{} | 		resp, err = httpClient.Do(req) | ||||||
| 	resp, err := client.Do(req) |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | 			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| @@ -257,9 +302,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
|  | 		isStream = strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	var textResponse TextResponse | 	var textResponse TextResponse | ||||||
| 	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") |  | ||||||
| 	var streamResponseText string |  | ||||||
|  |  | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		if consumeQuota { | 		if consumeQuota { | ||||||
| @@ -271,16 +317,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			if strings.HasPrefix(textRequest.Model, "gpt-4") { | 			if strings.HasPrefix(textRequest.Model, "gpt-4") { | ||||||
| 				completionRatio = 2 | 				completionRatio = 2 | ||||||
| 			} | 			} | ||||||
| 			if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu { |  | ||||||
| 				completionTokens = countTokenText(streamResponseText, textRequest.Model) |  | ||||||
| 			} else { |  | ||||||
| 			promptTokens = textResponse.Usage.PromptTokens | 			promptTokens = textResponse.Usage.PromptTokens | ||||||
| 			completionTokens = textResponse.Usage.CompletionTokens | 			completionTokens = textResponse.Usage.CompletionTokens | ||||||
| 				if apiType == APITypeZhipu { |  | ||||||
| 					// zhipu's API does not return prompt tokens & completion tokens |  | ||||||
| 					promptTokens = textResponse.Usage.TotalTokens |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			quota = promptTokens + int(float64(completionTokens)*completionRatio) | 			quota = promptTokens + int(float64(completionTokens)*completionRatio) | ||||||
| 			quota = int(float64(quota) * ratio) | 			quota = int(float64(quota) * ratio) | ||||||
| 			if ratio != 0 && quota <= 0 { | 			if ratio != 0 && quota <= 0 { | ||||||
| @@ -318,7 +358,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 			streamResponseText = responseText | 			textResponse.Usage.PromptTokens = promptTokens | ||||||
|  | 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			err, usage := openaiHandler(c, resp, consumeQuota) | 			err, usage := openaiHandler(c, resp, consumeQuota) | ||||||
| @@ -336,7 +377,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 			streamResponseText = responseText | 			textResponse.Usage.PromptTokens = promptTokens | ||||||
|  | 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) | 			err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) | ||||||
| @@ -359,7 +401,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			} | 			} | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			err, usage := baiduHandler(c, resp) | 			var err *OpenAIErrorWithStatusCode | ||||||
|  | 			var usage *Usage | ||||||
|  | 			switch relayMode { | ||||||
|  | 			case RelayModeEmbeddings: | ||||||
|  | 				err, usage = baiduEmbeddingHandler(c, resp) | ||||||
|  | 			default: | ||||||
|  | 				err, usage = baiduHandler(c, resp) | ||||||
|  | 			} | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| @@ -374,7 +423,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 			streamResponseText = responseText | 			textResponse.Usage.PromptTokens = promptTokens | ||||||
|  | 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) | 			err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) | ||||||
| @@ -395,6 +445,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			if usage != nil { | 			if usage != nil { | ||||||
| 				textResponse.Usage = *usage | 				textResponse.Usage = *usage | ||||||
| 			} | 			} | ||||||
|  | 			// zhipu's API does not return prompt tokens & completion tokens | ||||||
|  | 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			err, usage := zhipuHandler(c, resp) | 			err, usage := zhipuHandler(c, resp) | ||||||
| @@ -404,8 +456,49 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			if usage != nil { | 			if usage != nil { | ||||||
| 				textResponse.Usage = *usage | 				textResponse.Usage = *usage | ||||||
| 			} | 			} | ||||||
|  | 			// zhipu's API does not return prompt tokens & completion tokens | ||||||
|  | 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens | ||||||
| 			return nil | 			return nil | ||||||
| 		} | 		} | ||||||
|  | 	case APITypeAli: | ||||||
|  | 		if isStream { | ||||||
|  | 			err, usage := aliStreamHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		} else { | ||||||
|  | 			err, usage := aliHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  | 	case APITypeXunfei: | ||||||
|  | 		if isStream { | ||||||
|  | 			auth := c.Request.Header.Get("Authorization") | ||||||
|  | 			auth = strings.TrimPrefix(auth, "Bearer ") | ||||||
|  | 			splits := strings.Split(auth, "|") | ||||||
|  | 			if len(splits) != 3 { | ||||||
|  | 				return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) | ||||||
|  | 			} | ||||||
|  | 			err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		} else { | ||||||
|  | 			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) | ||||||
|  | 		} | ||||||
| 	default: | 	default: | ||||||
| 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
|   | |||||||
							
								
								
									
										274
									
								
								controller/relay-xunfei.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										274
									
								
								controller/relay-xunfei.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,274 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"crypto/hmac" | ||||||
|  | 	"crypto/sha256" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/gorilla/websocket" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/url" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // https://console.xfyun.cn/services/cbm | ||||||
|  | // https://www.xfyun.cn/doc/spark/Web.html | ||||||
|  |  | ||||||
|  | type XunfeiMessage struct { | ||||||
|  | 	Role    string `json:"role"` | ||||||
|  | 	Content string `json:"content"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type XunfeiChatRequest struct { | ||||||
|  | 	Header struct { | ||||||
|  | 		AppId string `json:"app_id"` | ||||||
|  | 	} `json:"header"` | ||||||
|  | 	Parameter struct { | ||||||
|  | 		Chat struct { | ||||||
|  | 			Domain      string  `json:"domain,omitempty"` | ||||||
|  | 			Temperature float64 `json:"temperature,omitempty"` | ||||||
|  | 			TopK        int     `json:"top_k,omitempty"` | ||||||
|  | 			MaxTokens   int     `json:"max_tokens,omitempty"` | ||||||
|  | 			Auditing    bool    `json:"auditing,omitempty"` | ||||||
|  | 		} `json:"chat"` | ||||||
|  | 	} `json:"parameter"` | ||||||
|  | 	Payload struct { | ||||||
|  | 		Message struct { | ||||||
|  | 			Text []XunfeiMessage `json:"text"` | ||||||
|  | 		} `json:"message"` | ||||||
|  | 	} `json:"payload"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type XunfeiChatResponseTextItem struct { | ||||||
|  | 	Content string `json:"content"` | ||||||
|  | 	Role    string `json:"role"` | ||||||
|  | 	Index   int    `json:"index"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type XunfeiChatResponse struct { | ||||||
|  | 	Header struct { | ||||||
|  | 		Code    int    `json:"code"` | ||||||
|  | 		Message string `json:"message"` | ||||||
|  | 		Sid     string `json:"sid"` | ||||||
|  | 		Status  int    `json:"status"` | ||||||
|  | 	} `json:"header"` | ||||||
|  | 	Payload struct { | ||||||
|  | 		Choices struct { | ||||||
|  | 			Status int                          `json:"status"` | ||||||
|  | 			Seq    int                          `json:"seq"` | ||||||
|  | 			Text   []XunfeiChatResponseTextItem `json:"text"` | ||||||
|  | 		} `json:"choices"` | ||||||
|  | 	} `json:"payload"` | ||||||
|  | 	Usage struct { | ||||||
|  | 		//Text struct { | ||||||
|  | 		//	QuestionTokens   string `json:"question_tokens"` | ||||||
|  | 		//	PromptTokens     string `json:"prompt_tokens"` | ||||||
|  | 		//	CompletionTokens string `json:"completion_tokens"` | ||||||
|  | 		//	TotalTokens      string `json:"total_tokens"` | ||||||
|  | 		//} `json:"text"` | ||||||
|  | 		Text Usage `json:"text"` | ||||||
|  | 	} `json:"usage"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest { | ||||||
|  | 	messages := make([]XunfeiMessage, 0, len(request.Messages)) | ||||||
|  | 	for _, message := range request.Messages { | ||||||
|  | 		if message.Role == "system" { | ||||||
|  | 			messages = append(messages, XunfeiMessage{ | ||||||
|  | 				Role:    "user", | ||||||
|  | 				Content: message.Content, | ||||||
|  | 			}) | ||||||
|  | 			messages = append(messages, XunfeiMessage{ | ||||||
|  | 				Role:    "assistant", | ||||||
|  | 				Content: "Okay", | ||||||
|  | 			}) | ||||||
|  | 		} else { | ||||||
|  | 			messages = append(messages, XunfeiMessage{ | ||||||
|  | 				Role:    message.Role, | ||||||
|  | 				Content: message.Content, | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	xunfeiRequest := XunfeiChatRequest{} | ||||||
|  | 	xunfeiRequest.Header.AppId = xunfeiAppId | ||||||
|  | 	xunfeiRequest.Parameter.Chat.Domain = "general" | ||||||
|  | 	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature | ||||||
|  | 	xunfeiRequest.Parameter.Chat.TopK = request.N | ||||||
|  | 	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens | ||||||
|  | 	xunfeiRequest.Payload.Message.Text = messages | ||||||
|  | 	return &xunfeiRequest | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { | ||||||
|  | 	if len(response.Payload.Choices.Text) == 0 { | ||||||
|  | 		response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | ||||||
|  | 			{ | ||||||
|  | 				Content: "", | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	choice := OpenAITextResponseChoice{ | ||||||
|  | 		Index: 0, | ||||||
|  | 		Message: Message{ | ||||||
|  | 			Role:    "assistant", | ||||||
|  | 			Content: response.Payload.Choices.Text[0].Content, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := OpenAITextResponse{ | ||||||
|  | 		Object:  "chat.completion", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Choices: []OpenAITextResponseChoice{choice}, | ||||||
|  | 		Usage:   response.Usage.Text, | ||||||
|  | 	} | ||||||
|  | 	return &fullTextResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { | ||||||
|  | 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { | ||||||
|  | 		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | ||||||
|  | 			{ | ||||||
|  | 				Content: "", | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content | ||||||
|  | 	response := ChatCompletionsStreamResponse{ | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   "SparkDesk", | ||||||
|  | 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | 	return &response | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { | ||||||
|  | 	HmacWithShaToBase64 := func(algorithm, data, key string) string { | ||||||
|  | 		mac := hmac.New(sha256.New, []byte(key)) | ||||||
|  | 		mac.Write([]byte(data)) | ||||||
|  | 		encodeData := mac.Sum(nil) | ||||||
|  | 		return base64.StdEncoding.EncodeToString(encodeData) | ||||||
|  | 	} | ||||||
|  | 	ul, err := url.Parse(hostUrl) | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Println(err) | ||||||
|  | 	} | ||||||
|  | 	date := time.Now().UTC().Format(time.RFC1123) | ||||||
|  | 	signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} | ||||||
|  | 	sign := strings.Join(signString, "\n") | ||||||
|  | 	sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret) | ||||||
|  | 	authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, | ||||||
|  | 		"hmac-sha256", "host date request-line", sha) | ||||||
|  | 	authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) | ||||||
|  | 	v := url.Values{} | ||||||
|  | 	v.Add("host", ul.Host) | ||||||
|  | 	v.Add("date", date) | ||||||
|  | 	v.Add("authorization", authorization) | ||||||
|  | 	callUrl := hostUrl + "?" + v.Encode() | ||||||
|  | 	return callUrl | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var usage Usage | ||||||
|  | 	d := websocket.Dialer{ | ||||||
|  | 		HandshakeTimeout: 5 * time.Second, | ||||||
|  | 	} | ||||||
|  | 	hostUrl := "wss://aichat.xf-yun.com/v1/chat" | ||||||
|  | 	conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil) | ||||||
|  | 	if err != nil || resp.StatusCode != 101 { | ||||||
|  | 		return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	data := requestOpenAI2Xunfei(textRequest, appId) | ||||||
|  | 	err = conn.WriteJSON(data) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	dataChan := make(chan XunfeiChatResponse) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		for { | ||||||
|  | 			_, msg, err := conn.ReadMessage() | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error reading stream response: " + err.Error()) | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 			var response XunfeiChatResponse | ||||||
|  | 			err = json.Unmarshal(msg, &response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 			dataChan <- response | ||||||
|  | 			if response.Payload.Choices.Status == 2 { | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||||
|  | 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||||
|  | 	c.Writer.Header().Set("Connection", "keep-alive") | ||||||
|  | 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||||
|  | 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||||
|  | 	c.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case xunfeiResponse := <-dataChan: | ||||||
|  | 			usage.PromptTokens += xunfeiResponse.Usage.Text.PromptTokens | ||||||
|  | 			usage.CompletionTokens += xunfeiResponse.Usage.Text.CompletionTokens | ||||||
|  | 			usage.TotalTokens += xunfeiResponse.Usage.Text.TotalTokens | ||||||
|  | 			response := streamResponseXunfei2OpenAI(&xunfeiResponse) | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	return nil, &usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var xunfeiResponse XunfeiChatResponse | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &xunfeiResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	if xunfeiResponse.Header.Code != 0 { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: xunfeiResponse.Header.Message, | ||||||
|  | 				Type:    "xunfei_error", | ||||||
|  | 				Param:   "", | ||||||
|  | 				Code:    xunfeiResponse.Header.Code, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
| @@ -111,11 +111,22 @@ func getZhipuToken(apikey string) string { | |||||||
| func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { | func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { | ||||||
| 	messages := make([]ZhipuMessage, 0, len(request.Messages)) | 	messages := make([]ZhipuMessage, 0, len(request.Messages)) | ||||||
| 	for _, message := range request.Messages { | 	for _, message := range request.Messages { | ||||||
|  | 		if message.Role == "system" { | ||||||
|  | 			messages = append(messages, ZhipuMessage{ | ||||||
|  | 				Role:    "system", | ||||||
|  | 				Content: message.Content, | ||||||
|  | 			}) | ||||||
|  | 			messages = append(messages, ZhipuMessage{ | ||||||
|  | 				Role:    "user", | ||||||
|  | 				Content: "Okay", | ||||||
|  | 			}) | ||||||
|  | 		} else { | ||||||
| 			messages = append(messages, ZhipuMessage{ | 			messages = append(messages, ZhipuMessage{ | ||||||
| 				Role:    message.Role, | 				Role:    message.Role, | ||||||
| 				Content: message.Content, | 				Content: message.Content, | ||||||
| 			}) | 			}) | ||||||
| 		} | 		} | ||||||
|  | 	} | ||||||
| 	return &ZhipuRequest{ | 	return &ZhipuRequest{ | ||||||
| 		Prompt:      messages, | 		Prompt:      messages, | ||||||
| 		Temperature: request.Temperature, | 		Temperature: request.Temperature, | ||||||
|   | |||||||
| @@ -99,6 +99,19 @@ type OpenAITextResponse struct { | |||||||
| 	Usage   `json:"usage"` | 	Usage   `json:"usage"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type OpenAIEmbeddingResponseItem struct { | ||||||
|  | 	Object    string    `json:"object"` | ||||||
|  | 	Index     int       `json:"index"` | ||||||
|  | 	Embedding []float64 `json:"embedding"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type OpenAIEmbeddingResponse struct { | ||||||
|  | 	Object string                        `json:"object"` | ||||||
|  | 	Data   []OpenAIEmbeddingResponseItem `json:"data"` | ||||||
|  | 	Model  string                        `json:"model"` | ||||||
|  | 	Usage  `json:"usage"` | ||||||
|  | } | ||||||
|  |  | ||||||
| type ImageResponse struct { | type ImageResponse struct { | ||||||
| 	Created int `json:"created"` | 	Created int `json:"created"` | ||||||
| 	Data    []struct { | 	Data    []struct { | ||||||
|   | |||||||
| @@ -3,12 +3,13 @@ package controller | |||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-contrib/sessions" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-contrib/sessions" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type LoginRequest struct { | type LoginRequest struct { | ||||||
| @@ -477,6 +478,16 @@ func DeleteUser(c *gin.Context) { | |||||||
|  |  | ||||||
| func DeleteSelf(c *gin.Context) { | func DeleteSelf(c *gin.Context) { | ||||||
| 	id := c.GetInt("id") | 	id := c.GetInt("id") | ||||||
|  | 	user, _ := model.GetUserById(id, false) | ||||||
|  |  | ||||||
|  | 	if user.Role == common.RoleRootUser { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": "不能删除超级管理员账户", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	err := model.DeleteUserById(id) | 	err := model.DeleteUserById(id) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(http.StatusOK, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								go.mod
									
									
									
									
									
								
							| @@ -13,6 +13,7 @@ require ( | |||||||
| 	github.com/go-redis/redis/v8 v8.11.5 | 	github.com/go-redis/redis/v8 v8.11.5 | ||||||
| 	github.com/golang-jwt/jwt v3.2.2+incompatible | 	github.com/golang-jwt/jwt v3.2.2+incompatible | ||||||
| 	github.com/google/uuid v1.3.0 | 	github.com/google/uuid v1.3.0 | ||||||
|  | 	github.com/gorilla/websocket v1.5.0 | ||||||
| 	github.com/pkoukk/tiktoken-go v0.1.1 | 	github.com/pkoukk/tiktoken-go v0.1.1 | ||||||
| 	golang.org/x/crypto v0.9.0 | 	golang.org/x/crypto v0.9.0 | ||||||
| 	gorm.io/driver/mysql v1.4.3 | 	gorm.io/driver/mysql v1.4.3 | ||||||
|   | |||||||
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							| @@ -67,6 +67,8 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC | |||||||
| github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= | github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= | ||||||
| github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= | github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= | ||||||
| github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= | github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= | ||||||
|  | github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= | ||||||
|  | github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= | ||||||
| github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | ||||||
| github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | ||||||
| github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||||
|   | |||||||
| @@ -503,5 +503,12 @@ | |||||||
|   "请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT", |   "请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT", | ||||||
|   "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", |   "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", | ||||||
|   "Homepage URL 填": "Fill in the Homepage URL", |   "Homepage URL 填": "Fill in the Homepage URL", | ||||||
|   "Authorization callback URL 填": "Fill in the Authorization callback URL" |   "Authorization callback URL 填": "Fill in the Authorization callback URL", | ||||||
|  |   "请为通道命名": "Please name the channel", | ||||||
|  |   "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", | ||||||
|  |   "模型重定向": "Model redirection", | ||||||
|  |   "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel", | ||||||
|  |   "注意,": "Note that, ", | ||||||
|  |   ",图片演示。": "related image demo.", | ||||||
|  |   "令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!" | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								main.go
									
									
									
									
									
								
							| @@ -54,6 +54,7 @@ func main() { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error()) | 			common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error()) | ||||||
| 		} | 		} | ||||||
|  | 		common.SyncFrequency = frequency | ||||||
| 		go model.SyncOptions(frequency) | 		go model.SyncOptions(frequency) | ||||||
| 		if common.RedisEnabled { | 		if common.RedisEnabled { | ||||||
| 			go model.SyncChannelCache(frequency) | 			go model.SyncChannelCache(frequency) | ||||||
|   | |||||||
| @@ -12,11 +12,11 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | var ( | ||||||
| 	TokenCacheSeconds         = 60 * 60 | 	TokenCacheSeconds         = common.SyncFrequency | ||||||
| 	UserId2GroupCacheSeconds  = 60 * 60 | 	UserId2GroupCacheSeconds  = common.SyncFrequency | ||||||
| 	UserId2QuotaCacheSeconds  = 10 * 60 | 	UserId2QuotaCacheSeconds  = common.SyncFrequency | ||||||
| 	UserId2StatusCacheSeconds = 60 * 60 | 	UserId2StatusCacheSeconds = common.SyncFrequency | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func CacheGetTokenByKey(key string) (*Token, error) { | func CacheGetTokenByKey(key string) (*Token, error) { | ||||||
| @@ -35,7 +35,7 @@ func CacheGetTokenByKey(key string) (*Token, error) { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), TokenCacheSeconds*time.Second) | 		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.SysError("Redis set token error: " + err.Error()) | 			common.SysError("Redis set token error: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| @@ -55,7 +55,7 @@ func CacheGetUserGroup(id int) (group string, err error) { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return "", err | 			return "", err | ||||||
| 		} | 		} | ||||||
| 		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, UserId2GroupCacheSeconds*time.Second) | 		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.SysError("Redis set user group error: " + err.Error()) | 			common.SysError("Redis set user group error: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| @@ -73,7 +73,7 @@ func CacheGetUserQuota(id int) (quota int, err error) { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return 0, err | 			return 0, err | ||||||
| 		} | 		} | ||||||
| 		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second) | 		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.SysError("Redis set user quota error: " + err.Error()) | 			common.SysError("Redis set user quota error: " + err.Error()) | ||||||
| 		} | 		} | ||||||
| @@ -91,7 +91,7 @@ func CacheUpdateUserQuota(id int) error { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second) | 	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -106,7 +106,7 @@ func CacheIsUserEnabled(userId int) bool { | |||||||
| 			status = common.UserStatusEnabled | 			status = common.UserStatusEnabled | ||||||
| 		} | 		} | ||||||
| 		enabled = fmt.Sprintf("%d", status) | 		enabled = fmt.Sprintf("%d", status) | ||||||
| 		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, UserId2StatusCacheSeconds*time.Second) | 		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.SysError("Redis set user enabled error: " + err.Error()) | 			common.SysError("Redis set user enabled error: " + err.Error()) | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -51,20 +51,21 @@ func Redeem(key string, userId int) (quota int, err error) { | |||||||
| 	redemption := &Redemption{} | 	redemption := &Redemption{} | ||||||
|  |  | ||||||
| 	err = DB.Transaction(func(tx *gorm.DB) error { | 	err = DB.Transaction(func(tx *gorm.DB) error { | ||||||
| 		err := DB.Where("`key` = ?", key).First(redemption).Error | 		err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errors.New("无效的兑换码") | 			return errors.New("无效的兑换码") | ||||||
| 		} | 		} | ||||||
| 		if redemption.Status != common.RedemptionCodeStatusEnabled { | 		if redemption.Status != common.RedemptionCodeStatusEnabled { | ||||||
| 			return errors.New("该兑换码已被使用") | 			return errors.New("该兑换码已被使用") | ||||||
| 		} | 		} | ||||||
| 		err = DB.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error | 		err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 		redemption.RedeemedTime = common.GetTimestamp() | 		redemption.RedeemedTime = common.GetTimestamp() | ||||||
| 		redemption.Status = common.RedemptionCodeStatusUsed | 		redemption.Status = common.RedemptionCodeStatusUsed | ||||||
| 		return redemption.SelectUpdate() | 		err = tx.Save(redemption).Error | ||||||
|  | 		return err | ||||||
| 	}) | 	}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return 0, errors.New("兑换失败," + err.Error()) | 		return 0, errors.New("兑换失败," + err.Error()) | ||||||
|   | |||||||
| @@ -36,7 +36,7 @@ func SetApiRouter(router *gin.Engine) { | |||||||
| 			{ | 			{ | ||||||
| 				selfRoute.GET("/self", controller.GetSelf) | 				selfRoute.GET("/self", controller.GetSelf) | ||||||
| 				selfRoute.PUT("/self", controller.UpdateSelf) | 				selfRoute.PUT("/self", controller.UpdateSelf) | ||||||
| 				selfRoute.DELETE("/self", controller.DeleteSelf) | 				selfRoute.DELETE("/self", middleware.TurnstileCheck(), controller.DeleteSelf) | ||||||
| 				selfRoute.GET("/token", controller.GenerateAccessToken) | 				selfRoute.GET("/token", controller.GenerateAccessToken) | ||||||
| 				selfRoute.GET("/aff", controller.GetAffCode) | 				selfRoute.GET("/aff", controller.GetAffCode) | ||||||
| 				selfRoute.POST("/topup", controller.TopUp) | 				selfRoute.POST("/topup", controller.TopUp) | ||||||
|   | |||||||
| @@ -12,7 +12,7 @@ func SetRelayRouter(router *gin.Engine) { | |||||||
| 	modelsRouter := router.Group("/v1/models") | 	modelsRouter := router.Group("/v1/models") | ||||||
| 	modelsRouter.Use(middleware.TokenAuth()) | 	modelsRouter.Use(middleware.TokenAuth()) | ||||||
| 	{ | 	{ | ||||||
| 		modelsRouter.GET("/", controller.ListModels) | 		modelsRouter.GET("", controller.ListModels) | ||||||
| 		modelsRouter.GET("/:model", controller.RetrieveModel) | 		modelsRouter.GET("/:model", controller.RetrieveModel) | ||||||
| 	} | 	} | ||||||
| 	relayV1Router := router.Group("/v1") | 	relayV1Router := router.Group("/v1") | ||||||
|   | |||||||
| @@ -363,9 +363,12 @@ const ChannelsTable = () => { | |||||||
|                   </Table.Cell> |                   </Table.Cell> | ||||||
|                   <Table.Cell> |                   <Table.Cell> | ||||||
|                     <Popup |                     <Popup | ||||||
|                       content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'} |                       trigger={<span onClick={() => { | ||||||
|                       key={channel.id} |                         updateChannelBalance(channel.id, channel.name, idx); | ||||||
|                       trigger={renderBalance(channel.type, channel.balance)} |                       }} style={{ cursor: 'pointer' }}> | ||||||
|  |                       {renderBalance(channel.type, channel.balance)} | ||||||
|  |                     </span>} | ||||||
|  |                       content="点击更新" | ||||||
|                       basic |                       basic | ||||||
|                     /> |                     /> | ||||||
|                   </Table.Cell> |                   </Table.Cell> | ||||||
| @@ -380,16 +383,16 @@ const ChannelsTable = () => { | |||||||
|                       > |                       > | ||||||
|                         测试 |                         测试 | ||||||
|                       </Button> |                       </Button> | ||||||
|                       <Button |                       {/*<Button*/} | ||||||
|                         size={'small'} |                       {/*  size={'small'}*/} | ||||||
|                         positive |                       {/*  positive*/} | ||||||
|                         loading={updatingBalance} |                       {/*  loading={updatingBalance}*/} | ||||||
|                         onClick={() => { |                       {/*  onClick={() => {*/} | ||||||
|                           updateChannelBalance(channel.id, channel.name, idx); |                       {/*    updateChannelBalance(channel.id, channel.name, idx);*/} | ||||||
|                         }} |                       {/*  }}*/} | ||||||
|                       > |                       {/*>*/} | ||||||
|                         更新余额 |                       {/*  更新余额*/} | ||||||
|                       </Button> |                       {/*</Button>*/} | ||||||
|                       <Popup |                       <Popup | ||||||
|                         trigger={ |                         trigger={ | ||||||
|                           <Button size='small' negative> |                           <Button size='small' negative> | ||||||
|   | |||||||
| @@ -1,36 +1,25 @@ | |||||||
| import React, { useContext, useEffect, useState } from 'react'; | import React, { useContext, useEffect, useState } from 'react'; | ||||||
| import { | import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react'; | ||||||
|   Button, |  | ||||||
|   Divider, |  | ||||||
|   Form, |  | ||||||
|   Grid, |  | ||||||
|   Header, |  | ||||||
|   Image, |  | ||||||
|   Message, |  | ||||||
|   Modal, |  | ||||||
|   Segment, |  | ||||||
| } from 'semantic-ui-react'; |  | ||||||
| import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | ||||||
| import { UserContext } from '../context/User'; | import { UserContext } from '../context/User'; | ||||||
| import { API, getLogo, showError, showSuccess, showInfo } from '../helpers'; | import { API, getLogo, showError, showSuccess } from '../helpers'; | ||||||
|  |  | ||||||
| const LoginForm = () => { | const LoginForm = () => { | ||||||
|   const [inputs, setInputs] = useState({ |   const [inputs, setInputs] = useState({ | ||||||
|     username: '', |     username: '', | ||||||
|     password: '', |     password: '', | ||||||
|     wechat_verification_code: '', |     wechat_verification_code: '' | ||||||
|   }); |   }); | ||||||
|   const [searchParams, setSearchParams] = useSearchParams(); |   const [searchParams, setSearchParams] = useSearchParams(); | ||||||
|   const [submitted, setSubmitted] = useState(false); |   const [submitted, setSubmitted] = useState(false); | ||||||
|   const { username, password } = inputs; |   const { username, password } = inputs; | ||||||
|   const [userState, userDispatch] = useContext(UserContext); |   const [userState, userDispatch] = useContext(UserContext); | ||||||
|   let navigate = useNavigate(); |   let navigate = useNavigate(); | ||||||
|  |  | ||||||
|   const [status, setStatus] = useState({}); |   const [status, setStatus] = useState({}); | ||||||
|   const logo = getLogo(); |   const logo = getLogo(); | ||||||
|  |  | ||||||
|   useEffect(() => { |   useEffect(() => { | ||||||
|     if (searchParams.get("expired")) { |     if (searchParams.get('expired')) { | ||||||
|       showError('未登录或登录已过期,请重新登录!'); |       showError('未登录或登录已过期,请重新登录!'); | ||||||
|     } |     } | ||||||
|     let status = localStorage.getItem('status'); |     let status = localStorage.getItem('status'); | ||||||
| @@ -78,7 +67,7 @@ const LoginForm = () => { | |||||||
|     if (username && password) { |     if (username && password) { | ||||||
|       const res = await API.post(`/api/user/login`, { |       const res = await API.post(`/api/user/login`, { | ||||||
|         username, |         username, | ||||||
|         password, |         password | ||||||
|       }); |       }); | ||||||
|       const { success, message, data } = res.data; |       const { success, message, data } = res.data; | ||||||
|       if (success) { |       if (success) { | ||||||
| @@ -93,44 +82,44 @@ const LoginForm = () => { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   return ( |   return ( | ||||||
|     <Grid textAlign="center" style={{ marginTop: '48px' }}> |     <Grid textAlign='center' style={{ marginTop: '48px' }}> | ||||||
|       <Grid.Column style={{ maxWidth: 450 }}> |       <Grid.Column style={{ maxWidth: 450 }}> | ||||||
|         <Header as="h2" color="" textAlign="center"> |         <Header as='h2' color='' textAlign='center'> | ||||||
|           <Image src={logo} /> 用户登录 |           <Image src={logo} /> 用户登录 | ||||||
|         </Header> |         </Header> | ||||||
|         <Form size="large"> |         <Form size='large'> | ||||||
|           <Segment> |           <Segment> | ||||||
|             <Form.Input |             <Form.Input | ||||||
|               fluid |               fluid | ||||||
|               icon="user" |               icon='user' | ||||||
|               iconPosition="left" |               iconPosition='left' | ||||||
|               placeholder="用户名" |               placeholder='用户名' | ||||||
|               name="username" |               name='username' | ||||||
|               value={username} |               value={username} | ||||||
|               onChange={handleChange} |               onChange={handleChange} | ||||||
|             /> |             /> | ||||||
|             <Form.Input |             <Form.Input | ||||||
|               fluid |               fluid | ||||||
|               icon="lock" |               icon='lock' | ||||||
|               iconPosition="left" |               iconPosition='left' | ||||||
|               placeholder="密码" |               placeholder='密码' | ||||||
|               name="password" |               name='password' | ||||||
|               type="password" |               type='password' | ||||||
|               value={password} |               value={password} | ||||||
|               onChange={handleChange} |               onChange={handleChange} | ||||||
|             /> |             /> | ||||||
|             <Button color="" fluid size="large" onClick={handleSubmit}> |             <Button color='green' fluid size='large' onClick={handleSubmit}> | ||||||
|               登录 |               登录 | ||||||
|             </Button> |             </Button> | ||||||
|           </Segment> |           </Segment> | ||||||
|         </Form> |         </Form> | ||||||
|         <Message> |         <Message> | ||||||
|           忘记密码? |           忘记密码? | ||||||
|           <Link to="/reset" className="btn btn-link"> |           <Link to='/reset' className='btn btn-link'> | ||||||
|             点击重置 |             点击重置 | ||||||
|           </Link> |           </Link> | ||||||
|           ; 没有账户? |           ; 没有账户? | ||||||
|           <Link to="/register" className="btn btn-link"> |           <Link to='/register' className='btn btn-link'> | ||||||
|             点击注册 |             点击注册 | ||||||
|           </Link> |           </Link> | ||||||
|         </Message> |         </Message> | ||||||
| @@ -140,8 +129,8 @@ const LoginForm = () => { | |||||||
|             {status.github_oauth ? ( |             {status.github_oauth ? ( | ||||||
|               <Button |               <Button | ||||||
|                 circular |                 circular | ||||||
|                 color="black" |                 color='black' | ||||||
|                 icon="github" |                 icon='github' | ||||||
|                 onClick={onGitHubOAuthClicked} |                 onClick={onGitHubOAuthClicked} | ||||||
|               /> |               /> | ||||||
|             ) : ( |             ) : ( | ||||||
| @@ -150,8 +139,8 @@ const LoginForm = () => { | |||||||
|             {status.wechat_login ? ( |             {status.wechat_login ? ( | ||||||
|               <Button |               <Button | ||||||
|                 circular |                 circular | ||||||
|                 color="green" |                 color='green' | ||||||
|                 icon="wechat" |                 icon='wechat' | ||||||
|                 onClick={onWeChatLoginClicked} |                 onClick={onWeChatLoginClicked} | ||||||
|               /> |               /> | ||||||
|             ) : ( |             ) : ( | ||||||
| @@ -175,18 +164,18 @@ const LoginForm = () => { | |||||||
|                   微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) |                   微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) | ||||||
|                 </p> |                 </p> | ||||||
|               </div> |               </div> | ||||||
|               <Form size="large"> |               <Form size='large'> | ||||||
|                 <Form.Input |                 <Form.Input | ||||||
|                   fluid |                   fluid | ||||||
|                   placeholder="验证码" |                   placeholder='验证码' | ||||||
|                   name="wechat_verification_code" |                   name='wechat_verification_code' | ||||||
|                   value={inputs.wechat_verification_code} |                   value={inputs.wechat_verification_code} | ||||||
|                   onChange={handleChange} |                   onChange={handleChange} | ||||||
|                 /> |                 /> | ||||||
|                 <Button |                 <Button | ||||||
|                   color="" |                   color='' | ||||||
|                   fluid |                   fluid | ||||||
|                   size="large" |                   size='large' | ||||||
|                   onClick={onSubmitWeChatVerificationCode} |                   onClick={onSubmitWeChatVerificationCode} | ||||||
|                 > |                 > | ||||||
|                   登录 |                   登录 | ||||||
|   | |||||||
| @@ -12,6 +12,11 @@ const PasswordResetConfirm = () => { | |||||||
|  |  | ||||||
|   const [loading, setLoading] = useState(false); |   const [loading, setLoading] = useState(false); | ||||||
|  |  | ||||||
|  |   const [disableButton, setDisableButton] = useState(false); | ||||||
|  |   const [countdown, setCountdown] = useState(30); | ||||||
|  |  | ||||||
|  |   const [newPassword, setNewPassword] = useState(''); | ||||||
|  |  | ||||||
|   const [searchParams, setSearchParams] = useSearchParams(); |   const [searchParams, setSearchParams] = useSearchParams(); | ||||||
|   useEffect(() => { |   useEffect(() => { | ||||||
|     let token = searchParams.get('token'); |     let token = searchParams.get('token'); | ||||||
| @@ -22,7 +27,21 @@ const PasswordResetConfirm = () => { | |||||||
|     }); |     }); | ||||||
|   }, []); |   }, []); | ||||||
|  |  | ||||||
|  |   useEffect(() => { | ||||||
|  |     let countdownInterval = null; | ||||||
|  |     if (disableButton && countdown > 0) { | ||||||
|  |       countdownInterval = setInterval(() => { | ||||||
|  |         setCountdown(countdown - 1); | ||||||
|  |       }, 1000); | ||||||
|  |     } else if (countdown === 0) { | ||||||
|  |       setDisableButton(false); | ||||||
|  |       setCountdown(30); | ||||||
|  |     } | ||||||
|  |     return () => clearInterval(countdownInterval);  | ||||||
|  |   }, [disableButton, countdown]); | ||||||
|  |  | ||||||
|   async function handleSubmit(e) { |   async function handleSubmit(e) { | ||||||
|  |     setDisableButton(true); | ||||||
|     if (!email) return; |     if (!email) return; | ||||||
|     setLoading(true); |     setLoading(true); | ||||||
|     const res = await API.post(`/api/user/reset`, { |     const res = await API.post(`/api/user/reset`, { | ||||||
| @@ -32,8 +51,9 @@ const PasswordResetConfirm = () => { | |||||||
|     const { success, message } = res.data; |     const { success, message } = res.data; | ||||||
|     if (success) { |     if (success) { | ||||||
|       let password = res.data.data; |       let password = res.data.data; | ||||||
|  |       setNewPassword(password); | ||||||
|       await copy(password); |       await copy(password); | ||||||
|       showNotice(`密码已重置并已复制到剪贴板:${password}`); |       showNotice(`新密码已复制到剪贴板:${password}`); | ||||||
|     } else { |     } else { | ||||||
|       showError(message); |       showError(message); | ||||||
|     } |     } | ||||||
| @@ -57,14 +77,31 @@ const PasswordResetConfirm = () => { | |||||||
|               value={email} |               value={email} | ||||||
|               readOnly |               readOnly | ||||||
|             /> |             /> | ||||||
|  |             {newPassword && ( | ||||||
|  |               <Form.Input | ||||||
|  |               fluid | ||||||
|  |               icon='lock' | ||||||
|  |               iconPosition='left' | ||||||
|  |               placeholder='新密码' | ||||||
|  |               name='newPassword' | ||||||
|  |               value={newPassword} | ||||||
|  |               readOnly | ||||||
|  |               onClick={(e) => { | ||||||
|  |                 e.target.select(); | ||||||
|  |                 navigator.clipboard.writeText(newPassword); | ||||||
|  |                 showNotice(`密码已复制到剪贴板:${newPassword}`); | ||||||
|  |               }} | ||||||
|  |             />             | ||||||
|  |             )} | ||||||
|             <Button |             <Button | ||||||
|               color='' |               color='green' | ||||||
|               fluid |               fluid | ||||||
|               size='large' |               size='large' | ||||||
|               onClick={handleSubmit} |               onClick={handleSubmit} | ||||||
|               loading={loading} |               loading={loading} | ||||||
|  |               disabled={disableButton} | ||||||
|             > |             > | ||||||
|               提交 |               {disableButton ? `密码重置完成` : '提交'} | ||||||
|             </Button> |             </Button> | ||||||
|           </Segment> |           </Segment> | ||||||
|         </Form> |         </Form> | ||||||
|   | |||||||
| @@ -5,7 +5,7 @@ import Turnstile from 'react-turnstile'; | |||||||
|  |  | ||||||
| const PasswordResetForm = () => { | const PasswordResetForm = () => { | ||||||
|   const [inputs, setInputs] = useState({ |   const [inputs, setInputs] = useState({ | ||||||
|     email: '', |     email: '' | ||||||
|   }); |   }); | ||||||
|   const { email } = inputs; |   const { email } = inputs; | ||||||
|  |  | ||||||
| @@ -13,24 +13,29 @@ const PasswordResetForm = () => { | |||||||
|   const [turnstileEnabled, setTurnstileEnabled] = useState(false); |   const [turnstileEnabled, setTurnstileEnabled] = useState(false); | ||||||
|   const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); |   const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); | ||||||
|   const [turnstileToken, setTurnstileToken] = useState(''); |   const [turnstileToken, setTurnstileToken] = useState(''); | ||||||
|  |   const [disableButton, setDisableButton] = useState(false); | ||||||
|  |   const [countdown, setCountdown] = useState(30); | ||||||
|  |  | ||||||
|   useEffect(() => { |   useEffect(() => { | ||||||
|     let status = localStorage.getItem('status'); |     let countdownInterval = null; | ||||||
|     if (status) { |     if (disableButton && countdown > 0) { | ||||||
|       status = JSON.parse(status); |       countdownInterval = setInterval(() => { | ||||||
|       if (status.turnstile_check) { |         setCountdown(countdown - 1); | ||||||
|         setTurnstileEnabled(true); |       }, 1000); | ||||||
|         setTurnstileSiteKey(status.turnstile_site_key); |     } else if (countdown === 0) { | ||||||
|  |       setDisableButton(false); | ||||||
|  |       setCountdown(30); | ||||||
|     } |     } | ||||||
|     } |     return () => clearInterval(countdownInterval); | ||||||
|   }, []); |   }, [disableButton, countdown]); | ||||||
|  |  | ||||||
|   function handleChange(e) { |   function handleChange(e) { | ||||||
|     const { name, value } = e.target; |     const { name, value } = e.target; | ||||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); |     setInputs(inputs => ({ ...inputs, [name]: value })); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   async function handleSubmit(e) { |   async function handleSubmit(e) { | ||||||
|  |     setDisableButton(true); | ||||||
|     if (!email) return; |     if (!email) return; | ||||||
|     if (turnstileEnabled && turnstileToken === '') { |     if (turnstileEnabled && turnstileToken === '') { | ||||||
|       showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); |       showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); | ||||||
| @@ -78,13 +83,14 @@ const PasswordResetForm = () => { | |||||||
|               <></> |               <></> | ||||||
|             )} |             )} | ||||||
|             <Button |             <Button | ||||||
|               color='' |               color='green' | ||||||
|               fluid |               fluid | ||||||
|               size='large' |               size='large' | ||||||
|               onClick={handleSubmit} |               onClick={handleSubmit} | ||||||
|               loading={loading} |               loading={loading} | ||||||
|  |               disabled={disableButton} | ||||||
|             > |             > | ||||||
|               提交 |               {disableButton ? `重试 (${countdown})` : '提交'} | ||||||
|             </Button> |             </Button> | ||||||
|           </Segment> |           </Segment> | ||||||
|         </Form> |         </Form> | ||||||
|   | |||||||
| @@ -1,22 +1,30 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useContext, useEffect, useState } from 'react'; | ||||||
| import { Button, Divider, Form, Header, Image, Message, Modal } from 'semantic-ui-react'; | import { Button, Divider, Form, Header, Image, Message, Modal } from 'semantic-ui-react'; | ||||||
| import { Link } from 'react-router-dom'; | import { Link, useNavigate } from 'react-router-dom'; | ||||||
| import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; | import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; | ||||||
| import Turnstile from 'react-turnstile'; | import Turnstile from 'react-turnstile'; | ||||||
|  | import { UserContext } from '../context/User'; | ||||||
|  |  | ||||||
| const PersonalSetting = () => { | const PersonalSetting = () => { | ||||||
|  |   const [userState, userDispatch] = useContext(UserContext); | ||||||
|  |   let navigate = useNavigate(); | ||||||
|  |  | ||||||
|   const [inputs, setInputs] = useState({ |   const [inputs, setInputs] = useState({ | ||||||
|     wechat_verification_code: '', |     wechat_verification_code: '', | ||||||
|     email_verification_code: '', |     email_verification_code: '', | ||||||
|     email: '', |     email: '', | ||||||
|  |     self_account_deletion_confirmation: '' | ||||||
|   }); |   }); | ||||||
|   const [status, setStatus] = useState({}); |   const [status, setStatus] = useState({}); | ||||||
|   const [showWeChatBindModal, setShowWeChatBindModal] = useState(false); |   const [showWeChatBindModal, setShowWeChatBindModal] = useState(false); | ||||||
|   const [showEmailBindModal, setShowEmailBindModal] = useState(false); |   const [showEmailBindModal, setShowEmailBindModal] = useState(false); | ||||||
|  |   const [showAccountDeleteModal, setShowAccountDeleteModal] = useState(false); | ||||||
|   const [turnstileEnabled, setTurnstileEnabled] = useState(false); |   const [turnstileEnabled, setTurnstileEnabled] = useState(false); | ||||||
|   const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); |   const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); | ||||||
|   const [turnstileToken, setTurnstileToken] = useState(''); |   const [turnstileToken, setTurnstileToken] = useState(''); | ||||||
|   const [loading, setLoading] = useState(false); |   const [loading, setLoading] = useState(false); | ||||||
|  |   const [disableButton, setDisableButton] = useState(false); | ||||||
|  |   const [countdown, setCountdown] = useState(30); | ||||||
|  |  | ||||||
|   useEffect(() => { |   useEffect(() => { | ||||||
|     let status = localStorage.getItem('status'); |     let status = localStorage.getItem('status'); | ||||||
| @@ -30,6 +38,19 @@ const PersonalSetting = () => { | |||||||
|     } |     } | ||||||
|   }, []); |   }, []); | ||||||
|  |  | ||||||
|  |   useEffect(() => { | ||||||
|  |     let countdownInterval = null; | ||||||
|  |     if (disableButton && countdown > 0) { | ||||||
|  |       countdownInterval = setInterval(() => { | ||||||
|  |         setCountdown(countdown - 1); | ||||||
|  |       }, 1000); | ||||||
|  |     } else if (countdown === 0) { | ||||||
|  |       setDisableButton(false); | ||||||
|  |       setCountdown(30); | ||||||
|  |     } | ||||||
|  |     return () => clearInterval(countdownInterval); // Clean up on unmount | ||||||
|  |   }, [disableButton, countdown]); | ||||||
|  |  | ||||||
|   const handleInputChange = (e, { name, value }) => { |   const handleInputChange = (e, { name, value }) => { | ||||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); |     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||||
|   }; |   }; | ||||||
| @@ -57,6 +78,26 @@ const PersonalSetting = () => { | |||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   const deleteAccount = async () => { | ||||||
|  |     if (inputs.self_account_deletion_confirmation !== userState.user.username) { | ||||||
|  |       showError('请输入你的账户名以确认删除!'); | ||||||
|  |       return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const res = await API.delete('/api/user/self'); | ||||||
|  |     const { success, message } = res.data; | ||||||
|  |  | ||||||
|  |     if (success) { | ||||||
|  |       showSuccess('账户已删除!'); | ||||||
|  |       await API.get('/api/user/logout'); | ||||||
|  |       userDispatch({ type: 'logout' }); | ||||||
|  |       localStorage.removeItem('user'); | ||||||
|  |       navigate('/login'); | ||||||
|  |     } else { | ||||||
|  |       showError(message); | ||||||
|  |     } | ||||||
|  |   }; | ||||||
|  |  | ||||||
|   const bindWeChat = async () => { |   const bindWeChat = async () => { | ||||||
|     if (inputs.wechat_verification_code === '') return; |     if (inputs.wechat_verification_code === '') return; | ||||||
|     const res = await API.get( |     const res = await API.get( | ||||||
| @@ -78,6 +119,7 @@ const PersonalSetting = () => { | |||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   const sendVerificationCode = async () => { |   const sendVerificationCode = async () => { | ||||||
|  |     setDisableButton(true); | ||||||
|     if (inputs.email === '') return; |     if (inputs.email === '') return; | ||||||
|     if (turnstileEnabled && turnstileToken === '') { |     if (turnstileEnabled && turnstileToken === '') { | ||||||
|       showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); |       showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); | ||||||
| @@ -123,6 +165,9 @@ const PersonalSetting = () => { | |||||||
|       </Button> |       </Button> | ||||||
|       <Button onClick={generateAccessToken}>生成系统访问令牌</Button> |       <Button onClick={generateAccessToken}>生成系统访问令牌</Button> | ||||||
|       <Button onClick={getAffLink}>复制邀请链接</Button> |       <Button onClick={getAffLink}>复制邀请链接</Button> | ||||||
|  |       <Button onClick={() => { | ||||||
|  |         setShowAccountDeleteModal(true); | ||||||
|  |       }}>删除个人账户</Button> | ||||||
|       <Divider /> |       <Divider /> | ||||||
|       <Header as='h3'>账号绑定</Header> |       <Header as='h3'>账号绑定</Header> | ||||||
|       { |       { | ||||||
| @@ -195,8 +240,8 @@ const PersonalSetting = () => { | |||||||
|                 name='email' |                 name='email' | ||||||
|                 type='email' |                 type='email' | ||||||
|                 action={ |                 action={ | ||||||
|                   <Button onClick={sendVerificationCode} disabled={loading}> |                   <Button onClick={sendVerificationCode} disabled={disableButton || loading}> | ||||||
|                     获取验证码 |                     {disableButton ? `重新发送(${countdown})` : '获取验证码'} | ||||||
|                   </Button> |                   </Button> | ||||||
|                 } |                 } | ||||||
|               /> |               /> | ||||||
| @@ -230,6 +275,47 @@ const PersonalSetting = () => { | |||||||
|           </Modal.Description> |           </Modal.Description> | ||||||
|         </Modal.Content> |         </Modal.Content> | ||||||
|       </Modal> |       </Modal> | ||||||
|  |       <Modal | ||||||
|  |         onClose={() => setShowAccountDeleteModal(false)} | ||||||
|  |         onOpen={() => setShowAccountDeleteModal(true)} | ||||||
|  |         open={showAccountDeleteModal} | ||||||
|  |         size={'tiny'} | ||||||
|  |         style={{ maxWidth: '450px' }} | ||||||
|  |       > | ||||||
|  |         <Modal.Header>确认删除自己的帐户</Modal.Header> | ||||||
|  |         <Modal.Content> | ||||||
|  |           <Modal.Description> | ||||||
|  |             <Form size='large'> | ||||||
|  |               <Form.Input | ||||||
|  |                 fluid | ||||||
|  |                 placeholder={`输入你的账户名 ${userState?.user?.username} 以确认删除`} | ||||||
|  |                 name='self_account_deletion_confirmation' | ||||||
|  |                 value={inputs.self_account_deletion_confirmation} | ||||||
|  |                 onChange={handleInputChange} | ||||||
|  |               /> | ||||||
|  |               {turnstileEnabled ? ( | ||||||
|  |                 <Turnstile | ||||||
|  |                   sitekey={turnstileSiteKey} | ||||||
|  |                   onVerify={(token) => { | ||||||
|  |                     setTurnstileToken(token); | ||||||
|  |                   }} | ||||||
|  |                 /> | ||||||
|  |               ) : ( | ||||||
|  |                 <></> | ||||||
|  |               )} | ||||||
|  |               <Button | ||||||
|  |                 color='red' | ||||||
|  |                 fluid | ||||||
|  |                 size='large' | ||||||
|  |                 onClick={deleteAccount} | ||||||
|  |                 loading={loading} | ||||||
|  |               > | ||||||
|  |                 删除 | ||||||
|  |               </Button> | ||||||
|  |             </Form> | ||||||
|  |           </Modal.Description> | ||||||
|  |         </Modal.Content> | ||||||
|  |       </Modal> | ||||||
|     </div> |     </div> | ||||||
|   ); |   ); | ||||||
| }; | }; | ||||||
|   | |||||||
| @@ -1,13 +1,5 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { | import { Button, Form, Grid, Header, Image, Message, Segment } from 'semantic-ui-react'; | ||||||
|   Button, |  | ||||||
|   Form, |  | ||||||
|   Grid, |  | ||||||
|   Header, |  | ||||||
|   Image, |  | ||||||
|   Message, |  | ||||||
|   Segment, |  | ||||||
| } from 'semantic-ui-react'; |  | ||||||
| import { Link, useNavigate } from 'react-router-dom'; | import { Link, useNavigate } from 'react-router-dom'; | ||||||
| import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; | import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; | ||||||
| import Turnstile from 'react-turnstile'; | import Turnstile from 'react-turnstile'; | ||||||
| @@ -18,7 +10,7 @@ const RegisterForm = () => { | |||||||
|     password: '', |     password: '', | ||||||
|     password2: '', |     password2: '', | ||||||
|     email: '', |     email: '', | ||||||
|     verification_code: '', |     verification_code: '' | ||||||
|   }); |   }); | ||||||
|   const { username, password, password2 } = inputs; |   const { username, password, password2 } = inputs; | ||||||
|   const [showEmailVerification, setShowEmailVerification] = useState(false); |   const [showEmailVerification, setShowEmailVerification] = useState(false); | ||||||
| @@ -178,7 +170,7 @@ const RegisterForm = () => { | |||||||
|               <></> |               <></> | ||||||
|             )} |             )} | ||||||
|             <Button |             <Button | ||||||
|               color='' |               color='green' | ||||||
|               fluid |               fluid | ||||||
|               size='large' |               size='large' | ||||||
|               onClick={handleSubmit} |               onClick={handleSubmit} | ||||||
|   | |||||||
| @@ -227,7 +227,7 @@ const UsersTable = () => { | |||||||
|                       content={user.email ? user.email : '未绑定邮箱地址'} |                       content={user.email ? user.email : '未绑定邮箱地址'} | ||||||
|                       key={user.username} |                       key={user.username} | ||||||
|                       header={user.display_name ? user.display_name : user.username} |                       header={user.display_name ? user.display_name : user.username} | ||||||
|                       trigger={<span>{renderText(user.username, 10)}</span>} |                       trigger={<span>{renderText(user.username, 15)}</span>} | ||||||
|                       hoverable |                       hoverable | ||||||
|                     /> |                     /> | ||||||
|                   </Table.Cell> |                   </Table.Cell> | ||||||
|   | |||||||
| @@ -4,6 +4,8 @@ export const CHANNEL_OPTIONS = [ | |||||||
|   { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, |   { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, | ||||||
|   { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, |   { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, | ||||||
|   { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, |   { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, | ||||||
|  |   { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, | ||||||
|  |   { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, | ||||||
|   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, |   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, | ||||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, |   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||||
|   { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, |   { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, | ||||||
|   | |||||||
| @@ -1,5 +1,5 @@ | |||||||
| export const toastConstants = { | export const toastConstants = { | ||||||
|   SUCCESS_TIMEOUT: 500, |   SUCCESS_TIMEOUT: 1500, | ||||||
|   INFO_TIMEOUT: 3000, |   INFO_TIMEOUT: 3000, | ||||||
|   ERROR_TIMEOUT: 5000, |   ERROR_TIMEOUT: 5000, | ||||||
|   WARNING_TIMEOUT: 10000, |   WARNING_TIMEOUT: 10000, | ||||||
|   | |||||||
| @@ -46,9 +46,7 @@ const About = () => { | |||||||
|             about.startsWith('https://') ? <iframe |             about.startsWith('https://') ? <iframe | ||||||
|               src={about} |               src={about} | ||||||
|               style={{ width: '100%', height: '100vh', border: 'none' }} |               style={{ width: '100%', height: '100vh', border: 'none' }} | ||||||
|             /> : <Segment> |             /> : <div style={{ fontSize: 'larger' }} dangerouslySetInnerHTML={{ __html: about }}></div> | ||||||
|               <div style={{ fontSize: 'larger' }} dangerouslySetInnerHTML={{ __html: about }}></div> |  | ||||||
|             </Segment> |  | ||||||
|           } |           } | ||||||
|         </> |         </> | ||||||
|       } |       } | ||||||
|   | |||||||
| @@ -35,6 +35,30 @@ const EditChannel = () => { | |||||||
|   const [customModel, setCustomModel] = useState(''); |   const [customModel, setCustomModel] = useState(''); | ||||||
|   const handleInputChange = (e, { name, value }) => { |   const handleInputChange = (e, { name, value }) => { | ||||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); |     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||||
|  |     if (name === 'type' && inputs.models.length === 0) { | ||||||
|  |       let localModels = []; | ||||||
|  |       switch (value) { | ||||||
|  |         case 14: | ||||||
|  |           localModels = ['claude-instant-1', 'claude-2']; | ||||||
|  |           break; | ||||||
|  |         case 11: | ||||||
|  |           localModels = ['PaLM-2']; | ||||||
|  |           break; | ||||||
|  |         case 15: | ||||||
|  |           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; | ||||||
|  |           break; | ||||||
|  |         case 17: | ||||||
|  |           localModels = ['qwen-v1', 'qwen-plus-v1']; | ||||||
|  |           break; | ||||||
|  |         case 16: | ||||||
|  |           localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||||
|  |           break; | ||||||
|  |         case 18: | ||||||
|  |           localModels = ['SparkDesk']; | ||||||
|  |           break; | ||||||
|  |       } | ||||||
|  |       setInputs((inputs) => ({ ...inputs, models: localModels })); | ||||||
|  |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   const loadChannel = async () => { |   const loadChannel = async () => { | ||||||
| @@ -132,7 +156,10 @@ const EditChannel = () => { | |||||||
|       localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); |       localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); | ||||||
|     } |     } | ||||||
|     if (localInputs.type === 3 && localInputs.other === '') { |     if (localInputs.type === 3 && localInputs.other === '') { | ||||||
|       localInputs.other = '2023-03-15-preview'; |       localInputs.other = '2023-06-01-preview'; | ||||||
|  |     } | ||||||
|  |     if (localInputs.model_mapping === '') { | ||||||
|  |       localInputs.model_mapping = '{}'; | ||||||
|     } |     } | ||||||
|     let res; |     let res; | ||||||
|     localInputs.models = localInputs.models.join(','); |     localInputs.models = localInputs.models.join(','); | ||||||
| @@ -192,7 +219,7 @@ const EditChannel = () => { | |||||||
|                   <Form.Input |                   <Form.Input | ||||||
|                     label='默认 API 版本' |                     label='默认 API 版本' | ||||||
|                     name='other' |                     name='other' | ||||||
|                     placeholder={'请输入默认 API 版本,例如:2023-03-15-preview,该配置可以被实际的请求查询参数所覆盖'} |                     placeholder={'请输入默认 API 版本,例如:2023-06-01-preview,该配置可以被实际的请求查询参数所覆盖'} | ||||||
|                     onChange={handleInputChange} |                     onChange={handleInputChange} | ||||||
|                     value={inputs.other} |                     value={inputs.other} | ||||||
|                     autoComplete='new-password' |                     autoComplete='new-password' | ||||||
| @@ -215,26 +242,12 @@ const EditChannel = () => { | |||||||
|               </Form.Field> |               </Form.Field> | ||||||
|             ) |             ) | ||||||
|           } |           } | ||||||
|           { |  | ||||||
|             inputs.type !== 3 && inputs.type !== 8 && ( |  | ||||||
|               <Form.Field> |  | ||||||
|                 <Form.Input |  | ||||||
|                   label='镜像' |  | ||||||
|                   name='base_url' |  | ||||||
|                   placeholder={'此项可选,输入镜像站地址,格式为:https://domain.com'} |  | ||||||
|                   onChange={handleInputChange} |  | ||||||
|                   value={inputs.base_url} |  | ||||||
|                   autoComplete='new-password' |  | ||||||
|                 /> |  | ||||||
|               </Form.Field> |  | ||||||
|             ) |  | ||||||
|           } |  | ||||||
|           <Form.Field> |           <Form.Field> | ||||||
|             <Form.Input |             <Form.Input | ||||||
|               label='名称' |               label='名称' | ||||||
|               required |               required | ||||||
|               name='name' |               name='name' | ||||||
|               placeholder={'请输入名称'} |               placeholder={'请为渠道命名'} | ||||||
|               onChange={handleInputChange} |               onChange={handleInputChange} | ||||||
|               value={inputs.name} |               value={inputs.name} | ||||||
|               autoComplete='new-password' |               autoComplete='new-password' | ||||||
| @@ -243,7 +256,7 @@ const EditChannel = () => { | |||||||
|           <Form.Field> |           <Form.Field> | ||||||
|             <Form.Dropdown |             <Form.Dropdown | ||||||
|               label='分组' |               label='分组' | ||||||
|               placeholder={'请选择分组'} |               placeholder={'请选择可以使用该渠道的分组'} | ||||||
|               name='groups' |               name='groups' | ||||||
|               required |               required | ||||||
|               fluid |               fluid | ||||||
| @@ -260,7 +273,7 @@ const EditChannel = () => { | |||||||
|           <Form.Field> |           <Form.Field> | ||||||
|             <Form.Dropdown |             <Form.Dropdown | ||||||
|               label='模型' |               label='模型' | ||||||
|               placeholder={'请选择该通道所支持的模型'} |               placeholder={'请选择该渠道所支持的模型'} | ||||||
|               name='models' |               name='models' | ||||||
|               required |               required | ||||||
|               fluid |               fluid | ||||||
| @@ -285,7 +298,7 @@ const EditChannel = () => { | |||||||
|             <Input |             <Input | ||||||
|               action={ |               action={ | ||||||
|                 <Button type={'button'} onClick={() => { |                 <Button type={'button'} onClick={() => { | ||||||
|                   if (customModel.trim() === "") return; |                   if (customModel.trim() === '') return; | ||||||
|                   if (inputs.models.includes(customModel)) return; |                   if (inputs.models.includes(customModel)) return; | ||||||
|                   let localModels = [...inputs.models]; |                   let localModels = [...inputs.models]; | ||||||
|                   localModels.push(customModel); |                   localModels.push(customModel); | ||||||
| @@ -293,7 +306,7 @@ const EditChannel = () => { | |||||||
|                   localModelOptions.push({ |                   localModelOptions.push({ | ||||||
|                     key: customModel, |                     key: customModel, | ||||||
|                     text: customModel, |                     text: customModel, | ||||||
|                     value: customModel, |                     value: customModel | ||||||
|                   }); |                   }); | ||||||
|                   setModelOptions(modelOptions => { |                   setModelOptions(modelOptions => { | ||||||
|                     return [...modelOptions, ...localModelOptions]; |                     return [...modelOptions, ...localModelOptions]; | ||||||
| @@ -311,8 +324,8 @@ const EditChannel = () => { | |||||||
|           </div> |           </div> | ||||||
|           <Form.Field> |           <Form.Field> | ||||||
|             <Form.TextArea |             <Form.TextArea | ||||||
|               label='模型映射' |               label='模型重定向' | ||||||
|               placeholder={`此项可选,为一个 JSON 文本,键为用户请求的模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`} |               placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`} | ||||||
|               name='model_mapping' |               name='model_mapping' | ||||||
|               onChange={handleInputChange} |               onChange={handleInputChange} | ||||||
|               value={inputs.model_mapping} |               value={inputs.model_mapping} | ||||||
| @@ -337,7 +350,7 @@ const EditChannel = () => { | |||||||
|                 label='密钥' |                 label='密钥' | ||||||
|                 name='key' |                 name='key' | ||||||
|                 required |                 required | ||||||
|                 placeholder={inputs.type === 15 ? "请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次" : '请输入密钥'} |                 placeholder={inputs.type === 15 ? '请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} | ||||||
|                 onChange={handleInputChange} |                 onChange={handleInputChange} | ||||||
|                 value={inputs.key} |                 value={inputs.key} | ||||||
|                 autoComplete='new-password' |                 autoComplete='new-password' | ||||||
| @@ -354,7 +367,21 @@ const EditChannel = () => { | |||||||
|               /> |               /> | ||||||
|             ) |             ) | ||||||
|           } |           } | ||||||
|           <Button type={isEdit ? "button" : "submit"} positive onClick={submit}>提交</Button> |           { | ||||||
|  |             inputs.type !== 3 && inputs.type !== 8 && ( | ||||||
|  |               <Form.Field> | ||||||
|  |                 <Form.Input | ||||||
|  |                   label='镜像' | ||||||
|  |                   name='base_url' | ||||||
|  |                   placeholder={'此项可选,用于通过镜像站来进行 API 调用,请输入镜像站地址,格式为:https://domain.com'} | ||||||
|  |                   onChange={handleInputChange} | ||||||
|  |                   value={inputs.base_url} | ||||||
|  |                   autoComplete='new-password' | ||||||
|  |                 /> | ||||||
|  |               </Form.Field> | ||||||
|  |             ) | ||||||
|  |           } | ||||||
|  |           <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button> | ||||||
|         </Form> |         </Form> | ||||||
|       </Segment> |       </Segment> | ||||||
|     </> |     </> | ||||||
|   | |||||||
| @@ -83,7 +83,7 @@ const EditToken = () => { | |||||||
|       if (isEdit) { |       if (isEdit) { | ||||||
|         showSuccess('令牌更新成功!'); |         showSuccess('令牌更新成功!'); | ||||||
|       } else { |       } else { | ||||||
|         showSuccess('令牌创建成功!'); |         showSuccess('令牌创建成功,请在列表页面点击复制获取令牌!'); | ||||||
|         setInputs(originInputs); |         setInputs(originInputs); | ||||||
|       } |       } | ||||||
|     } else { |     } else { | ||||||
|   | |||||||
| @@ -7,12 +7,15 @@ const TopUp = () => { | |||||||
|   const [redemptionCode, setRedemptionCode] = useState(''); |   const [redemptionCode, setRedemptionCode] = useState(''); | ||||||
|   const [topUpLink, setTopUpLink] = useState(''); |   const [topUpLink, setTopUpLink] = useState(''); | ||||||
|   const [userQuota, setUserQuota] = useState(0); |   const [userQuota, setUserQuota] = useState(0); | ||||||
|  |   const [isSubmitting, setIsSubmitting] = useState(false); | ||||||
|  |  | ||||||
|   const topUp = async () => { |   const topUp = async () => { | ||||||
|     if (redemptionCode === '') { |     if (redemptionCode === '') { | ||||||
|       showInfo('请输入充值码!') |       showInfo('请输入充值码!') | ||||||
|       return; |       return; | ||||||
|     } |     } | ||||||
|  |     setIsSubmitting(true); | ||||||
|  |     try { | ||||||
|       const res = await API.post('/api/user/topup', { |       const res = await API.post('/api/user/topup', { | ||||||
|         key: redemptionCode |         key: redemptionCode | ||||||
|       }); |       }); | ||||||
| @@ -26,6 +29,11 @@ const TopUp = () => { | |||||||
|       } else { |       } else { | ||||||
|         showError(message); |         showError(message); | ||||||
|       } |       } | ||||||
|  |     } catch (err) { | ||||||
|  |       showError('请求失败'); | ||||||
|  |     } finally { | ||||||
|  |       setIsSubmitting(false);  | ||||||
|  |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   const openTopUpLink = () => { |   const openTopUpLink = () => { | ||||||
| @@ -74,8 +82,8 @@ const TopUp = () => { | |||||||
|             <Button color='green' onClick={openTopUpLink}> |             <Button color='green' onClick={openTopUpLink}> | ||||||
|               获取兑换码 |               获取兑换码 | ||||||
|             </Button> |             </Button> | ||||||
|             <Button color='yellow' onClick={topUp}> |             <Button color='yellow' onClick={topUp} disabled={isSubmitting}> | ||||||
|               充值 |                 {isSubmitting ? '兑换中...' : '兑换'} | ||||||
|             </Button> |             </Button> | ||||||
|           </Form> |           </Form> | ||||||
|         </Grid.Column> |         </Grid.Column> | ||||||
| @@ -92,5 +100,4 @@ const TopUp = () => { | |||||||
|   ); |   ); | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  |  | ||||||
| export default TopUp; | export default TopUp; | ||||||
		Reference in New Issue
	
	Block a user