mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-25 10:53:42 +08:00 
			
		
		
		
	Compare commits
	
		
			17 Commits
		
	
	
		
			v0.5.2-alp
			...
			0.4.8
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 2ae5741214 | ||
|  | 28d58849a0 | ||
|  | adc9679d56 | ||
|  | 07589ae305 | ||
|  | 95bc32c555 | ||
|  | d61dc4a9ca | ||
|  | b29acb0c89 | ||
|  | a8e418275d | ||
|  | d7ab9b0935 | ||
|  | ebd62c3bfc | ||
|  | 4e6f9f67b3 | ||
|  | 77d295bbf5 | ||
|  | 7e3e25fbd9 | ||
|  | 9bf98ab53a | ||
|  | 3ed9a219c7 | ||
|  | 37d7afcedc | ||
|  | 2756554f7c | 
							
								
								
									
										2
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							| @@ -38,7 +38,7 @@ jobs: | ||||
|         uses: docker/metadata-action@v4 | ||||
|         with: | ||||
|           images: | | ||||
|             justsong/one-api-en | ||||
|             ckt1031/one-api-en | ||||
|  | ||||
|       - name: Build and push Docker images | ||||
|         uses: docker/build-push-action@v3 | ||||
|   | ||||
							
								
								
									
										2
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -42,7 +42,7 @@ jobs: | ||||
|         uses: docker/metadata-action@v4 | ||||
|         with: | ||||
|           images: | | ||||
|             justsong/one-api | ||||
|             ckt1031/one-api | ||||
|             ghcr.io/${{ github.repository }} | ||||
|  | ||||
|       - name: Build and push Docker images | ||||
|   | ||||
							
								
								
									
										2
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -49,7 +49,7 @@ jobs: | ||||
|         uses: docker/metadata-action@v4 | ||||
|         with: | ||||
|           images: | | ||||
|             justsong/one-api | ||||
|             ckt1031/one-api | ||||
|             ghcr.io/${{ github.repository }} | ||||
|  | ||||
|       - name: Build and push Docker images | ||||
|   | ||||
							
								
								
									
										24
									
								
								Dockerfile
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								Dockerfile
									
									
									
									
									
								
							| @@ -1,31 +1,29 @@ | ||||
| FROM node:16 as builder | ||||
|  | ||||
| # Node build stage | ||||
| FROM node:18 as builder | ||||
| WORKDIR /build | ||||
| COPY ./web/package*.json ./ | ||||
| RUN npm ci | ||||
| COPY ./web . | ||||
| COPY ./VERSION . | ||||
| RUN npm install | ||||
| RUN REACT_APP_VERSION=$(cat VERSION) npm run build | ||||
|  | ||||
| # Go build stage | ||||
| FROM golang AS builder2 | ||||
|  | ||||
| ENV GO111MODULE=on \ | ||||
|     CGO_ENABLED=1 \ | ||||
|     GOOS=linux | ||||
|  | ||||
| WORKDIR /build | ||||
| COPY go.mod . | ||||
| COPY go.sum . | ||||
| RUN go mod download | ||||
| COPY . . | ||||
| COPY --from=builder /build/build ./web/build | ||||
| RUN go mod download | ||||
| RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api | ||||
|  | ||||
| # Final stage | ||||
| FROM alpine | ||||
|  | ||||
| RUN apk update \ | ||||
|     && apk upgrade \ | ||||
|     && apk add --no-cache ca-certificates tzdata \ | ||||
|     && update-ca-certificates 2>/dev/null || true | ||||
|  | ||||
| RUN apk update && apk upgrade && apk add --no-cache ca-certificates tzdata && update-ca-certificates 2>/dev/null || true | ||||
| WORKDIR /data | ||||
| COPY --from=builder2 /build/one-api / | ||||
| EXPOSE 3000 | ||||
| WORKDIR /data | ||||
| ENTRYPOINT ["/one-api"] | ||||
|   | ||||
							
								
								
									
										18
									
								
								README.en.md
									
									
									
									
									
								
							
							
						
						
									
										18
									
								
								README.en.md
									
									
									
									
									
								
							| @@ -10,7 +10,7 @@ | ||||
|  | ||||
| # One API | ||||
|  | ||||
| _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use ✨_ | ||||
| _✨ An OpenAI key management & redistribution system, easy to deploy & use ✨_ | ||||
|  | ||||
| </div> | ||||
|  | ||||
| @@ -57,13 +57,15 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use | ||||
| > **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability. | ||||
|  | ||||
| ## Features | ||||
| 1. Support for multiple large models: | ||||
|    + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||
|    + [x] [Anthropic Claude Series Models](https://anthropic.com) | ||||
|    + [x] [Google PaLM2 Series Models](https://developers.generativeai.google) | ||||
|    + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||
|    + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) | ||||
|    + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) | ||||
| 1. Supports multiple API access channels: | ||||
|     + [x] Official OpenAI channel (support proxy configuration) | ||||
|     + [x] **Azure OpenAI API** | ||||
|     + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) | ||||
|     + [x] [OpenAI-SB](https://openai-sb.com) | ||||
|     + [x] [API2D](https://api2d.com/r/197971) | ||||
|     + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) | ||||
|     + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (invitation code: `OneAPI`) | ||||
|     + [x] Custom channel: Various third-party proxy services not included in the list | ||||
| 2. Supports access to multiple channels through **load balancing**. | ||||
| 3. Supports **stream mode** that enables typewriter-like effect through stream transmission. | ||||
| 4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details. | ||||
|   | ||||
							
								
								
									
										52
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										52
									
								
								README.md
									
									
									
									
									
								
							| @@ -11,7 +11,7 @@ | ||||
|  | ||||
| # One API | ||||
|  | ||||
| _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 ✨_ | ||||
| _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用✨_ | ||||
|  | ||||
| </div> | ||||
|  | ||||
| @@ -58,45 +58,39 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
| > **Warning**:从 `v0.3` 版本升级到 `v0.4` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.3-v0.4.sql)。 | ||||
|  | ||||
| ## 功能 | ||||
| 1. 支持多种大模型: | ||||
|    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||
|    + [x] [Anthropic Claude 系列模型](https://anthropic.com) | ||||
|    + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) | ||||
|    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||
|    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) | ||||
|    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||
|    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) | ||||
| 2. 支持配置镜像以及众多第三方代理服务: | ||||
| 1. 支持多种 API 访问渠道: | ||||
|    + [x] OpenAI 官方通道(支持配置镜像) | ||||
|    + [x] **Azure OpenAI API** | ||||
|    + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) | ||||
|    + [x] [OpenAI-SB](https://openai-sb.com) | ||||
|    + [x] [API2D](https://api2d.com/r/197971) | ||||
|    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) | ||||
|    + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) | ||||
|    + [x] [CloseAI](https://console.closeai-asia.com/r/2412) | ||||
|    + [x] 自定义渠道:例如各种未收录的第三方代理服务 | ||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||
| 5. 支持**多机部署**,[详见此处](#多机部署)。 | ||||
| 6. 支持**令牌管理**,设置令牌的过期时间和额度。 | ||||
| 7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 | ||||
| 8. 支持**通道管理**,批量创建通道。 | ||||
| 9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 | ||||
| 10. 支持渠道**设置模型列表**。 | ||||
| 11. 支持**查看额度明细**。 | ||||
| 12. 支持**用户邀请奖励**。 | ||||
| 13. 支持以美元为单位显示额度。 | ||||
| 14. 支持发布公告,设置充值链接,设置新用户初始额度。 | ||||
| 15. 支持模型映射,重定向用户的请求模型。 | ||||
| 16. 支持失败自动重试。 | ||||
| 17. 支持绘图接口。 | ||||
| 18. 支持丰富的**自定义**设置, | ||||
| 2. 支持通过**负载均衡**的方式访问多个渠道。 | ||||
| 3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||
| 4. 支持**多机部署**,[详见此处](#多机部署)。 | ||||
| 5. 支持**令牌管理**,设置令牌的过期时间和额度。 | ||||
| 6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 | ||||
| 7. 支持**通道管理**,批量创建通道。 | ||||
| 8. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 | ||||
| 9. 支持渠道**设置模型列表**。 | ||||
| 10. 支持**查看额度明细**。 | ||||
| 11. 支持**用户邀请奖励**。 | ||||
| 12. 支持以美元为单位显示额度。 | ||||
| 13. 支持发布公告,设置充值链接,设置新用户初始额度。 | ||||
| 14. 支持模型映射,重定向用户的请求模型。 | ||||
| 15. 支持丰富的**自定义**设置, | ||||
|     1. 支持自定义系统名称,logo 以及页脚。 | ||||
|     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 | ||||
| 19. 支持通过系统访问令牌访问管理 API。 | ||||
| 20. 支持 Cloudflare Turnstile 用户校验。 | ||||
| 21. 支持用户管理,支持**多种用户登录注册方式**: | ||||
| 16. 支持通过系统访问令牌访问管理 API。 | ||||
| 17. 支持 Cloudflare Turnstile 用户校验。 | ||||
| 18. 支持用户管理,支持**多种用户登录注册方式**: | ||||
|     + 邮箱登录注册以及通过邮箱进行密码重置。 | ||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||
| 19. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。 | ||||
|  | ||||
| ## 部署 | ||||
| ### 基于 Docker 进行部署 | ||||
|   | ||||
| @@ -68,7 +68,6 @@ var AutomaticDisableChannelEnabled = false | ||||
| var QuotaRemindThreshold = 1000 | ||||
| var PreConsumedQuota = 500 | ||||
| var ApproximateTokenEnabled = false | ||||
| var RetryTimes = 0 | ||||
|  | ||||
| var RootUserEmail = "" | ||||
|  | ||||
| @@ -77,8 +76,6 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" | ||||
| var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | ||||
| var RequestInterval = time.Duration(requestInterval) * time.Second | ||||
|  | ||||
| var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY | ||||
|  | ||||
| const ( | ||||
| 	RoleGuestUser  = 0 | ||||
| 	RoleCommonUser = 1 | ||||
| @@ -153,31 +150,21 @@ const ( | ||||
| 	ChannelTypePaLM      = 11 | ||||
| 	ChannelTypeAPI2GPT   = 12 | ||||
| 	ChannelTypeAIGC2D    = 13 | ||||
| 	ChannelTypeAnthropic = 14 | ||||
| 	ChannelTypeBaidu     = 15 | ||||
| 	ChannelTypeZhipu     = 16 | ||||
| 	ChannelTypeAli       = 17 | ||||
| 	ChannelTypeXunfei    = 18 | ||||
| ) | ||||
|  | ||||
| var ChannelBaseURLs = []string{ | ||||
| 	"",                               // 0 | ||||
| 	"https://api.openai.com",         // 1 | ||||
| 	"https://oa.api2d.net",           // 2 | ||||
| 	"",                               // 3 | ||||
| 	"https://api.closeai-proxy.xyz",  // 4 | ||||
| 	"https://api.openai-sb.com",      // 5 | ||||
| 	"https://api.openaimax.com",      // 6 | ||||
| 	"https://api.ohmygpt.com",        // 7 | ||||
| 	"",                               // 8 | ||||
| 	"https://api.caipacity.com",      // 9 | ||||
| 	"https://api.aiproxy.io",         // 10 | ||||
| 	"",                               // 11 | ||||
| 	"https://api.api2gpt.com",        // 12 | ||||
| 	"https://api.aigc2d.com",         // 13 | ||||
| 	"https://api.anthropic.com",      // 14 | ||||
| 	"https://aip.baidubce.com",       // 15 | ||||
| 	"https://open.bigmodel.cn",       // 16 | ||||
| 	"https://dashscope.aliyuncs.com", // 17 | ||||
| 	"",                               // 18 | ||||
| 	"",                              // 0 | ||||
| 	"https://api.openai.com",        // 1 | ||||
| 	"https://oa.api2d.net",          // 2 | ||||
| 	"",                              // 3 | ||||
| 	"https://api.closeai-proxy.xyz", // 4 | ||||
| 	"https://api.openai-sb.com",     // 5 | ||||
| 	"https://api.openaimax.com",     // 6 | ||||
| 	"https://api.ohmygpt.com",       // 7 | ||||
| 	"",                              // 8 | ||||
| 	"https://api.caipacity.com",     // 9 | ||||
| 	"https://api.aiproxy.io",        // 10 | ||||
| 	"",                              // 11 | ||||
| 	"https://api.api2gpt.com",       // 12 | ||||
| 	"https://api.aigc2d.com",        // 13 | ||||
| } | ||||
|   | ||||
| @@ -4,11 +4,9 @@ import "encoding/json" | ||||
|  | ||||
| // ModelRatio | ||||
| // https://platform.openai.com/docs/models/model-endpoint-compatibility | ||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf | ||||
| // https://openai.com/pricing | ||||
| // TODO: when a new api is enabled, check the pricing here | ||||
| // 1 === $0.002 / 1K tokens | ||||
| // 1 === ¥0.014 / 1k tokens | ||||
| var ModelRatio = map[string]float64{ | ||||
| 	"gpt-4":                   15, | ||||
| 	"gpt-4-0314":              15, | ||||
| @@ -38,18 +36,6 @@ var ModelRatio = map[string]float64{ | ||||
| 	"text-moderation-stable":  0.1, | ||||
| 	"text-moderation-latest":  0.1, | ||||
| 	"dall-e":                  8, | ||||
| 	"claude-instant-1":        0.75, | ||||
| 	"claude-2":                30, | ||||
| 	"ERNIE-Bot":               0.8572, // ¥0.012 / 1k tokens | ||||
| 	"ERNIE-Bot-turbo":         0.5715, // ¥0.008 / 1k tokens | ||||
| 	"Embedding-V1":            0.1429, // ¥0.002 / 1k tokens | ||||
| 	"PaLM-2":                  1, | ||||
| 	"chatglm_pro":             0.7143, // ¥0.01 / 1k tokens | ||||
| 	"chatglm_std":             0.3572, // ¥0.005 / 1k tokens | ||||
| 	"chatglm_lite":            0.1429, // ¥0.002 / 1k tokens | ||||
| 	"qwen-v1":                 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag | ||||
| 	"qwen-plus-v1":            0.5715, // Same as above | ||||
| 	"SparkDesk":               0.8572, // TBD | ||||
| } | ||||
|  | ||||
| func ModelRatio2JSONString() string { | ||||
|   | ||||
| @@ -7,24 +7,16 @@ import ( | ||||
| ) | ||||
|  | ||||
| func GetSubscription(c *gin.Context) { | ||||
| 	var remainQuota int | ||||
| 	var usedQuota int | ||||
| 	var quota int | ||||
| 	var err error | ||||
| 	var token *model.Token | ||||
| 	var expiredTime int64 | ||||
| 	if common.DisplayTokenStatEnabled { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		token, err = model.GetTokenById(tokenId) | ||||
| 		expiredTime = token.ExpiredTime | ||||
| 		remainQuota = token.RemainQuota | ||||
| 		usedQuota = token.UsedQuota | ||||
| 		quota = token.RemainQuota | ||||
| 	} else { | ||||
| 		userId := c.GetInt("id") | ||||
| 		remainQuota, err = model.GetUserQuota(userId) | ||||
| 		usedQuota, err = model.GetUserUsedQuota(userId) | ||||
| 	} | ||||
| 	if expiredTime <= 0 { | ||||
| 		expiredTime = 0 | ||||
| 		quota, err = model.GetUserQuota(userId) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		openAIError := OpenAIError{ | ||||
| @@ -36,7 +28,6 @@ func GetSubscription(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	quota := remainQuota + usedQuota | ||||
| 	amount := float64(quota) | ||||
| 	if common.DisplayInCurrencyEnabled { | ||||
| 		amount /= common.QuotaPerUnit | ||||
| @@ -50,7 +41,6 @@ func GetSubscription(c *gin.Context) { | ||||
| 		SoftLimitUSD:       amount, | ||||
| 		HardLimitUSD:       amount, | ||||
| 		SystemHardLimitUSD: amount, | ||||
| 		AccessUntil:        expiredTime, | ||||
| 	} | ||||
| 	c.JSON(200, subscription) | ||||
| 	return | ||||
|   | ||||
| @@ -22,7 +22,6 @@ type OpenAISubscriptionResponse struct { | ||||
| 	SoftLimitUSD       float64 `json:"soft_limit_usd"` | ||||
| 	HardLimitUSD       float64 `json:"hard_limit_usd"` | ||||
| 	SystemHardLimitUSD float64 `json:"system_hard_limit_usd"` | ||||
| 	AccessUntil        int64   `json:"access_until"` | ||||
| } | ||||
|  | ||||
| type OpenAIUsageDailyCost struct { | ||||
| @@ -85,6 +84,7 @@ func GetAuthHeader(token string) http.Header { | ||||
| } | ||||
|  | ||||
| func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { | ||||
| 	client := &http.Client{} | ||||
| 	req, err := http.NewRequest(method, url, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @@ -92,13 +92,10 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He | ||||
| 	for k := range headers { | ||||
| 		req.Header.Add(k, headers.Get(k)) | ||||
| 	} | ||||
| 	res, err := httpClient.Do(req) | ||||
| 	res, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if res.StatusCode != http.StatusOK { | ||||
| 		return nil, fmt.Errorf("status code: %d", res.StatusCode) | ||||
| 	} | ||||
| 	body, err := io.ReadAll(res.Body) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
|   | ||||
| @@ -1,29 +1,25 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { | ||||
| func testChannel(channel *model.Channel, request ChatRequest) error { | ||||
| 	switch channel.Type { | ||||
| 	case common.ChannelTypePaLM: | ||||
| 		fallthrough | ||||
| 	case common.ChannelTypeAnthropic: | ||||
| 		fallthrough | ||||
| 	case common.ChannelTypeBaidu: | ||||
| 		fallthrough | ||||
| 	case common.ChannelTypeZhipu: | ||||
| 		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil | ||||
| 	case common.ChannelTypeAzure: | ||||
| 		request.Model = "gpt-35-turbo" | ||||
| 	default: | ||||
| @@ -41,11 +37,11 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr | ||||
|  | ||||
| 	jsonData, err := json.Marshal(request) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 		return err | ||||
| 	} | ||||
| 	req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 		return err | ||||
| 	} | ||||
| 	if channel.Type == common.ChannelTypeAzure { | ||||
| 		req.Header.Set("api-key", channel.Key) | ||||
| @@ -53,20 +49,75 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr | ||||
| 		req.Header.Set("Authorization", "Bearer "+channel.Key) | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/json") | ||||
| 	resp, err := httpClient.Do(req) | ||||
| 	client := &http.Client{} | ||||
| 	resp, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 		return err | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	var response TextResponse | ||||
| 	err = json.NewDecoder(resp.Body).Decode(&response) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||
| 	var streamResponseText string | ||||
|  | ||||
| 	if isStream { | ||||
| 		scanner := bufio.NewScanner(resp.Body) | ||||
| 		scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 			if atEOF && len(data) == 0 { | ||||
| 				return 0, nil, nil | ||||
| 			} | ||||
|  | ||||
| 			if i := strings.Index(string(data), "\n\n"); i >= 0 { | ||||
| 				return i + 2, data[0:i], nil | ||||
| 			} | ||||
|  | ||||
| 			if atEOF { | ||||
| 				return len(data), data, nil | ||||
| 			} | ||||
|  | ||||
| 			return 0, nil, nil | ||||
| 		}) | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 6 { // must be something wrong! | ||||
| 				common.SysError("invalid stream response: " + data) | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[6:] | ||||
| 			if !strings.HasPrefix(data, "[DONE]") { | ||||
| 				var streamResponse ChatCompletionsStreamResponse | ||||
| 				err = json.Unmarshal([]byte(data), &streamResponse) | ||||
| 				if err != nil { | ||||
| 					common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 					return err | ||||
| 				} | ||||
| 				for _, choice := range streamResponse.Choices { | ||||
| 					streamResponseText += choice.Delta.Content | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		if streamResponseText == "" { | ||||
| 			return errors.New("empty stream response") | ||||
| 		} | ||||
| 	} else { | ||||
| 		body, err := ioutil.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		err = json.Unmarshal(body, &response) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		// channel.BaseURL starts with https://api.openai.com | ||||
| 		if response.Usage.CompletionTokens == 0 && strings.HasPrefix(channel.BaseURL, "https://api.openai.com") { | ||||
| 			return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) | ||||
| 		} | ||||
| 	} | ||||
| 	if response.Usage.CompletionTokens == 0 { | ||||
| 		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error | ||||
| 	} | ||||
| 	return nil, nil | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func buildTestRequest() *ChatRequest { | ||||
| @@ -101,7 +152,7 @@ func TestChannel(c *gin.Context) { | ||||
| 	} | ||||
| 	testRequest := buildTestRequest() | ||||
| 	tik := time.Now() | ||||
| 	err, _ = testChannel(channel, *testRequest) | ||||
| 	err = testChannel(channel, *testRequest) | ||||
| 	tok := time.Now() | ||||
| 	milliseconds := tok.Sub(tik).Milliseconds() | ||||
| 	go channel.UpdateResponseTime(milliseconds) | ||||
| @@ -165,14 +216,13 @@ func testAllChannels(notify bool) error { | ||||
| 				continue | ||||
| 			} | ||||
| 			tik := time.Now() | ||||
| 			err, openaiErr := testChannel(channel, *testRequest) | ||||
| 			err := testChannel(channel, *testRequest) | ||||
| 			tok := time.Now() | ||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | ||||
| 			if milliseconds > disableThreshold { | ||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 			} | ||||
| 			if shouldDisableChannel(openaiErr) { | ||||
| 			if err != nil || milliseconds > disableThreshold { | ||||
| 				if milliseconds > disableThreshold { | ||||
| 					err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||
| 				} | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 			} | ||||
| 			channel.UpdateResponseTime(milliseconds) | ||||
|   | ||||
| @@ -127,9 +127,8 @@ func SendPasswordResetEmail(c *gin.Context) { | ||||
| 	link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) | ||||
| 	subject := fmt.Sprintf("%s密码重置", common.SystemName) | ||||
| 	content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ | ||||
| 		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ | ||||
| 		"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ | ||||
| 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes) | ||||
| 		"<p>点击<a href='%s'>此处</a>进行密码重置。</p>"+ | ||||
| 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, common.VerificationValidMinutes) | ||||
| 	err := common.SendEmail(subject, email, content) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
|   | ||||
| @@ -252,114 +252,6 @@ func init() { | ||||
| 			Root:       "code-davinci-edit-001", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "claude-instant-1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anturopic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-instant-1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "claude-2", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "anturopic", | ||||
| 			Permission: permission, | ||||
| 			Root:       "claude-2", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "ERNIE-Bot", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "baidu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "ERNIE-Bot", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "ERNIE-Bot-turbo", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "baidu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "ERNIE-Bot-turbo", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "Embedding-V1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "baidu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "Embedding-V1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "PaLM-2", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "google", | ||||
| 			Permission: permission, | ||||
| 			Root:       "PaLM-2", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "chatglm_pro", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "zhipu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "chatglm_pro", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "chatglm_std", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "zhipu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "chatglm_std", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "chatglm_lite", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "zhipu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "chatglm_lite", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "qwen-v1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "ali", | ||||
| 			Permission: permission, | ||||
| 			Root:       "qwen-v1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "qwen-plus-v1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "ali", | ||||
| 			Permission: permission, | ||||
| 			Root:       "qwen-plus-v1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "SparkDesk", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "xunfei", | ||||
| 			Permission: permission, | ||||
| 			Root:       "SparkDesk", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 	} | ||||
| 	openAIModelsMap = make(map[string]OpenAIModels) | ||||
| 	for _, model := range openAIModels { | ||||
|   | ||||
| @@ -1,240 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r | ||||
|  | ||||
| type AliMessage struct { | ||||
| 	User string `json:"user"` | ||||
| 	Bot  string `json:"bot"` | ||||
| } | ||||
|  | ||||
| type AliInput struct { | ||||
| 	Prompt  string       `json:"prompt"` | ||||
| 	History []AliMessage `json:"history"` | ||||
| } | ||||
|  | ||||
| type AliParameters struct { | ||||
| 	TopP         float64 `json:"top_p,omitempty"` | ||||
| 	TopK         int     `json:"top_k,omitempty"` | ||||
| 	Seed         uint64  `json:"seed,omitempty"` | ||||
| 	EnableSearch bool    `json:"enable_search,omitempty"` | ||||
| } | ||||
|  | ||||
| type AliChatRequest struct { | ||||
| 	Model      string        `json:"model"` | ||||
| 	Input      AliInput      `json:"input"` | ||||
| 	Parameters AliParameters `json:"parameters,omitempty"` | ||||
| } | ||||
|  | ||||
| type AliError struct { | ||||
| 	Code      string `json:"code"` | ||||
| 	Message   string `json:"message"` | ||||
| 	RequestId string `json:"request_id"` | ||||
| } | ||||
|  | ||||
| type AliUsage struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| } | ||||
|  | ||||
| type AliOutput struct { | ||||
| 	Text         string `json:"text"` | ||||
| 	FinishReason string `json:"finish_reason"` | ||||
| } | ||||
|  | ||||
| type AliChatResponse struct { | ||||
| 	Output AliOutput `json:"output"` | ||||
| 	Usage  AliUsage  `json:"usage"` | ||||
| 	AliError | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { | ||||
| 	messages := make([]AliMessage, 0, len(request.Messages)) | ||||
| 	prompt := "" | ||||
| 	for i := 0; i < len(request.Messages); i++ { | ||||
| 		message := request.Messages[i] | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, AliMessage{ | ||||
| 				User: message.Content, | ||||
| 				Bot:  "Okay", | ||||
| 			}) | ||||
| 			continue | ||||
| 		} else { | ||||
| 			if i == len(request.Messages)-1 { | ||||
| 				prompt = message.Content | ||||
| 				break | ||||
| 			} | ||||
| 			messages = append(messages, AliMessage{ | ||||
| 				User: message.Content, | ||||
| 				Bot:  request.Messages[i+1].Content, | ||||
| 			}) | ||||
| 			i++ | ||||
| 		} | ||||
| 	} | ||||
| 	return &AliChatRequest{ | ||||
| 		Model: request.Model, | ||||
| 		Input: AliInput{ | ||||
| 			Prompt:  prompt, | ||||
| 			History: messages, | ||||
| 		}, | ||||
| 		//Parameters: AliParameters{  // ChatGPT's parameters are not compatible with Ali's | ||||
| 		//	TopP: request.TopP, | ||||
| 		//	TopK: 50, | ||||
| 		//	//Seed:         0, | ||||
| 		//	//EnableSearch: false, | ||||
| 		//}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Output.Text, | ||||
| 		}, | ||||
| 		FinishReason: response.Output.FinishReason, | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      response.RequestId, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Usage: Usage{ | ||||
| 			PromptTokens:     response.Usage.InputTokens, | ||||
| 			CompletionTokens: response.Usage.OutputTokens, | ||||
| 			TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens, | ||||
| 		}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = aliResponse.Output.Text | ||||
| 	choice.FinishReason = aliResponse.Output.FinishReason | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Id:      aliResponse.RequestId, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "ernie-bot", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] != "data:" { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[5:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	lastResponseText := "" | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var aliResponse AliChatResponse | ||||
| 			err := json.Unmarshal([]byte(data), &aliResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			usage.PromptTokens += aliResponse.Usage.InputTokens | ||||
| 			usage.CompletionTokens += aliResponse.Usage.OutputTokens | ||||
| 			usage.TotalTokens += aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | ||||
| 			response := streamResponseAli2OpenAI(&aliResponse) | ||||
| 			response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) | ||||
| 			lastResponseText = aliResponse.Output.Text | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var aliResponse AliChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &aliResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if aliResponse.Code != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: aliResponse.Message, | ||||
| 				Type:    aliResponse.Code, | ||||
| 				Param:   aliResponse.RequestId, | ||||
| 				Code:    aliResponse.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseAli2OpenAI(&aliResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
| @@ -1,299 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 | ||||
|  | ||||
| type BaiduTokenResponse struct { | ||||
| 	RefreshToken  string `json:"refresh_token"` | ||||
| 	ExpiresIn     int    `json:"expires_in"` | ||||
| 	SessionKey    string `json:"session_key"` | ||||
| 	AccessToken   string `json:"access_token"` | ||||
| 	Scope         string `json:"scope"` | ||||
| 	SessionSecret string `json:"session_secret"` | ||||
| } | ||||
|  | ||||
| type BaiduMessage struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type BaiduChatRequest struct { | ||||
| 	Messages []BaiduMessage `json:"messages"` | ||||
| 	Stream   bool           `json:"stream"` | ||||
| 	UserId   string         `json:"user_id,omitempty"` | ||||
| } | ||||
|  | ||||
| type BaiduError struct { | ||||
| 	ErrorCode int    `json:"error_code"` | ||||
| 	ErrorMsg  string `json:"error_msg"` | ||||
| } | ||||
|  | ||||
| type BaiduChatResponse struct { | ||||
| 	Id               string `json:"id"` | ||||
| 	Object           string `json:"object"` | ||||
| 	Created          int64  `json:"created"` | ||||
| 	Result           string `json:"result"` | ||||
| 	IsTruncated      bool   `json:"is_truncated"` | ||||
| 	NeedClearHistory bool   `json:"need_clear_history"` | ||||
| 	Usage            Usage  `json:"usage"` | ||||
| 	BaiduError | ||||
| } | ||||
|  | ||||
| type BaiduChatStreamResponse struct { | ||||
| 	BaiduChatResponse | ||||
| 	SentenceId int  `json:"sentence_id"` | ||||
| 	IsEnd      bool `json:"is_end"` | ||||
| } | ||||
|  | ||||
| type BaiduEmbeddingRequest struct { | ||||
| 	Input []string `json:"input"` | ||||
| } | ||||
|  | ||||
| type BaiduEmbeddingData struct { | ||||
| 	Object    string    `json:"object"` | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| 	Index     int       `json:"index"` | ||||
| } | ||||
|  | ||||
| type BaiduEmbeddingResponse struct { | ||||
| 	Id      string               `json:"id"` | ||||
| 	Object  string               `json:"object"` | ||||
| 	Created int64                `json:"created"` | ||||
| 	Data    []BaiduEmbeddingData `json:"data"` | ||||
| 	Usage   Usage                `json:"usage"` | ||||
| 	BaiduError | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { | ||||
| 	messages := make([]BaiduMessage, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	return &BaiduChatRequest{ | ||||
| 		Messages: messages, | ||||
| 		Stream:   request.Stream, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Result, | ||||
| 		}, | ||||
| 		FinishReason: "stop", | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      response.Id, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: response.Created, | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Usage:   response.Usage, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = baiduResponse.Result | ||||
| 	choice.FinishReason = "stop" | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Id:      baiduResponse.Id, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: baiduResponse.Created, | ||||
| 		Model:   "ernie-bot", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { | ||||
| 	baiduEmbeddingRequest := BaiduEmbeddingRequest{ | ||||
| 		Input: nil, | ||||
| 	} | ||||
| 	switch request.Input.(type) { | ||||
| 	case string: | ||||
| 		baiduEmbeddingRequest.Input = []string{request.Input.(string)} | ||||
| 	case []string: | ||||
| 		baiduEmbeddingRequest.Input = request.Input.([]string) | ||||
| 	} | ||||
| 	return &baiduEmbeddingRequest | ||||
| } | ||||
|  | ||||
| func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { | ||||
| 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), | ||||
| 		Model:  "baidu-embedding", | ||||
| 		Usage:  response.Usage, | ||||
| 	} | ||||
| 	for _, item := range response.Data { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ | ||||
| 			Object:    item.Object, | ||||
| 			Index:     item.Index, | ||||
| 			Embedding: item.Embedding, | ||||
| 		}) | ||||
| 	} | ||||
| 	return &openAIEmbeddingResponse | ||||
| } | ||||
|  | ||||
| func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 6 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[6:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var baiduResponse BaiduChatStreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &baiduResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			usage.PromptTokens += baiduResponse.Usage.PromptTokens | ||||
| 			usage.CompletionTokens += baiduResponse.Usage.CompletionTokens | ||||
| 			usage.TotalTokens += baiduResponse.Usage.TotalTokens | ||||
| 			response := streamResponseBaidu2OpenAI(&baiduResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var baiduResponse BaiduChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &baiduResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if baiduResponse.ErrorMsg != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: baiduResponse.ErrorMsg, | ||||
| 				Type:    "baidu_error", | ||||
| 				Param:   "", | ||||
| 				Code:    baiduResponse.ErrorCode, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseBaidu2OpenAI(&baiduResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|  | ||||
| func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var baiduResponse BaiduEmbeddingResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &baiduResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if baiduResponse.ErrorMsg != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: baiduResponse.ErrorMsg, | ||||
| 				Type:    "baidu_error", | ||||
| 				Param:   "", | ||||
| 				Code:    baiduResponse.ErrorCode, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
| @@ -1,221 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type ClaudeMetadata struct { | ||||
| 	UserId string `json:"user_id"` | ||||
| } | ||||
|  | ||||
| type ClaudeRequest struct { | ||||
| 	Model             string   `json:"model"` | ||||
| 	Prompt            string   `json:"prompt"` | ||||
| 	MaxTokensToSample int      `json:"max_tokens_to_sample"` | ||||
| 	StopSequences     []string `json:"stop_sequences,omitempty"` | ||||
| 	Temperature       float64  `json:"temperature,omitempty"` | ||||
| 	TopP              float64  `json:"top_p,omitempty"` | ||||
| 	TopK              int      `json:"top_k,omitempty"` | ||||
| 	//ClaudeMetadata    `json:"metadata,omitempty"` | ||||
| 	Stream bool `json:"stream,omitempty"` | ||||
| } | ||||
|  | ||||
| type ClaudeError struct { | ||||
| 	Type    string `json:"type"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type ClaudeResponse struct { | ||||
| 	Completion string      `json:"completion"` | ||||
| 	StopReason string      `json:"stop_reason"` | ||||
| 	Model      string      `json:"model"` | ||||
| 	Error      ClaudeError `json:"error"` | ||||
| } | ||||
|  | ||||
| func stopReasonClaude2OpenAI(reason string) string { | ||||
| 	switch reason { | ||||
| 	case "stop_sequence": | ||||
| 		return "stop" | ||||
| 	case "max_tokens": | ||||
| 		return "length" | ||||
| 	default: | ||||
| 		return reason | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { | ||||
| 	claudeRequest := ClaudeRequest{ | ||||
| 		Model:             textRequest.Model, | ||||
| 		Prompt:            "", | ||||
| 		MaxTokensToSample: textRequest.MaxTokens, | ||||
| 		StopSequences:     nil, | ||||
| 		Temperature:       textRequest.Temperature, | ||||
| 		TopP:              textRequest.TopP, | ||||
| 		Stream:            textRequest.Stream, | ||||
| 	} | ||||
| 	if claudeRequest.MaxTokensToSample == 0 { | ||||
| 		claudeRequest.MaxTokensToSample = 1000000 | ||||
| 	} | ||||
| 	prompt := "" | ||||
| 	for _, message := range textRequest.Messages { | ||||
| 		if message.Role == "user" { | ||||
| 			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) | ||||
| 		} else if message.Role == "assistant" { | ||||
| 			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) | ||||
| 		} else if message.Role == "system" { | ||||
| 			prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) | ||||
| 		} | ||||
| 	} | ||||
| 	prompt += "\n\nAssistant:" | ||||
| 	claudeRequest.Prompt = prompt | ||||
| 	return &claudeRequest | ||||
| } | ||||
|  | ||||
| func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = claudeResponse.Completion | ||||
| 	choice.FinishReason = stopReasonClaude2OpenAI(claudeResponse.StopReason) | ||||
| 	var response ChatCompletionsStreamResponse | ||||
| 	response.Object = "chat.completion.chunk" | ||||
| 	response.Model = claudeResponse.Model | ||||
| 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: strings.TrimPrefix(claudeResponse.Completion, " "), | ||||
| 			Name:    nil, | ||||
| 		}, | ||||
| 		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||||
| 	responseText := "" | ||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | ||||
| 	createdTime := common.GetTimestamp() | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { | ||||
| 			return i + 4, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if !strings.HasPrefix(data, "event: completion") { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = strings.TrimPrefix(data, "event: completion\r\ndata: ") | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			var claudeResponse ClaudeResponse | ||||
| 			err := json.Unmarshal([]byte(data), &claudeResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			responseText += claudeResponse.Completion | ||||
| 			response := streamResponseClaude2OpenAI(&claudeResponse) | ||||
| 			response.Id = responseId | ||||
| 			response.Created = createdTime | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var claudeResponse ClaudeResponse | ||||
| 	err = json.Unmarshal(responseBody, &claudeResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if claudeResponse.Error.Type != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: claudeResponse.Error.Message, | ||||
| 				Type:    claudeResponse.Error.Type, | ||||
| 				Param:   "", | ||||
| 				Code:    claudeResponse.Error.Type, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseClaude2OpenAI(&claudeResponse) | ||||
| 	completionTokens := countTokenText(claudeResponse.Completion, model) | ||||
| 	usage := Usage{ | ||||
| 		PromptTokens:     promptTokens, | ||||
| 		CompletionTokens: completionTokens, | ||||
| 		TotalTokens:      promptTokens + completionTokens, | ||||
| 	} | ||||
| 	fullTextResponse.Usage = usage | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &usage | ||||
| } | ||||
| @@ -22,26 +22,26 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	consumeQuota := c.GetBool("consume_quota") | ||||
| 	group := c.GetString("group") | ||||
|  | ||||
| 	var imageRequest ImageRequest | ||||
| 	var textRequest GeneralOpenAIRequest | ||||
| 	if consumeQuota { | ||||
| 		err := common.UnmarshalBodyReusable(c, &imageRequest) | ||||
| 		err := common.UnmarshalBodyReusable(c, &textRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Prompt validation | ||||
| 	if imageRequest.Prompt == "" { | ||||
| 	if textRequest.Prompt == "" { | ||||
| 		return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// Not "256x256", "512x512", or "1024x1024" | ||||
| 	if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { | ||||
| 	if textRequest.Size != "" && textRequest.Size != "256x256" && textRequest.Size != "512x512" && textRequest.Size != "1024x1024" { | ||||
| 		return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| 	// N should between 1 and 10 | ||||
| 	if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { | ||||
| 	// N should between 1 to 10 | ||||
| 	if textRequest.N != 0 && (textRequest.N < 1 || textRequest.N > 10) { | ||||
| 		return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) | ||||
| 	} | ||||
|  | ||||
| @@ -71,7 +71,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
|  | ||||
| 	var requestBody io.Reader | ||||
| 	if isModelMapped { | ||||
| 		jsonStr, err := json.Marshal(imageRequest) | ||||
| 		jsonStr, err := json.Marshal(textRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| @@ -87,14 +87,14 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
|  | ||||
| 	sizeRatio := 1.0 | ||||
| 	// Size | ||||
| 	if imageRequest.Size == "256x256" { | ||||
| 	if textRequest.Size == "256x256" { | ||||
| 		sizeRatio = 1 | ||||
| 	} else if imageRequest.Size == "512x512" { | ||||
| 	} else if textRequest.Size == "512x512" { | ||||
| 		sizeRatio = 1.125 | ||||
| 	} else if imageRequest.Size == "1024x1024" { | ||||
| 	} else if textRequest.Size == "1024x1024" { | ||||
| 		sizeRatio = 1.25 | ||||
| 	} | ||||
| 	quota := int(ratio*sizeRatio*1000) * imageRequest.N | ||||
| 	quota := int(ratio * sizeRatio * 1000) | ||||
|  | ||||
| 	if consumeQuota && userQuota-quota < 0 { | ||||
| 		return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden) | ||||
| @@ -109,7 +109,8 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | ||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
|  | ||||
| 	resp, err := httpClient.Do(req) | ||||
| 	client := &http.Client{} | ||||
| 	resp, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|   | ||||
| @@ -1,133 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { | ||||
| 	responseText := "" | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 6 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			dataChan <- data | ||||
| 			data = data[6:] | ||||
| 			if !strings.HasPrefix(data, "[DONE]") { | ||||
| 				switch relayMode { | ||||
| 				case RelayModeChatCompletions: | ||||
| 					var streamResponse ChatCompletionsStreamResponse | ||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||
| 					if err != nil { | ||||
| 						common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 						return | ||||
| 					} | ||||
| 					for _, choice := range streamResponse.Choices { | ||||
| 						responseText += choice.Delta.Content | ||||
| 					} | ||||
| 				case RelayModeCompletions: | ||||
| 					var streamResponse CompletionsStreamResponse | ||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||
| 					if err != nil { | ||||
| 						common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 						return | ||||
| 					} | ||||
| 					for _, choice := range streamResponse.Choices { | ||||
| 						responseText += choice.Text | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			if strings.HasPrefix(data, "data: [DONE]") { | ||||
| 				data = data[:12] | ||||
| 			} | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			c.Render(-1, common.CustomEvent{Data: data}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var textResponse TextResponse | ||||
| 	if consumeQuota { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
| 		err = json.Unmarshal(responseBody, &textResponse) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 		} | ||||
| 		if textResponse.Error.Type != "" { | ||||
| 			return &OpenAIErrorWithStatusCode{ | ||||
| 				OpenAIError: textResponse.Error, | ||||
| 				StatusCode:  resp.StatusCode, | ||||
| 			}, nil | ||||
| 		} | ||||
| 		// Reset response body | ||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| 	} | ||||
| 	// We shouldn't set the header before we parse the response body, because the parse part may fail. | ||||
| 	// And then we will have to send an error response, but in this case, the header has already been set. | ||||
| 	// So the httpClient will be confused by the response. | ||||
| 	// For example, Postman will report error, and we cannot check the response at all. | ||||
| 	for k, v := range resp.Header { | ||||
| 		c.Writer.Header().Set(k, v[0]) | ||||
| 	} | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err := io.Copy(c.Writer, resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, &textResponse.Usage | ||||
| } | ||||
| @@ -1,17 +1,10 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| ) | ||||
|  | ||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body | ||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body | ||||
|  | ||||
| type PaLMChatMessage struct { | ||||
| 	Author  string `json:"author"` | ||||
| 	Content string `json:"content"` | ||||
| @@ -22,188 +15,45 @@ type PaLMFilter struct { | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type PaLMPrompt struct { | ||||
| 	Messages []PaLMChatMessage `json:"messages"` | ||||
| } | ||||
|  | ||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body | ||||
| type PaLMChatRequest struct { | ||||
| 	Prompt         PaLMPrompt `json:"prompt"` | ||||
| 	Temperature    float64    `json:"temperature,omitempty"` | ||||
| 	CandidateCount int        `json:"candidateCount,omitempty"` | ||||
| 	TopP           float64    `json:"topP,omitempty"` | ||||
| 	TopK           int        `json:"topK,omitempty"` | ||||
| } | ||||
|  | ||||
| type PaLMError struct { | ||||
| 	Code    int    `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| 	Status  string `json:"status"` | ||||
| 	Prompt         []Message `json:"prompt"` | ||||
| 	Temperature    float64   `json:"temperature"` | ||||
| 	CandidateCount int       `json:"candidateCount"` | ||||
| 	TopP           float64   `json:"topP"` | ||||
| 	TopK           int       `json:"topK"` | ||||
| } | ||||
|  | ||||
| // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body | ||||
| type PaLMChatResponse struct { | ||||
| 	Candidates []PaLMChatMessage `json:"candidates"` | ||||
| 	Messages   []Message         `json:"messages"` | ||||
| 	Filters    []PaLMFilter      `json:"filters"` | ||||
| 	Error      PaLMError         `json:"error"` | ||||
| 	Candidates []Message    `json:"candidates"` | ||||
| 	Messages   []Message    `json:"messages"` | ||||
| 	Filters    []PaLMFilter `json:"filters"` | ||||
| } | ||||
|  | ||||
| func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { | ||||
| 	palmRequest := PaLMChatRequest{ | ||||
| 		Prompt: PaLMPrompt{ | ||||
| 			Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), | ||||
| 		}, | ||||
| 		Temperature:    textRequest.Temperature, | ||||
| 		CandidateCount: textRequest.N, | ||||
| 		TopP:           textRequest.TopP, | ||||
| 		TopK:           textRequest.MaxTokens, | ||||
| 	} | ||||
| 	for _, message := range textRequest.Messages { | ||||
| 		palmMessage := PaLMChatMessage{ | ||||
| 			Content: message.Content, | ||||
| 		} | ||||
| func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode { | ||||
| 	// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage | ||||
| 	messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages)) | ||||
| 	for _, message := range openAIRequest.Messages { | ||||
| 		var author string | ||||
| 		if message.Role == "user" { | ||||
| 			palmMessage.Author = "0" | ||||
| 			author = "0" | ||||
| 		} else { | ||||
| 			palmMessage.Author = "1" | ||||
| 			author = "1" | ||||
| 		} | ||||
| 		palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage) | ||||
| 		messages = append(messages, PaLMChatMessage{ | ||||
| 			Author:  author, | ||||
| 			Content: message.Content, | ||||
| 		}) | ||||
| 	} | ||||
| 	return &palmRequest | ||||
| } | ||||
|  | ||||
| func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), | ||||
| 	} | ||||
| 	for i, candidate := range response.Candidates { | ||||
| 		choice := OpenAITextResponseChoice{ | ||||
| 			Index: i, | ||||
| 			Message: Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: candidate.Content, | ||||
| 			}, | ||||
| 			FinishReason: "stop", | ||||
| 		} | ||||
| 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice) | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	if len(palmResponse.Candidates) > 0 { | ||||
| 		choice.Delta.Content = palmResponse.Candidates[0].Content | ||||
| 	} | ||||
| 	choice.FinishReason = "stop" | ||||
| 	var response ChatCompletionsStreamResponse | ||||
| 	response.Object = "chat.completion.chunk" | ||||
| 	response.Model = "palm2" | ||||
| 	response.Choices = []ChatCompletionsStreamResponseChoice{choice} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||||
| 	responseText := "" | ||||
| 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | ||||
| 	createdTime := common.GetTimestamp() | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error reading stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			common.SysError("error closing stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		var palmResponse PaLMChatResponse | ||||
| 		err = json.Unmarshal(responseBody, &palmResponse) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) | ||||
| 		fullTextResponse.Id = responseId | ||||
| 		fullTextResponse.Created = createdTime | ||||
| 		if len(palmResponse.Candidates) > 0 { | ||||
| 			responseText = palmResponse.Candidates[0].Content | ||||
| 		} | ||||
| 		jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 		if err != nil { | ||||
| 			common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 			stopChan <- true | ||||
| 			return | ||||
| 		} | ||||
| 		dataChan <- string(jsonResponse) | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + data}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||
| 	} | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var palmResponse PaLMChatResponse | ||||
| 	err = json.Unmarshal(responseBody, &palmResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: palmResponse.Error.Message, | ||||
| 				Type:    palmResponse.Error.Status, | ||||
| 				Param:   "", | ||||
| 				Code:    palmResponse.Error.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responsePaLM2OpenAI(&palmResponse) | ||||
| 	completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) | ||||
| 	usage := Usage{ | ||||
| 		PromptTokens:     promptTokens, | ||||
| 		CompletionTokens: completionTokens, | ||||
| 		TotalTokens:      promptTokens + completionTokens, | ||||
| 	} | ||||
| 	fullTextResponse.Usage = usage | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &usage | ||||
| 	request := PaLMChatRequest{ | ||||
| 		Prompt:         nil, | ||||
| 		Temperature:    openAIRequest.Temperature, | ||||
| 		CandidateCount: openAIRequest.N, | ||||
| 		TopP:           openAIRequest.TopP, | ||||
| 		TopK:           openAIRequest.MaxTokens, | ||||
| 	} | ||||
| 	// TODO: forward request to PaLM & convert response | ||||
| 	fmt.Print(request) | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -1,35 +1,19 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	APITypeOpenAI = iota | ||||
| 	APITypeClaude | ||||
| 	APITypePaLM | ||||
| 	APITypeBaidu | ||||
| 	APITypeZhipu | ||||
| 	APITypeAli | ||||
| 	APITypeXunfei | ||||
| ) | ||||
|  | ||||
| var httpClient *http.Client | ||||
|  | ||||
| func init() { | ||||
| 	httpClient = &http.Client{} | ||||
| } | ||||
|  | ||||
| func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	channelType := c.GetInt("channel") | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| @@ -46,9 +30,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	if relayMode == RelayModeModerations && textRequest.Model == "" { | ||||
| 		textRequest.Model = "text-moderation-latest" | ||||
| 	} | ||||
| 	if relayMode == RelayModeEmbeddings && textRequest.Model == "" { | ||||
| 		textRequest.Model = c.Param("model") | ||||
| 	} | ||||
| 	// request validation | ||||
| 	if textRequest.Model == "" { | ||||
| 		return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) | ||||
| @@ -75,7 +56,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	// map model name | ||||
| 	modelMapping := c.GetString("model_mapping") | ||||
| 	isModelMapped := false | ||||
| 	if modelMapping != "" && modelMapping != "{}" { | ||||
| 	if modelMapping != "" { | ||||
| 		modelMap := make(map[string]string) | ||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| 		if err != nil { | ||||
| @@ -86,83 +67,33 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			isModelMapped = true | ||||
| 		} | ||||
| 	} | ||||
| 	apiType := APITypeOpenAI | ||||
| 	switch channelType { | ||||
| 	case common.ChannelTypeAnthropic: | ||||
| 		apiType = APITypeClaude | ||||
| 	case common.ChannelTypeBaidu: | ||||
| 		apiType = APITypeBaidu | ||||
| 	case common.ChannelTypePaLM: | ||||
| 		apiType = APITypePaLM | ||||
| 	case common.ChannelTypeZhipu: | ||||
| 		apiType = APITypeZhipu | ||||
| 	case common.ChannelTypeAli: | ||||
| 		apiType = APITypeAli | ||||
| 	case common.ChannelTypeXunfei: | ||||
| 		apiType = APITypeXunfei | ||||
| 	} | ||||
| 	baseURL := common.ChannelBaseURLs[channelType] | ||||
| 	requestURL := c.Request.URL.String() | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	} | ||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||
| 	switch apiType { | ||||
| 	case APITypeOpenAI: | ||||
| 		if channelType == common.ChannelTypeAzure { | ||||
| 			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | ||||
| 			query := c.Request.URL.Query() | ||||
| 			apiVersion := query.Get("api-version") | ||||
| 			if apiVersion == "" { | ||||
| 				apiVersion = c.GetString("api_version") | ||||
| 			} | ||||
| 			requestURL := strings.Split(requestURL, "?")[0] | ||||
| 			requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) | ||||
| 			baseURL = c.GetString("base_url") | ||||
| 			task := strings.TrimPrefix(requestURL, "/v1/") | ||||
| 			model_ := textRequest.Model | ||||
| 			model_ = strings.Replace(model_, ".", "", -1) | ||||
| 			// https://github.com/songquanpeng/one-api/issues/67 | ||||
| 			model_ = strings.TrimSuffix(model_, "-0301") | ||||
| 			model_ = strings.TrimSuffix(model_, "-0314") | ||||
| 			model_ = strings.TrimSuffix(model_, "-0613") | ||||
| 			fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) | ||||
| 	if channelType == common.ChannelTypeAzure { | ||||
| 		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | ||||
| 		query := c.Request.URL.Query() | ||||
| 		apiVersion := query.Get("api-version") | ||||
| 		if apiVersion == "" { | ||||
| 			apiVersion = c.GetString("api_version") | ||||
| 		} | ||||
| 	case APITypeClaude: | ||||
| 		fullRequestURL = "https://api.anthropic.com/v1/complete" | ||||
| 		if baseURL != "" { | ||||
| 			fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL) | ||||
| 		} | ||||
| 	case APITypeBaidu: | ||||
| 		switch textRequest.Model { | ||||
| 		case "ERNIE-Bot": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" | ||||
| 		case "ERNIE-Bot-turbo": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" | ||||
| 		case "BLOOMZ-7B": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" | ||||
| 		case "Embedding-V1": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" | ||||
| 		} | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days | ||||
| 	case APITypePaLM: | ||||
| 		fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" | ||||
| 		if baseURL != "" { | ||||
| 			fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) | ||||
| 		} | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		fullRequestURL += "?key=" + apiKey | ||||
| 	case APITypeZhipu: | ||||
| 		method := "invoke" | ||||
| 		if textRequest.Stream { | ||||
| 			method = "sse-invoke" | ||||
| 		} | ||||
| 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) | ||||
| 	case APITypeAli: | ||||
| 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" | ||||
| 		requestURL := strings.Split(requestURL, "?")[0] | ||||
| 		requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 		task := strings.TrimPrefix(requestURL, "/v1/") | ||||
| 		model_ := textRequest.Model | ||||
| 		model_ = strings.Replace(model_, ".", "", -1) | ||||
| 		// https://github.com/songquanpeng/one-api/issues/67 | ||||
| 		model_ = strings.TrimSuffix(model_, "-0301") | ||||
| 		model_ = strings.TrimSuffix(model_, "-0314") | ||||
| 		model_ = strings.TrimSuffix(model_, "-0613") | ||||
| 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) | ||||
| 	} else if channelType == common.ChannelTypePaLM { | ||||
| 		err := relayPaLM(textRequest, c) | ||||
| 		return err | ||||
| 	} | ||||
| 	var promptTokens int | ||||
| 	var completionTokens int | ||||
| @@ -207,105 +138,35 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	} else { | ||||
| 		requestBody = c.Request.Body | ||||
| 	} | ||||
| 	switch apiType { | ||||
| 	case APITypeClaude: | ||||
| 		claudeRequest := requestOpenAI2Claude(textRequest) | ||||
| 		jsonStr, err := json.Marshal(claudeRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeBaidu: | ||||
| 		var jsonData []byte | ||||
| 		var err error | ||||
| 		switch relayMode { | ||||
| 		case RelayModeEmbeddings: | ||||
| 			baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) | ||||
| 			jsonData, err = json.Marshal(baiduEmbeddingRequest) | ||||
| 		default: | ||||
| 			baiduRequest := requestOpenAI2Baidu(textRequest) | ||||
| 			jsonData, err = json.Marshal(baiduRequest) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonData) | ||||
| 	case APITypePaLM: | ||||
| 		palmRequest := requestOpenAI2PaLM(textRequest) | ||||
| 		jsonStr, err := json.Marshal(palmRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeZhipu: | ||||
| 		zhipuRequest := requestOpenAI2Zhipu(textRequest) | ||||
| 		jsonStr, err := json.Marshal(zhipuRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeAli: | ||||
| 		aliRequest := requestOpenAI2Ali(textRequest) | ||||
| 		jsonStr, err := json.Marshal(aliRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	var req *http.Request | ||||
| 	var resp *http.Response | ||||
| 	isStream := textRequest.Stream | ||||
|  | ||||
| 	if apiType != APITypeXunfei { // cause xunfei use websocket | ||||
| 		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		switch apiType { | ||||
| 		case APITypeOpenAI: | ||||
| 			if channelType == common.ChannelTypeAzure { | ||||
| 				req.Header.Set("api-key", apiKey) | ||||
| 			} else { | ||||
| 				req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
| 			} | ||||
| 		case APITypeClaude: | ||||
| 			req.Header.Set("x-api-key", apiKey) | ||||
| 			anthropicVersion := c.Request.Header.Get("anthropic-version") | ||||
| 			if anthropicVersion == "" { | ||||
| 				anthropicVersion = "2023-06-01" | ||||
| 			} | ||||
| 			req.Header.Set("anthropic-version", anthropicVersion) | ||||
| 		case APITypeZhipu: | ||||
| 			token := getZhipuToken(apiKey) | ||||
| 			req.Header.Set("Authorization", token) | ||||
| 		case APITypeAli: | ||||
| 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||
| 			if textRequest.Stream { | ||||
| 				req.Header.Set("X-DashScope-SSE", "enable") | ||||
| 			} | ||||
| 		} | ||||
| 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
| 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | ||||
| 		resp, err = httpClient.Do(req) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = req.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = c.Request.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		isStream = strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||
| 	if channelType == common.ChannelTypeAzure { | ||||
| 		key := c.Request.Header.Get("Authorization") | ||||
| 		key = strings.TrimPrefix(key, "Bearer ") | ||||
| 		req.Header.Set("api-key", key) | ||||
| 	} else { | ||||
| 		req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
| 	//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | ||||
| 	client := &http.Client{} | ||||
| 	resp, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = req.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = c.Request.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
|  | ||||
| 	var textResponse TextResponse | ||||
| 	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||
| 	var streamResponseText string | ||||
|  | ||||
| 	defer func() { | ||||
| @@ -318,15 +179,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			if strings.HasPrefix(textRequest.Model, "gpt-4") { | ||||
| 				completionRatio = 2 | ||||
| 			} | ||||
| 			if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu && apiType != APITypeAli { | ||||
| 			if isStream { | ||||
| 				completionTokens = countTokenText(streamResponseText, textRequest.Model) | ||||
| 			} else { | ||||
| 				promptTokens = textResponse.Usage.PromptTokens | ||||
| 				completionTokens = textResponse.Usage.CompletionTokens | ||||
| 				if apiType == APITypeZhipu { | ||||
| 					// zhipu's API does not return prompt tokens & completion tokens | ||||
| 					promptTokens = textResponse.Usage.TotalTokens | ||||
| 				} | ||||
| 			} | ||||
| 			quota = promptTokens + int(float64(completionTokens)*completionRatio) | ||||
| 			quota = int(float64(quota) * ratio) | ||||
| @@ -358,148 +215,123 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| 	switch apiType { | ||||
| 	case APITypeOpenAI: | ||||
| 		if isStream { | ||||
| 			err, responseText := openaiStreamHandler(c, resp, relayMode) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
|  | ||||
| 	if isStream { | ||||
| 		scanner := bufio.NewScanner(resp.Body) | ||||
| 		scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 			if atEOF && len(data) == 0 { | ||||
| 				return 0, nil, nil | ||||
| 			} | ||||
| 			streamResponseText = responseText | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := openaiHandler(c, resp, consumeQuota) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
|  | ||||
| 			if i := strings.Index(string(data), "\n\n"); i >= 0 { | ||||
| 				return i + 2, data[0:i], nil | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
|  | ||||
| 			if atEOF { | ||||
| 				return len(data), data, nil | ||||
| 			} | ||||
| 			return nil | ||||
|  | ||||
| 			return 0, nil, nil | ||||
| 		}) | ||||
| 		dataChan := make(chan string) | ||||
| 		stopChan := make(chan bool) | ||||
| 		go func() { | ||||
| 			for scanner.Scan() { | ||||
| 				data := scanner.Text() | ||||
| 				if len(data) < 6 { // must be something wrong! | ||||
| 					common.SysError("invalid stream response: " + data) | ||||
| 					continue | ||||
| 				} | ||||
| 				dataChan <- data | ||||
| 				data = data[6:] | ||||
| 				if !strings.HasPrefix(data, "[DONE]") { | ||||
| 					switch relayMode { | ||||
| 					case RelayModeChatCompletions: | ||||
| 						var streamResponse ChatCompletionsStreamResponse | ||||
| 						err = json.Unmarshal([]byte(data), &streamResponse) | ||||
| 						if err != nil { | ||||
| 							common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 							return | ||||
| 						} | ||||
| 						for _, choice := range streamResponse.Choices { | ||||
| 							streamResponseText += choice.Delta.Content | ||||
| 						} | ||||
| 					case RelayModeCompletions: | ||||
| 						var streamResponse CompletionsStreamResponse | ||||
| 						err = json.Unmarshal([]byte(data), &streamResponse) | ||||
| 						if err != nil { | ||||
| 							common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 							return | ||||
| 						} | ||||
| 						for _, choice := range streamResponse.Choices { | ||||
| 							streamResponseText += choice.Text | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 			stopChan <- true | ||||
| 		}() | ||||
| 		c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 		c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 		c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 		c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 		c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 		c.Stream(func(w io.Writer) bool { | ||||
| 			select { | ||||
| 			case data := <-dataChan: | ||||
| 				if strings.HasPrefix(data, "data: [DONE]") { | ||||
| 					data = data[:12] | ||||
| 				} | ||||
| 				c.Render(-1, common.CustomEvent{Data: data}) | ||||
| 				return true | ||||
| 			case <-stopChan: | ||||
| 				return false | ||||
| 			} | ||||
| 		}) | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 	case APITypeClaude: | ||||
| 		if isStream { | ||||
| 			err, responseText := claudeStreamHandler(c, resp) | ||||
| 		return nil | ||||
| 	} else { | ||||
| 		if consumeQuota { | ||||
| 			responseBody, err := io.ReadAll(resp.Body) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 				return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||||
| 			} | ||||
| 			streamResponseText = responseText | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) | ||||
| 			err = resp.Body.Close() | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 				return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			err = json.Unmarshal(responseBody, &textResponse) | ||||
| 			if err != nil { | ||||
| 				return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||||
| 			} | ||||
| 			return nil | ||||
| 			if textResponse.Error.Type != "" { | ||||
| 				return &OpenAIErrorWithStatusCode{ | ||||
| 					OpenAIError: textResponse.Error, | ||||
| 					StatusCode:  resp.StatusCode, | ||||
| 				} | ||||
| 			} | ||||
| 			// Reset response body | ||||
| 			resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| 		} | ||||
| 	case APITypeBaidu: | ||||
| 		if isStream { | ||||
| 			err, usage := baiduStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			var err *OpenAIErrorWithStatusCode | ||||
| 			var usage *Usage | ||||
| 			switch relayMode { | ||||
| 			case RelayModeEmbeddings: | ||||
| 				err, usage = baiduEmbeddingHandler(c, resp) | ||||
| 			default: | ||||
| 				err, usage = baiduHandler(c, resp) | ||||
| 			} | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		// We shouldn't set the header before we parse the response body, because the parse part may fail. | ||||
| 		// And then we will have to send an error response, but in this case, the header has already been set. | ||||
| 		// So the client will be confused by the response. | ||||
| 		// For example, Postman will report error, and we cannot check the response at all. | ||||
| 		for k, v := range resp.Header { | ||||
| 			c.Writer.Header().Set(k, v[0]) | ||||
| 		} | ||||
| 	case APITypePaLM: | ||||
| 		if textRequest.Stream { // PaLM2 API does not support stream | ||||
| 			err, responseText := palmStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			streamResponseText = responseText | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		c.Writer.WriteHeader(resp.StatusCode) | ||||
| 		_, err = io.Copy(c.Writer, resp.Body) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 	case APITypeZhipu: | ||||
| 		if isStream { | ||||
| 			err, usage := zhipuStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := zhipuHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		err = resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 	case APITypeAli: | ||||
| 		if isStream { | ||||
| 			err, usage := aliStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := aliHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeXunfei: | ||||
| 		if isStream { | ||||
| 			auth := c.Request.Header.Get("Authorization") | ||||
| 			auth = strings.TrimPrefix(auth, "Bearer ") | ||||
| 			splits := strings.Split(auth, "|") | ||||
| 			if len(splits) != 3 { | ||||
| 				return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) | ||||
| 			} | ||||
| 			err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) | ||||
| 		} | ||||
| 	default: | ||||
| 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | ||||
| 		return nil | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -91,16 +91,3 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus | ||||
| 		StatusCode:  statusCode, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func shouldDisableChannel(err *OpenAIError) bool { | ||||
| 	if !common.AutomaticDisableChannelEnabled { | ||||
| 		return false | ||||
| 	} | ||||
| 	if err == nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | ||||
| 		return true | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|   | ||||
| @@ -1,274 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"crypto/hmac" | ||||
| 	"crypto/sha256" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // https://console.xfyun.cn/services/cbm | ||||
| // https://www.xfyun.cn/doc/spark/Web.html | ||||
|  | ||||
| type XunfeiMessage struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type XunfeiChatRequest struct { | ||||
| 	Header struct { | ||||
| 		AppId string `json:"app_id"` | ||||
| 	} `json:"header"` | ||||
| 	Parameter struct { | ||||
| 		Chat struct { | ||||
| 			Domain      string  `json:"domain,omitempty"` | ||||
| 			Temperature float64 `json:"temperature,omitempty"` | ||||
| 			TopK        int     `json:"top_k,omitempty"` | ||||
| 			MaxTokens   int     `json:"max_tokens,omitempty"` | ||||
| 			Auditing    bool    `json:"auditing,omitempty"` | ||||
| 		} `json:"chat"` | ||||
| 	} `json:"parameter"` | ||||
| 	Payload struct { | ||||
| 		Message struct { | ||||
| 			Text []XunfeiMessage `json:"text"` | ||||
| 		} `json:"message"` | ||||
| 	} `json:"payload"` | ||||
| } | ||||
|  | ||||
| type XunfeiChatResponseTextItem struct { | ||||
| 	Content string `json:"content"` | ||||
| 	Role    string `json:"role"` | ||||
| 	Index   int    `json:"index"` | ||||
| } | ||||
|  | ||||
| type XunfeiChatResponse struct { | ||||
| 	Header struct { | ||||
| 		Code    int    `json:"code"` | ||||
| 		Message string `json:"message"` | ||||
| 		Sid     string `json:"sid"` | ||||
| 		Status  int    `json:"status"` | ||||
| 	} `json:"header"` | ||||
| 	Payload struct { | ||||
| 		Choices struct { | ||||
| 			Status int                          `json:"status"` | ||||
| 			Seq    int                          `json:"seq"` | ||||
| 			Text   []XunfeiChatResponseTextItem `json:"text"` | ||||
| 		} `json:"choices"` | ||||
| 	} `json:"payload"` | ||||
| 	Usage struct { | ||||
| 		//Text struct { | ||||
| 		//	QuestionTokens   string `json:"question_tokens"` | ||||
| 		//	PromptTokens     string `json:"prompt_tokens"` | ||||
| 		//	CompletionTokens string `json:"completion_tokens"` | ||||
| 		//	TotalTokens      string `json:"total_tokens"` | ||||
| 		//} `json:"text"` | ||||
| 		Text Usage `json:"text"` | ||||
| 	} `json:"usage"` | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest { | ||||
| 	messages := make([]XunfeiMessage, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	xunfeiRequest := XunfeiChatRequest{} | ||||
| 	xunfeiRequest.Header.AppId = xunfeiAppId | ||||
| 	xunfeiRequest.Parameter.Chat.Domain = "general" | ||||
| 	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature | ||||
| 	xunfeiRequest.Parameter.Chat.TopK = request.N | ||||
| 	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens | ||||
| 	xunfeiRequest.Payload.Message.Text = messages | ||||
| 	return &xunfeiRequest | ||||
| } | ||||
|  | ||||
| func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { | ||||
| 	if len(response.Payload.Choices.Text) == 0 { | ||||
| 		response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | ||||
| 			{ | ||||
| 				Content: "", | ||||
| 			}, | ||||
| 		} | ||||
| 	} | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Payload.Choices.Text[0].Content, | ||||
| 		}, | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Usage:   response.Usage.Text, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { | ||||
| 		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | ||||
| 			{ | ||||
| 				Content: "", | ||||
| 			}, | ||||
| 		} | ||||
| 	} | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "SparkDesk", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { | ||||
| 	HmacWithShaToBase64 := func(algorithm, data, key string) string { | ||||
| 		mac := hmac.New(sha256.New, []byte(key)) | ||||
| 		mac.Write([]byte(data)) | ||||
| 		encodeData := mac.Sum(nil) | ||||
| 		return base64.StdEncoding.EncodeToString(encodeData) | ||||
| 	} | ||||
| 	ul, err := url.Parse(hostUrl) | ||||
| 	if err != nil { | ||||
| 		fmt.Println(err) | ||||
| 	} | ||||
| 	date := time.Now().UTC().Format(time.RFC1123) | ||||
| 	signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} | ||||
| 	sign := strings.Join(signString, "\n") | ||||
| 	sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret) | ||||
| 	authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, | ||||
| 		"hmac-sha256", "host date request-line", sha) | ||||
| 	authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) | ||||
| 	v := url.Values{} | ||||
| 	v.Add("host", ul.Host) | ||||
| 	v.Add("date", date) | ||||
| 	v.Add("authorization", authorization) | ||||
| 	callUrl := hostUrl + "?" + v.Encode() | ||||
| 	return callUrl | ||||
| } | ||||
|  | ||||
| func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage Usage | ||||
| 	d := websocket.Dialer{ | ||||
| 		HandshakeTimeout: 5 * time.Second, | ||||
| 	} | ||||
| 	hostUrl := "wss://aichat.xf-yun.com/v1/chat" | ||||
| 	conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil) | ||||
| 	if err != nil || resp.StatusCode != 101 { | ||||
| 		return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	data := requestOpenAI2Xunfei(textRequest, appId) | ||||
| 	err = conn.WriteJSON(data) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	dataChan := make(chan XunfeiChatResponse) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			_, msg, err := conn.ReadMessage() | ||||
| 			if err != nil { | ||||
| 				common.SysError("error reading stream response: " + err.Error()) | ||||
| 				break | ||||
| 			} | ||||
| 			var response XunfeiChatResponse | ||||
| 			err = json.Unmarshal(msg, &response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				break | ||||
| 			} | ||||
| 			dataChan <- response | ||||
| 			if response.Payload.Choices.Status == 2 { | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case xunfeiResponse := <-dataChan: | ||||
| 			usage.PromptTokens += xunfeiResponse.Usage.Text.PromptTokens | ||||
| 			usage.CompletionTokens += xunfeiResponse.Usage.Text.CompletionTokens | ||||
| 			usage.TotalTokens += xunfeiResponse.Usage.Text.TotalTokens | ||||
| 			response := streamResponseXunfei2OpenAI(&xunfeiResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var xunfeiResponse XunfeiChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &xunfeiResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if xunfeiResponse.Header.Code != 0 { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: xunfeiResponse.Header.Message, | ||||
| 				Type:    "xunfei_error", | ||||
| 				Param:   "", | ||||
| 				Code:    xunfeiResponse.Header.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
| @@ -1,301 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/golang-jwt/jwt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // https://open.bigmodel.cn/doc/api#chatglm_std | ||||
| // chatglm_std, chatglm_lite | ||||
| // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke | ||||
| // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke | ||||
|  | ||||
| type ZhipuMessage struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type ZhipuRequest struct { | ||||
| 	Prompt      []ZhipuMessage `json:"prompt"` | ||||
| 	Temperature float64        `json:"temperature,omitempty"` | ||||
| 	TopP        float64        `json:"top_p,omitempty"` | ||||
| 	RequestId   string         `json:"request_id,omitempty"` | ||||
| 	Incremental bool           `json:"incremental,omitempty"` | ||||
| } | ||||
|  | ||||
| type ZhipuResponseData struct { | ||||
| 	TaskId     string         `json:"task_id"` | ||||
| 	RequestId  string         `json:"request_id"` | ||||
| 	TaskStatus string         `json:"task_status"` | ||||
| 	Choices    []ZhipuMessage `json:"choices"` | ||||
| 	Usage      `json:"usage"` | ||||
| } | ||||
|  | ||||
| type ZhipuResponse struct { | ||||
| 	Code    int               `json:"code"` | ||||
| 	Msg     string            `json:"msg"` | ||||
| 	Success bool              `json:"success"` | ||||
| 	Data    ZhipuResponseData `json:"data"` | ||||
| } | ||||
|  | ||||
| type ZhipuStreamMetaResponse struct { | ||||
| 	RequestId  string `json:"request_id"` | ||||
| 	TaskId     string `json:"task_id"` | ||||
| 	TaskStatus string `json:"task_status"` | ||||
| 	Usage      `json:"usage"` | ||||
| } | ||||
|  | ||||
| type zhipuTokenData struct { | ||||
| 	Token      string | ||||
| 	ExpiryTime time.Time | ||||
| } | ||||
|  | ||||
| var zhipuTokens sync.Map | ||||
| var expSeconds int64 = 24 * 3600 | ||||
|  | ||||
| func getZhipuToken(apikey string) string { | ||||
| 	data, ok := zhipuTokens.Load(apikey) | ||||
| 	if ok { | ||||
| 		tokenData := data.(zhipuTokenData) | ||||
| 		if time.Now().Before(tokenData.ExpiryTime) { | ||||
| 			return tokenData.Token | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	split := strings.Split(apikey, ".") | ||||
| 	if len(split) != 2 { | ||||
| 		common.SysError("invalid zhipu key: " + apikey) | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| 	id := split[0] | ||||
| 	secret := split[1] | ||||
|  | ||||
| 	expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6 | ||||
| 	expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second) | ||||
|  | ||||
| 	timestamp := time.Now().UnixNano() / 1e6 | ||||
|  | ||||
| 	payload := jwt.MapClaims{ | ||||
| 		"api_key":   id, | ||||
| 		"exp":       expMillis, | ||||
| 		"timestamp": timestamp, | ||||
| 	} | ||||
|  | ||||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) | ||||
|  | ||||
| 	token.Header["alg"] = "HS256" | ||||
| 	token.Header["sign_type"] = "SIGN" | ||||
|  | ||||
| 	tokenString, err := token.SignedString([]byte(secret)) | ||||
| 	if err != nil { | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| 	zhipuTokens.Store(apikey, zhipuTokenData{ | ||||
| 		Token:      tokenString, | ||||
| 		ExpiryTime: expiryTime, | ||||
| 	}) | ||||
|  | ||||
| 	return tokenString | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { | ||||
| 	messages := make([]ZhipuMessage, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    "system", | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	return &ZhipuRequest{ | ||||
| 		Prompt:      messages, | ||||
| 		Temperature: request.Temperature, | ||||
| 		TopP:        request.TopP, | ||||
| 		Incremental: false, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      response.Data.TaskId, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)), | ||||
| 		Usage:   response.Data.Usage, | ||||
| 	} | ||||
| 	for i, choice := range response.Data.Choices { | ||||
| 		openaiChoice := OpenAITextResponseChoice{ | ||||
| 			Index: i, | ||||
| 			Message: Message{ | ||||
| 				Role:    choice.Role, | ||||
| 				Content: strings.Trim(choice.Content, "\""), | ||||
| 			}, | ||||
| 			FinishReason: "", | ||||
| 		} | ||||
| 		if i == len(response.Data.Choices)-1 { | ||||
| 			openaiChoice.FinishReason = "stop" | ||||
| 		} | ||||
| 		fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice) | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = zhipuResponse | ||||
| 	choice.FinishReason = "" | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "chatglm", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = "" | ||||
| 	choice.FinishReason = "stop" | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Id:      zhipuResponse.RequestId, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "chatglm", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response, &zhipuResponse.Usage | ||||
| } | ||||
|  | ||||
| func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage *Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	metaChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			data = strings.Trim(data, "\"") | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] == "data:" { | ||||
| 				dataChan <- data[5:] | ||||
| 			} else if data[:5] == "meta:" { | ||||
| 				metaChan <- data[5:] | ||||
| 			} | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			response := streamResponseZhipu2OpenAI(data) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case data := <-metaChan: | ||||
| 			var zhipuResponse ZhipuStreamMetaResponse | ||||
| 			err := json.Unmarshal([]byte(data), &zhipuResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			usage = zhipuUsage | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, usage | ||||
| } | ||||
|  | ||||
| func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var zhipuResponse ZhipuResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &zhipuResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if !zhipuResponse.Success { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: zhipuResponse.Msg, | ||||
| 				Type:    "zhipu_error", | ||||
| 				Param:   "", | ||||
| 				Code:    zhipuResponse.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseZhipu2OpenAI(&zhipuResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
| @@ -4,7 +4,6 @@ import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| @@ -56,12 +55,6 @@ type TextRequest struct { | ||||
| 	//Stream   bool      `json:"stream"` | ||||
| } | ||||
|  | ||||
| type ImageRequest struct { | ||||
| 	Prompt string `json:"prompt"` | ||||
| 	N      int    `json:"n"` | ||||
| 	Size   string `json:"size"` | ||||
| } | ||||
|  | ||||
| type Usage struct { | ||||
| 	PromptTokens     int `json:"prompt_tokens"` | ||||
| 	CompletionTokens int `json:"completion_tokens"` | ||||
| @@ -85,33 +78,6 @@ type TextResponse struct { | ||||
| 	Error OpenAIError `json:"error"` | ||||
| } | ||||
|  | ||||
| type OpenAITextResponseChoice struct { | ||||
| 	Index        int `json:"index"` | ||||
| 	Message      `json:"message"` | ||||
| 	FinishReason string `json:"finish_reason"` | ||||
| } | ||||
|  | ||||
| type OpenAITextResponse struct { | ||||
| 	Id      string                     `json:"id"` | ||||
| 	Object  string                     `json:"object"` | ||||
| 	Created int64                      `json:"created"` | ||||
| 	Choices []OpenAITextResponseChoice `json:"choices"` | ||||
| 	Usage   `json:"usage"` | ||||
| } | ||||
|  | ||||
| type OpenAIEmbeddingResponseItem struct { | ||||
| 	Object    string    `json:"object"` | ||||
| 	Index     int       `json:"index"` | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| } | ||||
|  | ||||
| type OpenAIEmbeddingResponse struct { | ||||
| 	Object string                        `json:"object"` | ||||
| 	Data   []OpenAIEmbeddingResponseItem `json:"data"` | ||||
| 	Model  string                        `json:"model"` | ||||
| 	Usage  `json:"usage"` | ||||
| } | ||||
|  | ||||
| type ImageResponse struct { | ||||
| 	Created int `json:"created"` | ||||
| 	Data    []struct { | ||||
| @@ -119,19 +85,13 @@ type ImageResponse struct { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type ChatCompletionsStreamResponseChoice struct { | ||||
| 	Delta struct { | ||||
| 		Content string `json:"content"` | ||||
| 	} `json:"delta"` | ||||
| 	FinishReason string `json:"finish_reason,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatCompletionsStreamResponse struct { | ||||
| 	Id      string                                `json:"id"` | ||||
| 	Object  string                                `json:"object"` | ||||
| 	Created int64                                 `json:"created"` | ||||
| 	Model   string                                `json:"model"` | ||||
| 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"` | ||||
| 	Choices []struct { | ||||
| 		Delta struct { | ||||
| 			Content string `json:"content"` | ||||
| 		} `json:"delta"` | ||||
| 		FinishReason string `json:"finish_reason"` | ||||
| 	} `json:"choices"` | ||||
| } | ||||
|  | ||||
| type CompletionsStreamResponse struct { | ||||
| @@ -149,8 +109,6 @@ func Relay(c *gin.Context) { | ||||
| 		relayMode = RelayModeCompletions | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { | ||||
| 		relayMode = RelayModeEmbeddings | ||||
| 	} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||
| 		relayMode = RelayModeEmbeddings | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||
| 		relayMode = RelayModeModerations | ||||
| 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||
| @@ -166,25 +124,16 @@ func Relay(c *gin.Context) { | ||||
| 		err = relayTextHelper(c, relayMode) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		retryTimesStr := c.Query("retry") | ||||
| 		retryTimes, _ := strconv.Atoi(retryTimesStr) | ||||
| 		if retryTimesStr == "" { | ||||
| 			retryTimes = common.RetryTimes | ||||
| 		} | ||||
| 		if retryTimes > 0 { | ||||
| 			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) | ||||
| 		} else { | ||||
| 			if err.StatusCode == http.StatusTooManyRequests { | ||||
| 				err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" | ||||
| 			} | ||||
| 			c.JSON(err.StatusCode, gin.H{ | ||||
| 				"error": err.OpenAIError, | ||||
| 			}) | ||||
| 		if err.StatusCode == http.StatusTooManyRequests { | ||||
| 			err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" | ||||
| 		} | ||||
| 		c.JSON(err.StatusCode, gin.H{ | ||||
| 			"error": err.OpenAIError, | ||||
| 		}) | ||||
| 		channelId := c.GetInt("channel_id") | ||||
| 		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | ||||
| 		// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||
| 		if shouldDisableChannel(&err.OpenAIError) { | ||||
| 		if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key") { | ||||
| 			channelId := c.GetInt("channel_id") | ||||
| 			channelName := c.GetString("channel_name") | ||||
| 			disableChannel(channelId, channelName, err.Message) | ||||
|   | ||||
| @@ -3,13 +3,12 @@ package controller | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
|  | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type LoginRequest struct { | ||||
| @@ -478,16 +477,6 @@ func DeleteUser(c *gin.Context) { | ||||
|  | ||||
| func DeleteSelf(c *gin.Context) { | ||||
| 	id := c.GetInt("id") | ||||
| 	user, _ := model.GetUserById(id, false) | ||||
|  | ||||
| 	if user.Role == common.RoleRootUser { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "不能删除超级管理员账户", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err := model.DeleteUserById(id) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
|   | ||||
							
								
								
									
										36
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										36
									
								
								go.mod
									
									
									
									
									
								
							| @@ -9,37 +9,37 @@ require ( | ||||
| 	github.com/gin-contrib/sessions v0.0.5 | ||||
| 	github.com/gin-contrib/static v0.0.1 | ||||
| 	github.com/gin-gonic/gin v1.9.1 | ||||
| 	github.com/go-playground/validator/v10 v10.14.0 | ||||
| 	github.com/go-playground/validator/v10 v10.14.1 | ||||
| 	github.com/go-redis/redis/v8 v8.11.5 | ||||
| 	github.com/golang-jwt/jwt v3.2.2+incompatible | ||||
| 	github.com/google/uuid v1.3.0 | ||||
| 	github.com/gorilla/websocket v1.5.0 | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.1 | ||||
| 	golang.org/x/crypto v0.9.0 | ||||
| 	gorm.io/driver/mysql v1.4.3 | ||||
| 	gorm.io/driver/sqlite v1.4.3 | ||||
| 	gorm.io/gorm v1.24.0 | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.4 | ||||
| 	golang.org/x/crypto v0.11.0 | ||||
| 	gorm.io/driver/mysql v1.5.1 | ||||
| 	gorm.io/driver/sqlite v1.5.2 | ||||
| 	gorm.io/gorm v1.25.2 | ||||
| ) | ||||
|  | ||||
| require ( | ||||
| 	github.com/bytedance/sonic v1.9.1 // indirect | ||||
| 	github.com/cespare/xxhash/v2 v2.1.2 // indirect | ||||
| 	github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff // indirect | ||||
| 	github.com/bytedance/sonic v1.9.2 // indirect | ||||
| 	github.com/cespare/xxhash/v2 v2.2.0 // indirect | ||||
| 	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect | ||||
| 	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect | ||||
| 	github.com/dlclark/regexp2 v1.8.1 // indirect | ||||
| 	github.com/dlclark/regexp2 v1.10.0 // indirect | ||||
| 	github.com/gabriel-vasile/mimetype v1.4.2 // indirect | ||||
| 	github.com/gin-contrib/sse v0.1.0 // indirect | ||||
| 	github.com/go-playground/locales v0.14.1 // indirect | ||||
| 	github.com/go-playground/universal-translator v0.18.1 // indirect | ||||
| 	github.com/go-sql-driver/mysql v1.6.0 // indirect | ||||
| 	github.com/go-sql-driver/mysql v1.7.1 // indirect | ||||
| 	github.com/goccy/go-json v0.10.2 // indirect | ||||
| 	github.com/gomodule/redigo v2.0.0+incompatible // indirect | ||||
| 	github.com/gorilla/context v1.1.1 // indirect | ||||
| 	github.com/gorilla/securecookie v1.1.1 // indirect | ||||
| 	github.com/gorilla/sessions v1.2.1 // indirect | ||||
| 	github.com/jinzhu/inflection v1.0.0 // indirect | ||||
| 	github.com/jinzhu/now v1.1.5 // indirect | ||||
| 	github.com/json-iterator/go v1.1.12 // indirect | ||||
| 	github.com/klauspost/cpuid/v2 v2.2.4 // indirect | ||||
| 	github.com/klauspost/cpuid/v2 v2.2.5 // indirect | ||||
| 	github.com/leodido/go-urn v1.2.4 // indirect | ||||
| 	github.com/mattn/go-isatty v0.0.19 // indirect | ||||
| 	github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect | ||||
| @@ -48,10 +48,10 @@ require ( | ||||
| 	github.com/pelletier/go-toml/v2 v2.0.8 // indirect | ||||
| 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | ||||
| 	github.com/ugorji/go/codec v1.2.11 // indirect | ||||
| 	golang.org/x/arch v0.3.0 // indirect | ||||
| 	golang.org/x/net v0.10.0 // indirect | ||||
| 	golang.org/x/sys v0.8.0 // indirect | ||||
| 	golang.org/x/text v0.9.0 // indirect | ||||
| 	google.golang.org/protobuf v1.30.0 // indirect | ||||
| 	golang.org/x/arch v0.4.0 // indirect | ||||
| 	golang.org/x/net v0.12.0 // indirect | ||||
| 	golang.org/x/sys v0.10.0 // indirect | ||||
| 	golang.org/x/text v0.11.0 // indirect | ||||
| 	google.golang.org/protobuf v1.31.0 // indirect | ||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||
| ) | ||||
|   | ||||
							
								
								
									
										44
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										44
									
								
								go.sum
									
									
									
									
									
								
							| @@ -1,8 +1,14 @@ | ||||
| github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff h1:RmdPFa+slIr4SCBg4st/l/vZWVe9QJKMXGO60Bxbe04= | ||||
| github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P97fT2CiHkbFQwkK3mjsFAP6zCYV2aXtjw= | ||||
| github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= | ||||
| github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= | ||||
| github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= | ||||
| github.com/bytedance/sonic v1.9.2 h1:GDaNjuWSGu09guE9Oql0MSTNhNCLlWwO8y/xM5BzcbM= | ||||
| github.com/bytedance/sonic v1.9.2/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= | ||||
| github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= | ||||
| github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= | ||||
| github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= | ||||
| github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= | ||||
| github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= | ||||
| github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= | ||||
| github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= | ||||
| @@ -14,6 +20,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r | ||||
| github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= | ||||
| github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= | ||||
| github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= | ||||
| github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= | ||||
| github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= | ||||
| github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= | ||||
| github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= | ||||
| github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= | ||||
| @@ -45,17 +53,22 @@ github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GO | ||||
| github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= | ||||
| github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= | ||||
| github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= | ||||
| github.com/go-playground/validator/v10 v10.14.1 h1:9c50NUPC30zyuKprjL3vNZ0m5oG+jU0zvx4AqHGnv4k= | ||||
| github.com/go-playground/validator/v10 v10.14.1/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= | ||||
| github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= | ||||
| github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= | ||||
| github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= | ||||
| github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= | ||||
| github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= | ||||
| github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= | ||||
| github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= | ||||
| github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= | ||||
| github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= | ||||
| github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= | ||||
| github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= | ||||
| github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= | ||||
| github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= | ||||
| github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= | ||||
| github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= | ||||
| github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= | ||||
| github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= | ||||
| github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | ||||
| github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= | ||||
| @@ -65,10 +78,9 @@ github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8 | ||||
| github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= | ||||
| github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= | ||||
| github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= | ||||
| github.com/gorilla/sessions v1.1.1/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w= | ||||
| github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= | ||||
| github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= | ||||
| github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= | ||||
| github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= | ||||
| github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | ||||
| github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | ||||
| github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||
| @@ -80,6 +92,8 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm | ||||
| github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= | ||||
| github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= | ||||
| github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= | ||||
| github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg= | ||||
| github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= | ||||
| github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= | ||||
| github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= | ||||
| github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= | ||||
| @@ -114,6 +128,8 @@ github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNc | ||||
| github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= | ||||
| github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo= | ||||
| github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw= | ||||
| github.com/pkoukk/tiktoken-go v0.1.4 h1:bniMzWdUvNO6YkRbASo2x5qJf2LAG/TIJojqz+Igm8E= | ||||
| github.com/pkoukk/tiktoken-go v0.1.4/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= | ||||
| github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||||
| github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||||
| github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= | ||||
| @@ -143,27 +159,38 @@ github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ | ||||
| golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | ||||
| golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= | ||||
| golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | ||||
| golang.org/x/arch v0.4.0 h1:A8WCeEWhLwPBKNbFi5Wv5UTCBx5zzubnXDlMOFAzFMc= | ||||
| golang.org/x/arch v0.4.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | ||||
| golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= | ||||
| golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= | ||||
| golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= | ||||
| golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= | ||||
| golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= | ||||
| golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | ||||
| golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= | ||||
| golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= | ||||
| golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= | ||||
| golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= | ||||
| golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= | ||||
| golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= | ||||
| golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | ||||
| golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= | ||||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= | ||||
| golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= | ||||
| golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= | ||||
| golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= | ||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| @@ -171,6 +198,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 | ||||
| google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= | ||||
| google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= | ||||
| google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= | ||||
| google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= | ||||
| google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= | ||||
| gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= | ||||
| @@ -187,9 +216,16 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= | ||||
| gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | ||||
| gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k= | ||||
| gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= | ||||
| gorm.io/driver/mysql v1.5.1 h1:WUEH5VF9obL/lTtzjmML/5e6VfFR/788coz2uaVCAZw= | ||||
| gorm.io/driver/mysql v1.5.1/go.mod h1:Jo3Xu7mMhCyj8dlrb3WoCaRd1FhsVh+yMXb1jUInf5o= | ||||
| gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= | ||||
| gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= | ||||
| gorm.io/driver/sqlite v1.5.2 h1:TpQ+/dqCY4uCigCFyrfnrJnrW9zjpelWVoEVNy5qJkc= | ||||
| gorm.io/driver/sqlite v1.5.2/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= | ||||
| gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= | ||||
| gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74= | ||||
| gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= | ||||
| gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= | ||||
| gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho= | ||||
| gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= | ||||
| rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= | ||||
|   | ||||
| @@ -503,12 +503,5 @@ | ||||
|   "请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT", | ||||
|   "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", | ||||
|   "Homepage URL 填": "Fill in the Homepage URL", | ||||
|   "Authorization callback URL 填": "Fill in the Authorization callback URL", | ||||
|   "请为通道命名": "Please name the channel", | ||||
|   "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", | ||||
|   "模型重定向": "Model redirection", | ||||
|   "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel", | ||||
|   "注意,": "Note that, ", | ||||
|   ",图片演示。": "related image demo.", | ||||
|   "令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!" | ||||
|   "Authorization callback URL 填": "Fill in the Authorization callback URL" | ||||
| } | ||||
|   | ||||
							
								
								
									
										1
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								main.go
									
									
									
									
									
								
							| @@ -54,7 +54,6 @@ func main() { | ||||
| 		if err != nil { | ||||
| 			common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error()) | ||||
| 		} | ||||
| 		common.SyncFrequency = frequency | ||||
| 		go model.SyncOptions(frequency) | ||||
| 		if common.RedisEnabled { | ||||
| 			go model.SyncChannelCache(frequency) | ||||
|   | ||||
| @@ -74,11 +74,6 @@ func Distribute() func(c *gin.Context) { | ||||
| 					modelRequest.Model = "text-moderation-stable" | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = c.Param("model") | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "dall-e" | ||||
| @@ -86,7 +81,7 @@ func Distribute() func(c *gin.Context) { | ||||
| 			} | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) | ||||
| 			if err != nil { | ||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | ||||
| 				message := "无可用渠道" | ||||
| 				if channel != nil { | ||||
| 					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | ||||
| 					message = "数据库一致性已被破坏,请联系管理员" | ||||
|   | ||||
| @@ -12,11 +12,11 @@ import ( | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	TokenCacheSeconds         = common.SyncFrequency | ||||
| 	UserId2GroupCacheSeconds  = common.SyncFrequency | ||||
| 	UserId2QuotaCacheSeconds  = common.SyncFrequency | ||||
| 	UserId2StatusCacheSeconds = common.SyncFrequency | ||||
| const ( | ||||
| 	TokenCacheSeconds         = 60 * 60 | ||||
| 	UserId2GroupCacheSeconds  = 60 * 60 | ||||
| 	UserId2QuotaCacheSeconds  = 10 * 60 | ||||
| 	UserId2StatusCacheSeconds = 60 * 60 | ||||
| ) | ||||
|  | ||||
| func CacheGetTokenByKey(key string) (*Token, error) { | ||||
| @@ -35,7 +35,7 @@ func CacheGetTokenByKey(key string) (*Token, error) { | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) | ||||
| 		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), TokenCacheSeconds*time.Second) | ||||
| 		if err != nil { | ||||
| 			common.SysError("Redis set token error: " + err.Error()) | ||||
| 		} | ||||
| @@ -55,7 +55,7 @@ func CacheGetUserGroup(id int) (group string, err error) { | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, UserId2GroupCacheSeconds*time.Second) | ||||
| 		if err != nil { | ||||
| 			common.SysError("Redis set user group error: " + err.Error()) | ||||
| 		} | ||||
| @@ -73,7 +73,7 @@ func CacheGetUserQuota(id int) (quota int, err error) { | ||||
| 		if err != nil { | ||||
| 			return 0, err | ||||
| 		} | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second) | ||||
| 		if err != nil { | ||||
| 			common.SysError("Redis set user quota error: " + err.Error()) | ||||
| 		} | ||||
| @@ -91,7 +91,7 @@ func CacheUpdateUserQuota(id int) error { | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||
| 	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| @@ -106,7 +106,7 @@ func CacheIsUserEnabled(userId int) bool { | ||||
| 			status = common.UserStatusEnabled | ||||
| 		} | ||||
| 		enabled = fmt.Sprintf("%d", status) | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, UserId2StatusCacheSeconds*time.Second) | ||||
| 		if err != nil { | ||||
| 			common.SysError("Redis set user enabled error: " + err.Error()) | ||||
| 		} | ||||
|   | ||||
| @@ -68,7 +68,6 @@ func InitOptionMap() { | ||||
| 	common.OptionMap["TopUpLink"] = common.TopUpLink | ||||
| 	common.OptionMap["ChatLink"] = common.ChatLink | ||||
| 	common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) | ||||
| 	common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) | ||||
| 	common.OptionMapRWMutex.Unlock() | ||||
| 	loadOptionsFromDatabase() | ||||
| } | ||||
| @@ -197,8 +196,6 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 		common.QuotaRemindThreshold, _ = strconv.Atoi(value) | ||||
| 	case "PreConsumedQuota": | ||||
| 		common.PreConsumedQuota, _ = strconv.Atoi(value) | ||||
| 	case "RetryTimes": | ||||
| 		common.RetryTimes, _ = strconv.Atoi(value) | ||||
| 	case "ModelRatio": | ||||
| 		err = common.UpdateModelRatioByJSONString(value) | ||||
| 	case "GroupRatio": | ||||
|   | ||||
| @@ -51,21 +51,20 @@ func Redeem(key string, userId int) (quota int, err error) { | ||||
| 	redemption := &Redemption{} | ||||
|  | ||||
| 	err = DB.Transaction(func(tx *gorm.DB) error { | ||||
| 		err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error | ||||
| 		err := DB.Where("`key` = ?", key).First(redemption).Error | ||||
| 		if err != nil { | ||||
| 			return errors.New("无效的兑换码") | ||||
| 		} | ||||
| 		if redemption.Status != common.RedemptionCodeStatusEnabled { | ||||
| 			return errors.New("该兑换码已被使用") | ||||
| 		} | ||||
| 		err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error | ||||
| 		err = DB.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		redemption.RedeemedTime = common.GetTimestamp() | ||||
| 		redemption.Status = common.RedemptionCodeStatusUsed | ||||
| 		err = tx.Save(redemption).Error | ||||
| 		return err | ||||
| 		return redemption.SelectUpdate() | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return 0, errors.New("兑换失败," + err.Error()) | ||||
|   | ||||
| @@ -28,7 +28,7 @@ func SetApiRouter(router *gin.Engine) { | ||||
| 		userRoute := apiRouter.Group("/user") | ||||
| 		{ | ||||
| 			userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register) | ||||
| 			userRoute.POST("/login", middleware.CriticalRateLimit(), controller.Login) | ||||
| 			userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login) | ||||
| 			userRoute.GET("/logout", controller.Logout) | ||||
|  | ||||
| 			selfRoute := userRoute.Group("/") | ||||
| @@ -36,7 +36,7 @@ func SetApiRouter(router *gin.Engine) { | ||||
| 			{ | ||||
| 				selfRoute.GET("/self", controller.GetSelf) | ||||
| 				selfRoute.PUT("/self", controller.UpdateSelf) | ||||
| 				selfRoute.DELETE("/self", middleware.TurnstileCheck(), controller.DeleteSelf) | ||||
| 				selfRoute.DELETE("/self", controller.DeleteSelf) | ||||
| 				selfRoute.GET("/token", controller.GenerateAccessToken) | ||||
| 				selfRoute.GET("/aff", controller.GetAffCode) | ||||
| 				selfRoute.POST("/topup", controller.TopUp) | ||||
|   | ||||
| @@ -12,7 +12,7 @@ func SetRelayRouter(router *gin.Engine) { | ||||
| 	modelsRouter := router.Group("/v1/models") | ||||
| 	modelsRouter.Use(middleware.TokenAuth()) | ||||
| 	{ | ||||
| 		modelsRouter.GET("", controller.ListModels) | ||||
| 		modelsRouter.GET("/", controller.ListModels) | ||||
| 		modelsRouter.GET("/:model", controller.RetrieveModel) | ||||
| 	} | ||||
| 	relayV1Router := router.Group("/v1") | ||||
| @@ -25,7 +25,6 @@ func SetRelayRouter(router *gin.Engine) { | ||||
| 		relayV1Router.POST("/images/edits", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/images/variations", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/embeddings", controller.Relay) | ||||
| 		relayV1Router.POST("/engines/:model/embeddings", controller.Relay) | ||||
| 		relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented) | ||||
| 		relayV1Router.POST("/audio/translations", controller.RelayNotImplemented) | ||||
| 		relayV1Router.GET("/files", controller.RelayNotImplemented) | ||||
|   | ||||
							
								
								
									
										3
									
								
								web/.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								web/.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -22,5 +22,4 @@ npm-debug.log* | ||||
| yarn-debug.log* | ||||
| yarn-error.log* | ||||
| .idea | ||||
| package-lock.json | ||||
| yarn.lock | ||||
| yarn.lock | ||||
|   | ||||
							
								
								
									
										18281
									
								
								web/package-lock.json
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						
									
										18281
									
								
								web/package-lock.json
									
									
									
										generated
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -3,18 +3,17 @@ | ||||
|   "version": "0.1.0", | ||||
|   "private": true, | ||||
|   "dependencies": { | ||||
|     "axios": "^0.27.2", | ||||
|     "axios": "^1.4.0", | ||||
|     "history": "^5.3.0", | ||||
|     "marked": "^4.1.1", | ||||
|     "marked": "^5.1.1", | ||||
|     "react": "^18.2.0", | ||||
|     "react-dom": "^18.2.0", | ||||
|     "react-dropzone": "^14.2.3", | ||||
|     "react-router-dom": "^6.3.0", | ||||
|     "react-scripts": "5.0.1", | ||||
|     "react-toastify": "^9.0.8", | ||||
|     "react-turnstile": "^1.0.5", | ||||
|     "react-router-dom": "^6.14.1", | ||||
|     "react-toastify": "^9.1.3", | ||||
|     "react-turnstile": "^1.1.1", | ||||
|     "semantic-ui-css": "^2.5.0", | ||||
|     "semantic-ui-react": "^2.1.3" | ||||
|     "semantic-ui-react": "^2.1.4" | ||||
|   }, | ||||
|   "scripts": { | ||||
|     "start": "react-scripts start", | ||||
| @@ -41,7 +40,8 @@ | ||||
|     ] | ||||
|   }, | ||||
|   "devDependencies": { | ||||
|     "prettier": "^2.7.1" | ||||
|     "prettier": "^2.7.1", | ||||
|     "react-scripts": "^5.0.1" | ||||
|   }, | ||||
|   "prettier": { | ||||
|     "singleQuote": true, | ||||
|   | ||||
| @@ -55,15 +55,6 @@ function App() { | ||||
|       } else { | ||||
|         localStorage.removeItem('chat_link'); | ||||
|       } | ||||
|       if ( | ||||
|         data.version !== process.env.REACT_APP_VERSION && | ||||
|         data.version !== 'v0.0.0' && | ||||
|         process.env.REACT_APP_VERSION !== '' | ||||
|       ) { | ||||
|         showNotice( | ||||
|           `新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面` | ||||
|         ); | ||||
|       } | ||||
|     } else { | ||||
|       showError('无法正常连接至服务器!'); | ||||
|     } | ||||
|   | ||||
| @@ -363,12 +363,9 @@ const ChannelsTable = () => { | ||||
|                   </Table.Cell> | ||||
|                   <Table.Cell> | ||||
|                     <Popup | ||||
|                       trigger={<span onClick={() => { | ||||
|                         updateChannelBalance(channel.id, channel.name, idx); | ||||
|                       }} style={{ cursor: 'pointer' }}> | ||||
|                       {renderBalance(channel.type, channel.balance)} | ||||
|                     </span>} | ||||
|                       content="点击更新" | ||||
|                       content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'} | ||||
|                       key={channel.id} | ||||
|                       trigger={renderBalance(channel.type, channel.balance)} | ||||
|                       basic | ||||
|                     /> | ||||
|                   </Table.Cell> | ||||
| @@ -383,16 +380,16 @@ const ChannelsTable = () => { | ||||
|                       > | ||||
|                         测试 | ||||
|                       </Button> | ||||
|                       {/*<Button*/} | ||||
|                       {/*  size={'small'}*/} | ||||
|                       {/*  positive*/} | ||||
|                       {/*  loading={updatingBalance}*/} | ||||
|                       {/*  onClick={() => {*/} | ||||
|                       {/*    updateChannelBalance(channel.id, channel.name, idx);*/} | ||||
|                       {/*  }}*/} | ||||
|                       {/*>*/} | ||||
|                       {/*  更新余额*/} | ||||
|                       {/*</Button>*/} | ||||
|                       <Button | ||||
|                         size={'small'} | ||||
|                         positive | ||||
|                         loading={updatingBalance} | ||||
|                         onClick={() => { | ||||
|                           updateChannelBalance(channel.id, channel.name, idx); | ||||
|                         }} | ||||
|                       > | ||||
|                         更新余额 | ||||
|                       </Button> | ||||
|                       <Popup | ||||
|                         trigger={ | ||||
|                           <Button size='small' negative> | ||||
|   | ||||
| @@ -1,31 +1,51 @@ | ||||
| import React, { useContext, useEffect, useState } from 'react'; | ||||
| import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react'; | ||||
| import { | ||||
|   Button, | ||||
|   Divider, | ||||
|   Form, | ||||
|   Grid, | ||||
|   Header, | ||||
|   Image, | ||||
|   Message, | ||||
|   Modal, | ||||
|   Segment, | ||||
| } from 'semantic-ui-react'; | ||||
| import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | ||||
| import { UserContext } from '../context/User'; | ||||
| import { API, getLogo, showError, showSuccess } from '../helpers'; | ||||
| import { API, getLogo, showError, showSuccess, showInfo } from '../helpers'; | ||||
| import Turnstile from 'react-turnstile'; | ||||
|  | ||||
| const LoginForm = () => { | ||||
|   const [inputs, setInputs] = useState({ | ||||
|     username: '', | ||||
|     password: '', | ||||
|     wechat_verification_code: '' | ||||
|     wechat_verification_code: '', | ||||
|   }); | ||||
|   const [searchParams, setSearchParams] = useSearchParams(); | ||||
|   const [submitted, setSubmitted] = useState(false); | ||||
|   const { username, password } = inputs; | ||||
|   const [userState, userDispatch] = useContext(UserContext); | ||||
|   const [turnstileEnabled, setTurnstileEnabled] = useState(false); | ||||
|   const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); | ||||
|   const [turnstileToken, setTurnstileToken] = useState(''); | ||||
|   let navigate = useNavigate(); | ||||
|  | ||||
|   const [status, setStatus] = useState({}); | ||||
|   const logo = getLogo(); | ||||
|  | ||||
|   useEffect(() => { | ||||
|     if (searchParams.get('expired')) { | ||||
|     if (searchParams.get("expired")) { | ||||
|       showError('未登录或登录已过期,请重新登录!'); | ||||
|     } | ||||
|     let status = localStorage.getItem('status'); | ||||
|     if (status) { | ||||
|       status = JSON.parse(status); | ||||
|       setStatus(status); | ||||
|  | ||||
|       if (status.turnstile_check) { | ||||
|         setTurnstileEnabled(true); | ||||
|         setTurnstileSiteKey(status.turnstile_site_key); | ||||
|       } | ||||
|     } | ||||
|   }, []); | ||||
|  | ||||
| @@ -65,9 +85,14 @@ const LoginForm = () => { | ||||
|   async function handleSubmit(e) { | ||||
|     setSubmitted(true); | ||||
|     if (username && password) { | ||||
|       const res = await API.post(`/api/user/login`, { | ||||
|       if (turnstileEnabled && turnstileToken === '') { | ||||
|         showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); | ||||
|         return; | ||||
|       } | ||||
|  | ||||
|       const res = await API.post(`/api/user/login?turnstile=${turnstileToken}`, { | ||||
|         username, | ||||
|         password | ||||
|         password, | ||||
|       }); | ||||
|       const { success, message, data } = res.data; | ||||
|       if (success) { | ||||
| @@ -82,44 +107,54 @@ const LoginForm = () => { | ||||
|   } | ||||
|  | ||||
|   return ( | ||||
|     <Grid textAlign='center' style={{ marginTop: '48px' }}> | ||||
|     <Grid textAlign="center" style={{ marginTop: '48px' }}> | ||||
|       <Grid.Column style={{ maxWidth: 450 }}> | ||||
|         <Header as='h2' color='' textAlign='center'> | ||||
|         <Header as="h2" color="" textAlign="center"> | ||||
|           <Image src={logo} /> 用户登录 | ||||
|         </Header> | ||||
|         <Form size='large'> | ||||
|         <Form size="large"> | ||||
|           <Segment> | ||||
|             <Form.Input | ||||
|               fluid | ||||
|               icon='user' | ||||
|               iconPosition='left' | ||||
|               placeholder='用户名' | ||||
|               name='username' | ||||
|               icon="user" | ||||
|               iconPosition="left" | ||||
|               placeholder="用户名" | ||||
|               name="username" | ||||
|               value={username} | ||||
|               onChange={handleChange} | ||||
|             /> | ||||
|             <Form.Input | ||||
|               fluid | ||||
|               icon='lock' | ||||
|               iconPosition='left' | ||||
|               placeholder='密码' | ||||
|               name='password' | ||||
|               type='password' | ||||
|               icon="lock" | ||||
|               iconPosition="left" | ||||
|               placeholder="密码" | ||||
|               name="password" | ||||
|               type="password" | ||||
|               value={password} | ||||
|               onChange={handleChange} | ||||
|             /> | ||||
|             <Button color='green' fluid size='large' onClick={handleSubmit}> | ||||
|             {turnstileEnabled ? ( | ||||
|               <Turnstile | ||||
|                 sitekey={turnstileSiteKey} | ||||
|                 onVerify={(token) => { | ||||
|                   setTurnstileToken(token); | ||||
|                 }} | ||||
|               /> | ||||
|             ) : ( | ||||
|               <></> | ||||
|             )} | ||||
|             <Button color="" fluid size="large" onClick={handleSubmit}> | ||||
|               登录 | ||||
|             </Button> | ||||
|           </Segment> | ||||
|         </Form> | ||||
|         <Message> | ||||
|           忘记密码? | ||||
|           <Link to='/reset' className='btn btn-link'> | ||||
|           <Link to="/reset" className="btn btn-link"> | ||||
|             点击重置 | ||||
|           </Link> | ||||
|           ; 没有账户? | ||||
|           <Link to='/register' className='btn btn-link'> | ||||
|           <Link to="/register" className="btn btn-link"> | ||||
|             点击注册 | ||||
|           </Link> | ||||
|         </Message> | ||||
| @@ -129,8 +164,8 @@ const LoginForm = () => { | ||||
|             {status.github_oauth ? ( | ||||
|               <Button | ||||
|                 circular | ||||
|                 color='black' | ||||
|                 icon='github' | ||||
|                 color="black" | ||||
|                 icon="github" | ||||
|                 onClick={onGitHubOAuthClicked} | ||||
|               /> | ||||
|             ) : ( | ||||
| @@ -139,8 +174,8 @@ const LoginForm = () => { | ||||
|             {status.wechat_login ? ( | ||||
|               <Button | ||||
|                 circular | ||||
|                 color='green' | ||||
|                 icon='wechat' | ||||
|                 color="green" | ||||
|                 icon="wechat" | ||||
|                 onClick={onWeChatLoginClicked} | ||||
|               /> | ||||
|             ) : ( | ||||
| @@ -164,18 +199,18 @@ const LoginForm = () => { | ||||
|                   微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) | ||||
|                 </p> | ||||
|               </div> | ||||
|               <Form size='large'> | ||||
|               <Form size="large"> | ||||
|                 <Form.Input | ||||
|                   fluid | ||||
|                   placeholder='验证码' | ||||
|                   name='wechat_verification_code' | ||||
|                   placeholder="验证码" | ||||
|                   name="wechat_verification_code" | ||||
|                   value={inputs.wechat_verification_code} | ||||
|                   onChange={handleChange} | ||||
|                 /> | ||||
|                 <Button | ||||
|                   color='' | ||||
|                   color="" | ||||
|                   fluid | ||||
|                   size='large' | ||||
|                   size="large" | ||||
|                   onClick={onSubmitWeChatVerificationCode} | ||||
|                 > | ||||
|                   登录 | ||||
|   | ||||
| @@ -20,7 +20,6 @@ const OperationSetting = () => { | ||||
|     DisplayInCurrencyEnabled: '', | ||||
|     DisplayTokenStatEnabled: '', | ||||
|     ApproximateTokenEnabled: '', | ||||
|     RetryTimes: 0, | ||||
|   }); | ||||
|   const [originInputs, setOriginInputs] = useState({}); | ||||
|   let [loading, setLoading] = useState(false); | ||||
| @@ -123,9 +122,6 @@ const OperationSetting = () => { | ||||
|         if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) { | ||||
|           await updateOption('QuotaPerUnit', inputs.QuotaPerUnit); | ||||
|         } | ||||
|         if (originInputs['RetryTimes'] !== inputs.RetryTimes) { | ||||
|           await updateOption('RetryTimes', inputs.RetryTimes); | ||||
|         } | ||||
|         break; | ||||
|     } | ||||
|   }; | ||||
| @@ -137,7 +133,7 @@ const OperationSetting = () => { | ||||
|           <Header as='h3'> | ||||
|             通用设置 | ||||
|           </Header> | ||||
|           <Form.Group widths={4}> | ||||
|           <Form.Group widths={3}> | ||||
|             <Form.Input | ||||
|               label='充值链接' | ||||
|               name='TopUpLink' | ||||
| @@ -166,17 +162,6 @@ const OperationSetting = () => { | ||||
|               step='0.01' | ||||
|               placeholder='一单位货币能兑换的额度' | ||||
|             /> | ||||
|             <Form.Input | ||||
|               label='失败重试次数' | ||||
|               name='RetryTimes' | ||||
|               type={'number'} | ||||
|               step='1' | ||||
|               min='0' | ||||
|               onChange={handleInputChange} | ||||
|               autoComplete='new-password' | ||||
|               value={inputs.RetryTimes} | ||||
|               placeholder='失败重试次数' | ||||
|             /> | ||||
|           </Form.Group> | ||||
|           <Form.Group inline> | ||||
|             <Form.Checkbox | ||||
|   | ||||
| @@ -12,11 +12,6 @@ const PasswordResetConfirm = () => { | ||||
|  | ||||
|   const [loading, setLoading] = useState(false); | ||||
|  | ||||
|   const [disableButton, setDisableButton] = useState(false); | ||||
|   const [countdown, setCountdown] = useState(30); | ||||
|  | ||||
|   const [newPassword, setNewPassword] = useState(''); | ||||
|  | ||||
|   const [searchParams, setSearchParams] = useSearchParams(); | ||||
|   useEffect(() => { | ||||
|     let token = searchParams.get('token'); | ||||
| @@ -27,21 +22,7 @@ const PasswordResetConfirm = () => { | ||||
|     }); | ||||
|   }, []); | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let countdownInterval = null; | ||||
|     if (disableButton && countdown > 0) { | ||||
|       countdownInterval = setInterval(() => { | ||||
|         setCountdown(countdown - 1); | ||||
|       }, 1000); | ||||
|     } else if (countdown === 0) { | ||||
|       setDisableButton(false); | ||||
|       setCountdown(30); | ||||
|     } | ||||
|     return () => clearInterval(countdownInterval);  | ||||
|   }, [disableButton, countdown]); | ||||
|  | ||||
|   async function handleSubmit(e) { | ||||
|     setDisableButton(true); | ||||
|     if (!email) return; | ||||
|     setLoading(true); | ||||
|     const res = await API.post(`/api/user/reset`, { | ||||
| @@ -51,15 +32,14 @@ const PasswordResetConfirm = () => { | ||||
|     const { success, message } = res.data; | ||||
|     if (success) { | ||||
|       let password = res.data.data; | ||||
|       setNewPassword(password); | ||||
|       await copy(password); | ||||
|       showNotice(`新密码已复制到剪贴板:${password}`); | ||||
|       showNotice(`密码已重置并已复制到剪贴板:${password}`); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
|     setLoading(false); | ||||
|   } | ||||
|    | ||||
|  | ||||
|   return ( | ||||
|     <Grid textAlign='center' style={{ marginTop: '48px' }}> | ||||
|       <Grid.Column style={{ maxWidth: 450 }}> | ||||
| @@ -77,37 +57,20 @@ const PasswordResetConfirm = () => { | ||||
|               value={email} | ||||
|               readOnly | ||||
|             /> | ||||
|             {newPassword && ( | ||||
|               <Form.Input | ||||
|               fluid | ||||
|               icon='lock' | ||||
|               iconPosition='left' | ||||
|               placeholder='新密码' | ||||
|               name='newPassword' | ||||
|               value={newPassword} | ||||
|               readOnly | ||||
|               onClick={(e) => { | ||||
|                 e.target.select(); | ||||
|                 navigator.clipboard.writeText(newPassword); | ||||
|                 showNotice(`密码已复制到剪贴板:${newPassword}`); | ||||
|               }} | ||||
|             />             | ||||
|             )} | ||||
|             <Button | ||||
|               color='green' | ||||
|               color='' | ||||
|               fluid | ||||
|               size='large' | ||||
|               onClick={handleSubmit} | ||||
|               loading={loading} | ||||
|               disabled={disableButton} | ||||
|             > | ||||
|               {disableButton ? `密码重置完成` : '提交'} | ||||
|               提交 | ||||
|             </Button> | ||||
|           </Segment> | ||||
|         </Form> | ||||
|       </Grid.Column> | ||||
|     </Grid> | ||||
|   );   | ||||
|   ); | ||||
| }; | ||||
|  | ||||
| export default PasswordResetConfirm; | ||||
|   | ||||
| @@ -5,7 +5,7 @@ import Turnstile from 'react-turnstile'; | ||||
|  | ||||
| const PasswordResetForm = () => { | ||||
|   const [inputs, setInputs] = useState({ | ||||
|     email: '' | ||||
|     email: '', | ||||
|   }); | ||||
|   const { email } = inputs; | ||||
|  | ||||
| @@ -13,29 +13,24 @@ const PasswordResetForm = () => { | ||||
|   const [turnstileEnabled, setTurnstileEnabled] = useState(false); | ||||
|   const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); | ||||
|   const [turnstileToken, setTurnstileToken] = useState(''); | ||||
|   const [disableButton, setDisableButton] = useState(false); | ||||
|   const [countdown, setCountdown] = useState(30); | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let countdownInterval = null; | ||||
|     if (disableButton && countdown > 0) { | ||||
|       countdownInterval = setInterval(() => { | ||||
|         setCountdown(countdown - 1); | ||||
|       }, 1000); | ||||
|     } else if (countdown === 0) { | ||||
|       setDisableButton(false); | ||||
|       setCountdown(30); | ||||
|     let status = localStorage.getItem('status'); | ||||
|     if (status) { | ||||
|       status = JSON.parse(status); | ||||
|       if (status.turnstile_check) { | ||||
|         setTurnstileEnabled(true); | ||||
|         setTurnstileSiteKey(status.turnstile_site_key); | ||||
|       } | ||||
|     } | ||||
|     return () => clearInterval(countdownInterval); | ||||
|   }, [disableButton, countdown]); | ||||
|   }, []); | ||||
|  | ||||
|   function handleChange(e) { | ||||
|     const { name, value } = e.target; | ||||
|     setInputs(inputs => ({ ...inputs, [name]: value })); | ||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
|   } | ||||
|  | ||||
|   async function handleSubmit(e) { | ||||
|     setDisableButton(true); | ||||
|     if (!email) return; | ||||
|     if (turnstileEnabled && turnstileToken === '') { | ||||
|       showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); | ||||
| @@ -83,14 +78,13 @@ const PasswordResetForm = () => { | ||||
|               <></> | ||||
|             )} | ||||
|             <Button | ||||
|               color='green' | ||||
|               color='' | ||||
|               fluid | ||||
|               size='large' | ||||
|               onClick={handleSubmit} | ||||
|               loading={loading} | ||||
|               disabled={disableButton} | ||||
|             > | ||||
|               {disableButton ? `重试 (${countdown})` : '提交'} | ||||
|               提交 | ||||
|             </Button> | ||||
|           </Segment> | ||||
|         </Form> | ||||
|   | ||||
| @@ -1,30 +1,22 @@ | ||||
| import React, { useContext, useEffect, useState } from 'react'; | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Button, Divider, Form, Header, Image, Message, Modal } from 'semantic-ui-react'; | ||||
| import { Link, useNavigate } from 'react-router-dom'; | ||||
| import { Link } from 'react-router-dom'; | ||||
| import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; | ||||
| import Turnstile from 'react-turnstile'; | ||||
| import { UserContext } from '../context/User'; | ||||
|  | ||||
| const PersonalSetting = () => { | ||||
|   const [userState, userDispatch] = useContext(UserContext); | ||||
|   let navigate = useNavigate(); | ||||
|  | ||||
|   const [inputs, setInputs] = useState({ | ||||
|     wechat_verification_code: '', | ||||
|     email_verification_code: '', | ||||
|     email: '', | ||||
|     self_account_deletion_confirmation: '' | ||||
|   }); | ||||
|   const [status, setStatus] = useState({}); | ||||
|   const [showWeChatBindModal, setShowWeChatBindModal] = useState(false); | ||||
|   const [showEmailBindModal, setShowEmailBindModal] = useState(false); | ||||
|   const [showAccountDeleteModal, setShowAccountDeleteModal] = useState(false); | ||||
|   const [turnstileEnabled, setTurnstileEnabled] = useState(false); | ||||
|   const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); | ||||
|   const [turnstileToken, setTurnstileToken] = useState(''); | ||||
|   const [loading, setLoading] = useState(false); | ||||
|   const [disableButton, setDisableButton] = useState(false); | ||||
|   const [countdown, setCountdown] = useState(30); | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let status = localStorage.getItem('status'); | ||||
| @@ -38,19 +30,6 @@ const PersonalSetting = () => { | ||||
|     } | ||||
|   }, []); | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let countdownInterval = null; | ||||
|     if (disableButton && countdown > 0) { | ||||
|       countdownInterval = setInterval(() => { | ||||
|         setCountdown(countdown - 1); | ||||
|       }, 1000); | ||||
|     } else if (countdown === 0) { | ||||
|       setDisableButton(false); | ||||
|       setCountdown(30); | ||||
|     } | ||||
|     return () => clearInterval(countdownInterval); // Clean up on unmount | ||||
|   }, [disableButton, countdown]); | ||||
|  | ||||
|   const handleInputChange = (e, { name, value }) => { | ||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
|   }; | ||||
| @@ -78,26 +57,6 @@ const PersonalSetting = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const deleteAccount = async () => { | ||||
|     if (inputs.self_account_deletion_confirmation !== userState.user.username) { | ||||
|       showError('请输入你的账户名以确认删除!'); | ||||
|       return; | ||||
|     } | ||||
|  | ||||
|     const res = await API.delete('/api/user/self'); | ||||
|     const { success, message } = res.data; | ||||
|  | ||||
|     if (success) { | ||||
|       showSuccess('账户已删除!'); | ||||
|       await API.get('/api/user/logout'); | ||||
|       userDispatch({ type: 'logout' }); | ||||
|       localStorage.removeItem('user'); | ||||
|       navigate('/login'); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const bindWeChat = async () => { | ||||
|     if (inputs.wechat_verification_code === '') return; | ||||
|     const res = await API.get( | ||||
| @@ -119,7 +78,6 @@ const PersonalSetting = () => { | ||||
|   }; | ||||
|  | ||||
|   const sendVerificationCode = async () => { | ||||
|     setDisableButton(true); | ||||
|     if (inputs.email === '') return; | ||||
|     if (turnstileEnabled && turnstileToken === '') { | ||||
|       showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); | ||||
| @@ -165,9 +123,6 @@ const PersonalSetting = () => { | ||||
|       </Button> | ||||
|       <Button onClick={generateAccessToken}>生成系统访问令牌</Button> | ||||
|       <Button onClick={getAffLink}>复制邀请链接</Button> | ||||
|       <Button onClick={() => { | ||||
|         setShowAccountDeleteModal(true); | ||||
|       }}>删除个人账户</Button> | ||||
|       <Divider /> | ||||
|       <Header as='h3'>账号绑定</Header> | ||||
|       { | ||||
| @@ -240,8 +195,8 @@ const PersonalSetting = () => { | ||||
|                 name='email' | ||||
|                 type='email' | ||||
|                 action={ | ||||
|                   <Button onClick={sendVerificationCode} disabled={disableButton || loading}> | ||||
|                     {disableButton ? `重新发送(${countdown})` : '获取验证码'} | ||||
|                   <Button onClick={sendVerificationCode} disabled={loading}> | ||||
|                     获取验证码 | ||||
|                   </Button> | ||||
|                 } | ||||
|               /> | ||||
| @@ -275,47 +230,6 @@ const PersonalSetting = () => { | ||||
|           </Modal.Description> | ||||
|         </Modal.Content> | ||||
|       </Modal> | ||||
|       <Modal | ||||
|         onClose={() => setShowAccountDeleteModal(false)} | ||||
|         onOpen={() => setShowAccountDeleteModal(true)} | ||||
|         open={showAccountDeleteModal} | ||||
|         size={'tiny'} | ||||
|         style={{ maxWidth: '450px' }} | ||||
|       > | ||||
|         <Modal.Header>确认删除自己的帐户</Modal.Header> | ||||
|         <Modal.Content> | ||||
|           <Modal.Description> | ||||
|             <Form size='large'> | ||||
|               <Form.Input | ||||
|                 fluid | ||||
|                 placeholder={`输入你的账户名 ${userState?.user?.username} 以确认删除`} | ||||
|                 name='self_account_deletion_confirmation' | ||||
|                 value={inputs.self_account_deletion_confirmation} | ||||
|                 onChange={handleInputChange} | ||||
|               /> | ||||
|               {turnstileEnabled ? ( | ||||
|                 <Turnstile | ||||
|                   sitekey={turnstileSiteKey} | ||||
|                   onVerify={(token) => { | ||||
|                     setTurnstileToken(token); | ||||
|                   }} | ||||
|                 /> | ||||
|               ) : ( | ||||
|                 <></> | ||||
|               )} | ||||
|               <Button | ||||
|                 color='red' | ||||
|                 fluid | ||||
|                 size='large' | ||||
|                 onClick={deleteAccount} | ||||
|                 loading={loading} | ||||
|               > | ||||
|                 删除 | ||||
|               </Button> | ||||
|             </Form> | ||||
|           </Modal.Description> | ||||
|         </Modal.Content> | ||||
|       </Modal> | ||||
|     </div> | ||||
|   ); | ||||
| }; | ||||
|   | ||||
| @@ -1,5 +1,13 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Button, Form, Grid, Header, Image, Message, Segment } from 'semantic-ui-react'; | ||||
| import { | ||||
|   Button, | ||||
|   Form, | ||||
|   Grid, | ||||
|   Header, | ||||
|   Image, | ||||
|   Message, | ||||
|   Segment, | ||||
| } from 'semantic-ui-react'; | ||||
| import { Link, useNavigate } from 'react-router-dom'; | ||||
| import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; | ||||
| import Turnstile from 'react-turnstile'; | ||||
| @@ -10,7 +18,7 @@ const RegisterForm = () => { | ||||
|     password: '', | ||||
|     password2: '', | ||||
|     email: '', | ||||
|     verification_code: '' | ||||
|     verification_code: '', | ||||
|   }); | ||||
|   const { username, password, password2 } = inputs; | ||||
|   const [showEmailVerification, setShowEmailVerification] = useState(false); | ||||
| @@ -170,7 +178,7 @@ const RegisterForm = () => { | ||||
|               <></> | ||||
|             )} | ||||
|             <Button | ||||
|               color='green' | ||||
|               color='' | ||||
|               fluid | ||||
|               size='large' | ||||
|               onClick={handleSubmit} | ||||
|   | ||||
| @@ -226,8 +226,7 @@ const UsersTable = () => { | ||||
|                     <Popup | ||||
|                       content={user.email ? user.email : '未绑定邮箱地址'} | ||||
|                       key={user.username} | ||||
|                       header={user.display_name ? user.display_name : user.username} | ||||
|                       trigger={<span>{renderText(user.username, 15)}</span>} | ||||
|                       trigger={<span>{renderText(user.username, 10)}</span>} | ||||
|                       hoverable | ||||
|                     /> | ||||
|                   </Table.Cell> | ||||
|   | ||||
| @@ -1,20 +1,14 @@ | ||||
| export const CHANNEL_OPTIONS = [ | ||||
|   { key: 1, text: 'OpenAI', value: 1, color: 'green' }, | ||||
|   { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, | ||||
|   { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, | ||||
|   { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, | ||||
|   { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, | ||||
|   { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, | ||||
|   { key: 18, text: '讯飞星火认知大模型', value: 18, color: 'blue' }, | ||||
|   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, | ||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||
|   { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, | ||||
|   { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, | ||||
|   { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' }, | ||||
|   { key: 10, text: '代理:AI Proxy', value: 10, color: 'purple' }, | ||||
|   { key: 4, text: '代理:CloseAI', value: 4, color: 'teal' }, | ||||
|   { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' }, | ||||
|   { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' }, | ||||
|   { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' }, | ||||
|   { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' } | ||||
|   { key: 8, text: '自定义', value: 8, color: 'pink' }, | ||||
|   { key: 3, text: 'Azure', value: 3, color: 'olive' }, | ||||
|   { key: 2, text: 'API2D', value: 2, color: 'blue' }, | ||||
|   { key: 4, text: 'CloseAI', value: 4, color: 'teal' }, | ||||
|   { key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' }, | ||||
|   { key: 6, text: 'OpenAI Max', value: 6, color: 'violet' }, | ||||
|   { key: 7, text: 'OhMyGPT', value: 7, color: 'purple' }, | ||||
|   { key: 9, text: 'AI.LS', value: 9, color: 'yellow' }, | ||||
|   { key: 10, text: 'AI Proxy', value: 10, color: 'purple' }, | ||||
|   { key: 12, text: 'API2GPT', value: 12, color: 'blue' }, | ||||
|   { key: 13, text: 'AIGC2D', value: 13, color: 'purple' } | ||||
| ]; | ||||
| @@ -1,5 +1,5 @@ | ||||
| export const toastConstants = { | ||||
|   SUCCESS_TIMEOUT: 1500, | ||||
|   SUCCESS_TIMEOUT: 500, | ||||
|   INFO_TIMEOUT: 3000, | ||||
|   ERROR_TIMEOUT: 5000, | ||||
|   WARNING_TIMEOUT: 10000, | ||||
|   | ||||
| @@ -46,7 +46,9 @@ const About = () => { | ||||
|             about.startsWith('https://') ? <iframe | ||||
|               src={about} | ||||
|               style={{ width: '100%', height: '100vh', border: 'none' }} | ||||
|             /> : <div style={{ fontSize: 'larger' }} dangerouslySetInnerHTML={{ __html: about }}></div> | ||||
|             /> : <Segment> | ||||
|               <div style={{ fontSize: 'larger' }} dangerouslySetInnerHTML={{ __html: about }}></div> | ||||
|             </Segment> | ||||
|           } | ||||
|         </> | ||||
|       } | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react'; | ||||
| import { Button, Form, Header, Message, Segment } from 'semantic-ui-react'; | ||||
| import { useParams } from 'react-router-dom'; | ||||
| import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; | ||||
| import { CHANNEL_OPTIONS } from '../../constants'; | ||||
| @@ -27,38 +27,12 @@ const EditChannel = () => { | ||||
|   }; | ||||
|   const [batch, setBatch] = useState(false); | ||||
|   const [inputs, setInputs] = useState(originInputs); | ||||
|   const [originModelOptions, setOriginModelOptions] = useState([]); | ||||
|   const [modelOptions, setModelOptions] = useState([]); | ||||
|   const [groupOptions, setGroupOptions] = useState([]); | ||||
|   const [basicModels, setBasicModels] = useState([]); | ||||
|   const [fullModels, setFullModels] = useState([]); | ||||
|   const [customModel, setCustomModel] = useState(''); | ||||
|   const handleInputChange = (e, { name, value }) => { | ||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
|     if (name === 'type' && inputs.models.length === 0) { | ||||
|       let localModels = []; | ||||
|       switch (value) { | ||||
|         case 14: | ||||
|           localModels = ['claude-instant-1', 'claude-2']; | ||||
|           break; | ||||
|         case 11: | ||||
|           localModels = ['PaLM-2']; | ||||
|           break; | ||||
|         case 15: | ||||
|           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; | ||||
|           break; | ||||
|         case 17: | ||||
|           localModels = ['qwen-v1', 'qwen-plus-v1']; | ||||
|           break; | ||||
|         case 16: | ||||
|           localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||
|           break; | ||||
|         case 18: | ||||
|           localModels = ['SparkDesk']; | ||||
|           break; | ||||
|       } | ||||
|       setInputs((inputs) => ({ ...inputs, models: localModels })); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const loadChannel = async () => { | ||||
| @@ -88,16 +62,13 @@ const EditChannel = () => { | ||||
|   const fetchModels = async () => { | ||||
|     try { | ||||
|       let res = await API.get(`/api/channel/models`); | ||||
|       let localModelOptions = res.data.data.map((model) => ({ | ||||
|       setModelOptions(res.data.data.map((model) => ({ | ||||
|         key: model.id, | ||||
|         text: model.id, | ||||
|         value: model.id | ||||
|       })); | ||||
|       setOriginModelOptions(localModelOptions); | ||||
|       }))); | ||||
|       setFullModels(res.data.data.map((model) => model.id)); | ||||
|       setBasicModels(res.data.data.filter((model) => { | ||||
|         return model.id.startsWith('gpt-3') || model.id.startsWith('text-'); | ||||
|       }).map((model) => model.id)); | ||||
|       setBasicModels(res.data.data.filter((model) => !model.id.startsWith('gpt-4')).map((model) => model.id)); | ||||
|     } catch (error) { | ||||
|       showError(error.message); | ||||
|     } | ||||
| @@ -116,20 +87,6 @@ const EditChannel = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let localModelOptions = [...originModelOptions]; | ||||
|     inputs.models.forEach((model) => { | ||||
|       if (!localModelOptions.find((option) => option.key === model)) { | ||||
|         localModelOptions.push({ | ||||
|           key: model, | ||||
|           text: model, | ||||
|           value: model | ||||
|         }); | ||||
|       } | ||||
|     }); | ||||
|     setModelOptions(localModelOptions); | ||||
|   }, [originModelOptions, inputs.models]); | ||||
|  | ||||
|   useEffect(() => { | ||||
|     if (isEdit) { | ||||
|       loadChannel().then(); | ||||
| @@ -156,10 +113,7 @@ const EditChannel = () => { | ||||
|       localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); | ||||
|     } | ||||
|     if (localInputs.type === 3 && localInputs.other === '') { | ||||
|       localInputs.other = '2023-06-01-preview'; | ||||
|     } | ||||
|     if (localInputs.model_mapping === '') { | ||||
|       localInputs.model_mapping = '{}'; | ||||
|       localInputs.other = '2023-03-15-preview'; | ||||
|     } | ||||
|     let res; | ||||
|     localInputs.models = localInputs.models.join(','); | ||||
| @@ -219,7 +173,7 @@ const EditChannel = () => { | ||||
|                   <Form.Input | ||||
|                     label='默认 API 版本' | ||||
|                     name='other' | ||||
|                     placeholder={'请输入默认 API 版本,例如:2023-06-01-preview,该配置可以被实际的请求查询参数所覆盖'} | ||||
|                     placeholder={'请输入默认 API 版本,例如:2023-03-15-preview,该配置可以被实际的请求查询参数所覆盖'} | ||||
|                     onChange={handleInputChange} | ||||
|                     value={inputs.other} | ||||
|                     autoComplete='new-password' | ||||
| @@ -242,12 +196,26 @@ const EditChannel = () => { | ||||
|               </Form.Field> | ||||
|             ) | ||||
|           } | ||||
|           { | ||||
|             inputs.type !== 3 && inputs.type !== 8 && ( | ||||
|               <Form.Field> | ||||
|                 <Form.Input | ||||
|                   label='镜像' | ||||
|                   name='base_url' | ||||
|                   placeholder={'此项可选,输入镜像站地址,格式为:https://domain.com'} | ||||
|                   onChange={handleInputChange} | ||||
|                   value={inputs.base_url} | ||||
|                   autoComplete='new-password' | ||||
|                 /> | ||||
|               </Form.Field> | ||||
|             ) | ||||
|           } | ||||
|           <Form.Field> | ||||
|             <Form.Input | ||||
|               label='名称' | ||||
|               required | ||||
|               name='name' | ||||
|               placeholder={'请为渠道命名'} | ||||
|               placeholder={'请输入名称'} | ||||
|               onChange={handleInputChange} | ||||
|               value={inputs.name} | ||||
|               autoComplete='new-password' | ||||
| @@ -256,7 +224,7 @@ const EditChannel = () => { | ||||
|           <Form.Field> | ||||
|             <Form.Dropdown | ||||
|               label='分组' | ||||
|               placeholder={'请选择可以使用该渠道的分组'} | ||||
|               placeholder={'请选择分组'} | ||||
|               name='groups' | ||||
|               required | ||||
|               fluid | ||||
| @@ -273,7 +241,7 @@ const EditChannel = () => { | ||||
|           <Form.Field> | ||||
|             <Form.Dropdown | ||||
|               label='模型' | ||||
|               placeholder={'请选择该渠道所支持的模型'} | ||||
|               placeholder={'请选择该通道所支持的模型'} | ||||
|               name='models' | ||||
|               required | ||||
|               fluid | ||||
| @@ -295,37 +263,11 @@ const EditChannel = () => { | ||||
|             <Button type={'button'} onClick={() => { | ||||
|               handleInputChange(null, { name: 'models', value: [] }); | ||||
|             }}>清除所有模型</Button> | ||||
|             <Input | ||||
|               action={ | ||||
|                 <Button type={'button'} onClick={() => { | ||||
|                   if (customModel.trim() === '') return; | ||||
|                   if (inputs.models.includes(customModel)) return; | ||||
|                   let localModels = [...inputs.models]; | ||||
|                   localModels.push(customModel); | ||||
|                   let localModelOptions = []; | ||||
|                   localModelOptions.push({ | ||||
|                     key: customModel, | ||||
|                     text: customModel, | ||||
|                     value: customModel | ||||
|                   }); | ||||
|                   setModelOptions(modelOptions => { | ||||
|                     return [...modelOptions, ...localModelOptions]; | ||||
|                   }); | ||||
|                   setCustomModel(''); | ||||
|                   handleInputChange(null, { name: 'models', value: localModels }); | ||||
|                 }}>填入</Button> | ||||
|               } | ||||
|               placeholder='输入自定义模型名称' | ||||
|               value={customModel} | ||||
|               onChange={(e, { value }) => { | ||||
|                 setCustomModel(value); | ||||
|               }} | ||||
|             /> | ||||
|           </div> | ||||
|           <Form.Field> | ||||
|             <Form.TextArea | ||||
|               label='模型重定向' | ||||
|               placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`} | ||||
|               label='模型映射' | ||||
|               placeholder={`此项可选,为一个 JSON 文本,键为用户请求的模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`} | ||||
|               name='model_mapping' | ||||
|               onChange={handleInputChange} | ||||
|               value={inputs.model_mapping} | ||||
| @@ -350,7 +292,7 @@ const EditChannel = () => { | ||||
|                 label='密钥' | ||||
|                 name='key' | ||||
|                 required | ||||
|                 placeholder={inputs.type === 15 ? '请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} | ||||
|                 placeholder={'请输入密钥'} | ||||
|                 onChange={handleInputChange} | ||||
|                 value={inputs.key} | ||||
|                 autoComplete='new-password' | ||||
| @@ -367,21 +309,7 @@ const EditChannel = () => { | ||||
|               /> | ||||
|             ) | ||||
|           } | ||||
|           { | ||||
|             inputs.type !== 3 && inputs.type !== 8 && ( | ||||
|               <Form.Field> | ||||
|                 <Form.Input | ||||
|                   label='镜像' | ||||
|                   name='base_url' | ||||
|                   placeholder={'此项可选,用于通过镜像站来进行 API 调用,请输入镜像站地址,格式为:https://domain.com'} | ||||
|                   onChange={handleInputChange} | ||||
|                   value={inputs.base_url} | ||||
|                   autoComplete='new-password' | ||||
|                 /> | ||||
|               </Form.Field> | ||||
|             ) | ||||
|           } | ||||
|           <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button> | ||||
|           <Button positive onClick={submit}>提交</Button> | ||||
|         </Form> | ||||
|       </Segment> | ||||
|     </> | ||||
|   | ||||
| @@ -83,7 +83,7 @@ const EditToken = () => { | ||||
|       if (isEdit) { | ||||
|         showSuccess('令牌更新成功!'); | ||||
|       } else { | ||||
|         showSuccess('令牌创建成功,请在列表页面点击复制获取令牌!'); | ||||
|         showSuccess('令牌创建成功!'); | ||||
|         setInputs(originInputs); | ||||
|       } | ||||
|     } else { | ||||
|   | ||||
| @@ -7,32 +7,24 @@ const TopUp = () => { | ||||
|   const [redemptionCode, setRedemptionCode] = useState(''); | ||||
|   const [topUpLink, setTopUpLink] = useState(''); | ||||
|   const [userQuota, setUserQuota] = useState(0); | ||||
|   const [isSubmitting, setIsSubmitting] = useState(false); | ||||
|  | ||||
|   const topUp = async () => { | ||||
|     if (redemptionCode === '') { | ||||
|       showInfo('请输入充值码!') | ||||
|       return; | ||||
|     } | ||||
|     setIsSubmitting(true); | ||||
|     try { | ||||
|       const res = await API.post('/api/user/topup', { | ||||
|         key: redemptionCode | ||||
|     const res = await API.post('/api/user/topup', { | ||||
|       key: redemptionCode | ||||
|     }); | ||||
|     const { success, message, data } = res.data; | ||||
|     if (success) { | ||||
|       showSuccess('充值成功!'); | ||||
|       setUserQuota((quota) => { | ||||
|         return quota + data; | ||||
|       }); | ||||
|       const { success, message, data } = res.data; | ||||
|       if (success) { | ||||
|         showSuccess('充值成功!'); | ||||
|         setUserQuota((quota) => { | ||||
|           return quota + data; | ||||
|         }); | ||||
|         setRedemptionCode(''); | ||||
|       } else { | ||||
|         showError(message); | ||||
|       } | ||||
|     } catch (err) { | ||||
|       showError('请求失败'); | ||||
|     } finally { | ||||
|       setIsSubmitting(false);  | ||||
|       setRedemptionCode(''); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
| @@ -82,8 +74,8 @@ const TopUp = () => { | ||||
|             <Button color='green' onClick={openTopUpLink}> | ||||
|               获取兑换码 | ||||
|             </Button> | ||||
|             <Button color='yellow' onClick={topUp} disabled={isSubmitting}> | ||||
|                 {isSubmitting ? '兑换中...' : '兑换'} | ||||
|             <Button color='yellow' onClick={topUp}> | ||||
|               充值 | ||||
|             </Button> | ||||
|           </Form> | ||||
|         </Grid.Column> | ||||
| @@ -100,4 +92,5 @@ const TopUp = () => { | ||||
|   ); | ||||
| }; | ||||
|  | ||||
| export default TopUp; | ||||
|  | ||||
| export default TopUp; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user