mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-26 11:23:43 +08:00 
			
		
		
		
	Compare commits
	
		
			29 Commits
		
	
	
		
			0.0.2
			...
			v0.5.2-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | b464e2907a | ||
|  | d96cf2e84d | ||
|  | 446337c329 | ||
|  | 1dfa190e79 | ||
|  | 2d49ca6a07 | ||
|  | 89bcaaf989 | ||
|  | afcd1bd27b | ||
|  | c2c455c980 | ||
|  | 30a7f1a1c7 | ||
|  | c9d2e42a9e | ||
|  | 3fca6ff534 | ||
|  | 8cbbeb784f | ||
|  | ec88c0c240 | ||
|  | 065147b440 | ||
|  | fe8f216dd9 | ||
|  | b7d0616ae0 | ||
|  | ce9c8024a6 | ||
|  | 8a866078b2 | ||
|  | 3e81d8af45 | ||
|  | b8cb86c2c1 | ||
|  | f45d586400 | ||
|  | 50dec03ff3 | ||
|  | f31d400b6f | ||
|  | 130e6bfd83 | ||
|  | d1335ebc01 | ||
|  | e92da7928b | ||
|  | d1b6f492b6 | ||
|  | b9f6461dd4 | ||
|  | 0a39521a3d | 
							
								
								
									
										2
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							| @@ -38,7 +38,7 @@ jobs: | ||||
|         uses: docker/metadata-action@v4 | ||||
|         with: | ||||
|           images: | | ||||
|             ckt1031/one-api-en | ||||
|             justsong/one-api-en | ||||
|  | ||||
|       - name: Build and push Docker images | ||||
|         uses: docker/build-push-action@v3 | ||||
|   | ||||
							
								
								
									
										2
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -42,7 +42,7 @@ jobs: | ||||
|         uses: docker/metadata-action@v4 | ||||
|         with: | ||||
|           images: | | ||||
|             ckt1031/one-api | ||||
|             justsong/one-api | ||||
|             ghcr.io/${{ github.repository }} | ||||
|  | ||||
|       - name: Build and push Docker images | ||||
|   | ||||
							
								
								
									
										2
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -49,7 +49,7 @@ jobs: | ||||
|         uses: docker/metadata-action@v4 | ||||
|         with: | ||||
|           images: | | ||||
|             ckt1031/one-api | ||||
|             justsong/one-api | ||||
|             ghcr.io/${{ github.repository }} | ||||
|  | ||||
|       - name: Build and push Docker images | ||||
|   | ||||
							
								
								
									
										25
									
								
								README.en.md
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								README.en.md
									
									
									
									
									
								
							| @@ -57,15 +57,13 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use | ||||
| > **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability. | ||||
|  | ||||
| ## Features | ||||
| 1. Supports multiple API access channels: | ||||
|     + [x] Official OpenAI channel (support proxy configuration) | ||||
|     + [x] **Azure OpenAI API** | ||||
|     + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) | ||||
|     + [x] [OpenAI-SB](https://openai-sb.com) | ||||
|     + [x] [API2D](https://api2d.com/r/197971) | ||||
|     + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) | ||||
|     + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (invitation code: `OneAPI`) | ||||
|     + [x] Custom channel: Various third-party proxy services not included in the list | ||||
| 1. Support for multiple large models: | ||||
|    + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||
|    + [x] [Anthropic Claude Series Models](https://anthropic.com) | ||||
|    + [x] [Google PaLM2 Series Models](https://developers.generativeai.google) | ||||
|    + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||
|    + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) | ||||
|    + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) | ||||
| 2. Supports access to multiple channels through **load balancing**. | ||||
| 3. Supports **stream mode** that enables typewriter-like effect through stream transmission. | ||||
| 4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details. | ||||
| @@ -139,7 +137,7 @@ The initial account username is `root` and password is `123456`. | ||||
|    cd one-api/web | ||||
|    npm install | ||||
|    npm run build | ||||
|  | ||||
|     | ||||
|    # Build the backend | ||||
|    cd .. | ||||
|    go mod download | ||||
| @@ -175,7 +173,12 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co | ||||
| <summary><strong>Deploy on Sealos</strong></summary> | ||||
| <div> | ||||
|  | ||||
| Please refer to [this tutorial](https://github.com/c121914yu/FastGPT/blob/main/docs/deploy/one-api/sealos.md). | ||||
| > Sealos supports high concurrency, dynamic scaling, and stable operations for millions of users. | ||||
|  | ||||
| > Click the button below to deploy with one click.👇 | ||||
|  | ||||
| [](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) | ||||
|  | ||||
|  | ||||
| </div> | ||||
| </details> | ||||
|   | ||||
							
								
								
									
										14
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								README.md
									
									
									
									
									
								
							| @@ -63,9 +63,10 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|    + [x] [Anthropic Claude 系列模型](https://anthropic.com) | ||||
|    + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) | ||||
|    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) | ||||
|    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) | ||||
|    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||
|    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) | ||||
| 2. 支持配置镜像以及众多第三方代理服务: | ||||
|    + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) | ||||
|    + [x] [OpenAI-SB](https://openai-sb.com) | ||||
|    + [x] [API2D](https://api2d.com/r/197971) | ||||
|    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) | ||||
| @@ -93,7 +94,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
| 19. 支持通过系统访问令牌访问管理 API。 | ||||
| 20. 支持 Cloudflare Turnstile 用户校验。 | ||||
| 21. 支持用户管理,支持**多种用户登录注册方式**: | ||||
|     + 邮箱登录注册以及通过邮箱进行密码重置。 | ||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||
|  | ||||
| @@ -152,7 +153,7 @@ sudo service nginx restart | ||||
|    cd one-api/web | ||||
|    npm install | ||||
|    npm run build | ||||
|  | ||||
|     | ||||
|    # 构建后端 | ||||
|    cd .. | ||||
|    go mod download | ||||
| @@ -210,9 +211,11 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | ||||
| <summary><strong>部署到 Sealos </strong></summary> | ||||
| <div> | ||||
|  | ||||
| > Sealos 可视化部署,仅需 1 分钟。 | ||||
| > Sealos 的服务器在国外,不需要额外处理网络问题,支持高并发 & 动态伸缩。 | ||||
|  | ||||
| 参考这个[教程](https://github.com/c121914yu/FastGPT/blob/main/docs/deploy/one-api/sealos.md)中 1~5 步。 | ||||
| 点击以下按钮一键部署: | ||||
|  | ||||
| [](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) | ||||
|  | ||||
| </div> | ||||
| </details> | ||||
| @@ -313,6 +316,7 @@ https://openai.justsong.cn | ||||
|    + 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率) | ||||
|    + 其中补全倍率对于 GPT3.5 固定为 1.33,GPT4 为 2,与官方保持一致。 | ||||
|    + 如果是非流模式,官方接口会返回消耗的总 token,但是你要注意提示和补全的消耗倍率不一样。 | ||||
|    + 注意,One API 的默认倍率就是官方倍率,是已经调整过的。 | ||||
| 2. 账户额度足够为什么提示额度不足? | ||||
|    + 请检查你的令牌额度是否足够,这个和账户额度是分开的。 | ||||
|    + 令牌额度仅供用户设置最大使用量,用户可自由设置。 | ||||
|   | ||||
| @@ -38,12 +38,23 @@ var PasswordLoginEnabled = true | ||||
| var PasswordRegisterEnabled = true | ||||
| var EmailVerificationEnabled = false | ||||
| var GitHubOAuthEnabled = false | ||||
| var DiscordOAuthEnabled = false | ||||
| var WeChatAuthEnabled = false | ||||
| var GoogleOAuthEnabled = false | ||||
| var TurnstileCheckEnabled = false | ||||
| var RegisterEnabled = true | ||||
|  | ||||
| var EmailDomainRestrictionEnabled = false | ||||
| var EmailDomainWhitelist = []string{ | ||||
| 	"gmail.com", | ||||
| 	"163.com", | ||||
| 	"126.com", | ||||
| 	"qq.com", | ||||
| 	"outlook.com", | ||||
| 	"hotmail.com", | ||||
| 	"icloud.com", | ||||
| 	"yahoo.com", | ||||
| 	"foxmail.com", | ||||
| } | ||||
|  | ||||
| var LogConsumeEnabled = true | ||||
|  | ||||
| var SMTPServer = "" | ||||
| @@ -55,16 +66,10 @@ var SMTPToken = "" | ||||
| var GitHubClientId = "" | ||||
| var GitHubClientSecret = "" | ||||
|  | ||||
| var DiscordClientId = "" | ||||
| var DiscordClientSecret = "" | ||||
|  | ||||
| var WeChatServerAddress = "" | ||||
| var WeChatServerToken = "" | ||||
| var WeChatAccountQRCodeImageURL = "" | ||||
|  | ||||
| var GoogleClientId = "" | ||||
| var GoogleClientSecret = "" | ||||
|  | ||||
| var TurnstileSiteKey = "" | ||||
| var TurnstileSecretKey = "" | ||||
|  | ||||
| @@ -146,16 +151,6 @@ const ( | ||||
| 	ChannelStatusDisabled = 2 // also don't use 0 | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	ChannelAllowNonStreamEnabled  = 1 | ||||
| 	ChannelAllowNonStreamDisabled = 2 | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	ChannelAllowStreamEnabled  = 1 | ||||
| 	ChannelAllowStreamDisabled = 2 | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	ChannelTypeUnknown   = 0 | ||||
| 	ChannelTypeOpenAI    = 1 | ||||
| @@ -174,24 +169,28 @@ const ( | ||||
| 	ChannelTypeAnthropic = 14 | ||||
| 	ChannelTypeBaidu     = 15 | ||||
| 	ChannelTypeZhipu     = 16 | ||||
| 	ChannelTypeAli       = 17 | ||||
| 	ChannelTypeXunfei    = 18 | ||||
| ) | ||||
|  | ||||
| var ChannelBaseURLs = []string{ | ||||
| 	"",                              // 0 | ||||
| 	"https://api.openai.com",        // 1 | ||||
| 	"https://oa.api2d.net",          // 2 | ||||
| 	"",                              // 3 | ||||
| 	"https://api.closeai-proxy.xyz", // 4 | ||||
| 	"https://api.openai-sb.com",     // 5 | ||||
| 	"https://api.openaimax.com",     // 6 | ||||
| 	"https://api.ohmygpt.com",       // 7 | ||||
| 	"",                              // 8 | ||||
| 	"https://api.caipacity.com",     // 9 | ||||
| 	"https://api.aiproxy.io",        // 10 | ||||
| 	"",                              // 11 | ||||
| 	"https://api.api2gpt.com",       // 12 | ||||
| 	"https://api.aigc2d.com",        // 13 | ||||
| 	"https://api.anthropic.com",     // 14 | ||||
| 	"https://aip.baidubce.com",      // 15 | ||||
| 	"https://open.bigmodel.cn",      // 16 | ||||
| 	"",                               // 0 | ||||
| 	"https://api.openai.com",         // 1 | ||||
| 	"https://oa.api2d.net",           // 2 | ||||
| 	"",                               // 3 | ||||
| 	"https://api.closeai-proxy.xyz",  // 4 | ||||
| 	"https://api.openai-sb.com",      // 5 | ||||
| 	"https://api.openaimax.com",      // 6 | ||||
| 	"https://api.ohmygpt.com",        // 7 | ||||
| 	"",                               // 8 | ||||
| 	"https://api.caipacity.com",      // 9 | ||||
| 	"https://api.aiproxy.io",         // 10 | ||||
| 	"",                               // 11 | ||||
| 	"https://api.api2gpt.com",        // 12 | ||||
| 	"https://api.aigc2d.com",         // 13 | ||||
| 	"https://api.anthropic.com",      // 14 | ||||
| 	"https://aip.baidubce.com",       // 15 | ||||
| 	"https://open.bigmodel.cn",       // 16 | ||||
| 	"https://dashscope.aliyuncs.com", // 17 | ||||
| 	"",                               // 18 | ||||
| } | ||||
|   | ||||
| @@ -42,10 +42,14 @@ var ModelRatio = map[string]float64{ | ||||
| 	"claude-2":                30, | ||||
| 	"ERNIE-Bot":               0.8572, // ¥0.012 / 1k tokens | ||||
| 	"ERNIE-Bot-turbo":         0.5715, // ¥0.008 / 1k tokens | ||||
| 	"Embedding-V1":            0.1429, // ¥0.002 / 1k tokens | ||||
| 	"PaLM-2":                  1, | ||||
| 	"chatglm_pro":             0.7143, // ¥0.01 / 1k tokens | ||||
| 	"chatglm_std":             0.3572, // ¥0.005 / 1k tokens | ||||
| 	"chatglm_lite":            0.1429, // ¥0.002 / 1k tokens | ||||
| 	"qwen-v1":                 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag | ||||
| 	"qwen-plus-v1":            0.5715, // Same as above | ||||
| 	"SparkDesk":               0.8572, // TBD | ||||
| } | ||||
|  | ||||
| func ModelRatio2JSONString() string { | ||||
|   | ||||
| @@ -1,20 +1,17 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { | ||||
| @@ -26,6 +23,8 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr | ||||
| 	case common.ChannelTypeBaidu: | ||||
| 		fallthrough | ||||
| 	case common.ChannelTypeZhipu: | ||||
| 		fallthrough | ||||
| 	case common.ChannelTypeXunfei: | ||||
| 		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil | ||||
| 	case common.ChannelTypeAzure: | ||||
| 		request.Model = "gpt-35-turbo" | ||||
| @@ -61,83 +60,21 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr | ||||
| 		return err, nil | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||
|  | ||||
| 	if channel.AllowStreaming == common.ChannelAllowStreamEnabled && isStream { | ||||
| 		responseText := "" | ||||
| 		scanner := bufio.NewScanner(resp.Body) | ||||
| 		scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 			if atEOF && len(data) == 0 { | ||||
| 				return 0, nil, nil | ||||
| 			} | ||||
| 			if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 				return i + 1, data[0:i], nil | ||||
| 			} | ||||
| 			if atEOF { | ||||
| 				return len(data), data, nil | ||||
| 			} | ||||
| 			return 0, nil, nil | ||||
| 		}) | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 6 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			// ChatGPT Next Web | ||||
| 			if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") { | ||||
| 				// Remove event: event in the front or back | ||||
| 				data = strings.TrimPrefix(data, "event: event") | ||||
| 				data = strings.TrimSuffix(data, "event: event") | ||||
| 				// Remove everything, only keep `data: {...}` <--- this is the json | ||||
| 				// Find the start and end indices of `data: {...}` substring | ||||
| 				startIndex := strings.Index(data, "data:") | ||||
| 				endIndex := strings.LastIndex(data, "}") | ||||
|  | ||||
| 				// If both indices are found and end index is greater than start index | ||||
| 				if startIndex != -1 && endIndex != -1 && endIndex > startIndex { | ||||
| 					// Extract the `data: {...}` substring | ||||
| 					data = data[startIndex : endIndex+1] | ||||
| 				} | ||||
| 			} | ||||
| 			if !strings.HasPrefix(data, "data:") { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[6:] | ||||
| 			if !strings.HasPrefix(data, "[DONE]") { | ||||
| 				var streamResponse ChatCompletionsStreamResponse | ||||
| 				err := json.Unmarshal([]byte(data), &streamResponse) | ||||
| 				if err != nil { | ||||
| 					return err, nil | ||||
| 				} | ||||
| 				for _, choice := range streamResponse.Choices { | ||||
| 					responseText += choice.Delta.Content | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		if responseText == "" { | ||||
| 			return errors.New("Empty response"), nil | ||||
| 		} | ||||
| 	} else { | ||||
| 		var response TextResponse | ||||
| 		err = json.NewDecoder(resp.Body).Decode(&response) | ||||
| 		if err != nil { | ||||
| 			return err, nil | ||||
| 		} | ||||
| 		if response.Usage.CompletionTokens == 0 { | ||||
| 			return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error | ||||
| 		} | ||||
| 	var response TextResponse | ||||
| 	err = json.NewDecoder(resp.Body).Decode(&response) | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 	} | ||||
| 	if response.Usage.CompletionTokens == 0 { | ||||
| 		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error | ||||
| 	} | ||||
|  | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func buildTestRequest(stream bool) *ChatRequest { | ||||
| func buildTestRequest() *ChatRequest { | ||||
| 	testRequest := &ChatRequest{ | ||||
| 		Model:     "", // this will be set later | ||||
| 		MaxTokens: 1, | ||||
| 		Stream:    stream, | ||||
| 	} | ||||
| 	testMessage := Message{ | ||||
| 		Role:    "user", | ||||
| @@ -164,7 +101,7 @@ func TestChannel(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	testRequest := buildTestRequest(channel.AllowStreaming == common.ChannelAllowStreamEnabled) | ||||
| 	testRequest := buildTestRequest() | ||||
| 	tik := time.Now() | ||||
| 	err, _ = testChannel(channel, *testRequest) | ||||
| 	tok := time.Now() | ||||
| @@ -219,6 +156,7 @@ func testAllChannels(notify bool) error { | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	testRequest := buildTestRequest() | ||||
| 	var disableThreshold = int64(common.ChannelDisableThreshold * 1000) | ||||
| 	if disableThreshold == 0 { | ||||
| 		disableThreshold = 10000000 // a impossible value | ||||
| @@ -229,7 +167,6 @@ func testAllChannels(notify bool) error { | ||||
| 				continue | ||||
| 			} | ||||
| 			tik := time.Now() | ||||
| 			testRequest := buildTestRequest(channel.AllowStreaming == common.ChannelAllowStreamEnabled) | ||||
| 			err, openaiErr := testChannel(channel, *testRequest) | ||||
| 			tok := time.Now() | ||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | ||||
|   | ||||
| @@ -1,13 +1,12 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func GetAllChannels(c *gin.Context) { | ||||
|   | ||||
| @@ -1,223 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
|  | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type DiscordOAuthResponse struct { | ||||
| 	AccessToken string `json:"access_token"` | ||||
| 	Scope       string `json:"scope"` | ||||
| 	TokenType   string `json:"token_type"` | ||||
| } | ||||
|  | ||||
| type DiscordUser struct { | ||||
| 	Id       string `json:"id"` | ||||
| 	Username string `json:"username"` | ||||
| } | ||||
|  | ||||
| func getDiscordUserInfoByCode(codeFromURLParamaters string, host string) (*DiscordUser, error) { | ||||
| 	if codeFromURLParamaters == "" { | ||||
| 		return nil, errors.New("无效参数") | ||||
| 	} | ||||
|  | ||||
| 	RequestClient := &http.Client{} | ||||
|  | ||||
| 	accessTokenBody := bytes.NewBuffer([]byte(fmt.Sprintf( | ||||
| 		"client_id=%s&client_secret=%s&grant_type=authorization_code&redirect_uri=https://%s/oauth/discord&code=%s&scope=identify", | ||||
| 		common.DiscordClientId, common.DiscordClientSecret, host, codeFromURLParamaters, | ||||
| 	))) | ||||
|  | ||||
| 	req, _ := http.NewRequest("POST", | ||||
| 		"https://discordapp.com/api/oauth2/token", | ||||
| 		accessTokenBody, | ||||
| 	) | ||||
|  | ||||
| 	req.Header = http.Header{ | ||||
| 		"Content-Type": []string{"application/x-www-form-urlencoded"}, | ||||
| 		"Accept":       []string{"application/json"}, | ||||
| 	} | ||||
|  | ||||
| 	resp, err := RequestClient.Do(req) | ||||
|  | ||||
| 	if resp.StatusCode != 200 || err != nil { | ||||
| 		return nil, errors.New("访问令牌无效") | ||||
| 	} | ||||
|  | ||||
| 	var discordOAuthResponse DiscordOAuthResponse | ||||
|  | ||||
| 	json.NewDecoder(resp.Body).Decode(&discordOAuthResponse) | ||||
|  | ||||
| 	accessToken := fmt.Sprintf("Bearer %s", discordOAuthResponse.AccessToken) | ||||
|  | ||||
| 	// Get User Info | ||||
| 	req, _ = http.NewRequest("GET", "https://discord.com/api/users/@me", nil) | ||||
|  | ||||
| 	req.Header = http.Header{ | ||||
| 		"Content-Type":  []string{"application/json"}, | ||||
| 		"Authorization": []string{accessToken}, | ||||
| 	} | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	resp, err = RequestClient.Do(req) | ||||
|  | ||||
| 	if resp.StatusCode != 200 || err != nil { | ||||
| 		return nil, errors.New("Discord 用户信息无效") | ||||
| 	} | ||||
|  | ||||
| 	var discordUser DiscordUser | ||||
|  | ||||
| 	json.NewDecoder(resp.Body).Decode(&discordUser) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if discordUser.Id == "" { | ||||
| 		return nil, errors.New("返回值无效,用户字段为空,请稍后再试!") | ||||
| 	} | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	return &discordUser, nil | ||||
| } | ||||
|  | ||||
| func DiscordOAuth(c *gin.Context) { | ||||
| 	session := sessions.Default(c) | ||||
| 	username := session.Get("username") | ||||
| 	if username != nil { | ||||
| 		DiscordBind(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if !common.DiscordOAuthEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "管理员未开启通过 Discord 登录以及注册", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	code := c.Query("code") | ||||
|  | ||||
| 	discordUser, err := getDiscordUserInfoByCode(code, c.Request.Host) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user := model.User{ | ||||
| 		DiscordId: discordUser.Id, | ||||
| 	} | ||||
| 	if model.IsDiscordIdAlreadyTaken(user.DiscordId) { | ||||
| 		err := user.FillUserByDiscordId() | ||||
| 		if err != nil { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": err.Error(), | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} else { | ||||
| 		if common.RegisterEnabled { | ||||
| 			user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1) | ||||
| 			if discordUser.Username != "" { | ||||
| 				user.DisplayName = discordUser.Username | ||||
| 			} else { | ||||
| 				user.DisplayName = "Discord User" | ||||
| 			} | ||||
| 			user.Role = common.RoleCommonUser | ||||
| 			user.Status = common.UserStatusEnabled | ||||
|  | ||||
| 			if err := user.Insert(0); err != nil { | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| 					"success": false, | ||||
| 					"message": err.Error(), | ||||
| 				}) | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "管理员关闭了新用户注册", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if user.Status != common.UserStatusEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "用户已被封禁", | ||||
| 			"success": false, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	setupLogin(&user, c) | ||||
| } | ||||
|  | ||||
| func DiscordBind(c *gin.Context) { | ||||
| 	if !common.DiscordOAuthEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "管理员未开启通过 Discord 登录以及注册", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	code := c.Query("code") | ||||
|  | ||||
| 	discordUser, err := getDiscordUserInfoByCode(code, c.Request.Host) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user := model.User{ | ||||
| 		DiscordId: discordUser.Id, | ||||
| 	} | ||||
| 	if model.IsDiscordIdAlreadyTaken(user.DiscordId) { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "该 Discord 账户已被绑定", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	session := sessions.Default(c) | ||||
| 	id := session.Get("id") | ||||
| 	// id := c.GetInt("id")  // critical bug! | ||||
| 	user.Id = id.(int) | ||||
| 	err = user.FillUserById() | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user.DiscordId = discordUser.Id | ||||
| 	err = user.Update(false) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "bind", | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
| @@ -1,226 +0,0 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strconv" | ||||
|  | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type GoogleAccessTokenResponse struct { | ||||
| 	AccessToken  string `json:"access_token"` | ||||
| 	ExpiresIn    int    `json:"expires_in"` | ||||
| 	TokenType    string `json:"token_type"` | ||||
| 	Scope        string `json:"scope"` | ||||
| 	RefreshToken string `json:"refresh_token"` | ||||
| } | ||||
|  | ||||
| type GoogleUser struct { | ||||
| 	Sub  string `json:"sub"` | ||||
| 	Name string `json:"name"` | ||||
| } | ||||
|  | ||||
| func getGoogleUserInfoByCode(codeFromURLParamaters string, host string) (*GoogleUser, error) { | ||||
| 	if codeFromURLParamaters == "" { | ||||
| 		return nil, errors.New("无效参数") | ||||
| 	} | ||||
|  | ||||
| 	RequestClient := &http.Client{} | ||||
|  | ||||
| 	accessTokenBody := bytes.NewBuffer([]byte(fmt.Sprintf( | ||||
| 		"code=%s&client_id=%s&client_secret=%s&redirect_uri=https://%s/oauth/google&grant_type=authorization_code", | ||||
| 		codeFromURLParamaters, common.GoogleClientId, common.GoogleClientSecret, host, | ||||
| 	))) | ||||
|  | ||||
| 	req, _ := http.NewRequest("POST", | ||||
| 		"https://oauth2.googleapis.com/token", | ||||
| 		accessTokenBody, | ||||
| 	) | ||||
|  | ||||
| 	req.Header = http.Header{ | ||||
| 		"Content-Type": []string{"application/x-www-form-urlencoded"}, | ||||
| 		"Accept":       []string{"application/json"}, | ||||
| 	} | ||||
|  | ||||
| 	resp, err := RequestClient.Do(req) | ||||
|  | ||||
| 	if resp.StatusCode != 200 || err != nil { | ||||
| 		return nil, errors.New("访问令牌无效") | ||||
| 	} | ||||
|  | ||||
| 	var googleTokenResponse GoogleAccessTokenResponse | ||||
|  | ||||
| 	json.NewDecoder(resp.Body).Decode(&googleTokenResponse) | ||||
|  | ||||
| 	accessToken := "Bearer " + googleTokenResponse.AccessToken | ||||
|  | ||||
| 	// Get User Info | ||||
| 	req, _ = http.NewRequest("GET", "https://www.googleapis.com/oauth2/v3/userinfo", nil) | ||||
|  | ||||
| 	req.Header = http.Header{ | ||||
| 		"Content-Type":  []string{"application/json"}, | ||||
| 		"Authorization": []string{accessToken}, | ||||
| 	} | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	resp, err = RequestClient.Do(req) | ||||
|  | ||||
| 	if resp.StatusCode != 200 || err != nil { | ||||
| 		return nil, errors.New("Google 用户信息无效") | ||||
| 	} | ||||
|  | ||||
| 	var googleUser GoogleUser | ||||
|  | ||||
| 	// Parse json to googleUser | ||||
| 	err = json.NewDecoder(resp.Body).Decode(&googleUser) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if googleUser.Sub == "" { | ||||
| 		return nil, errors.New("返回值无效,用户字段为空,请稍后再试!") | ||||
| 	} | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	return &googleUser, nil | ||||
| } | ||||
|  | ||||
| func GoogleOAuth(c *gin.Context) { | ||||
| 	session := sessions.Default(c) | ||||
| 	username := session.Get("username") | ||||
| 	if username != nil { | ||||
| 		GoogleBind(c) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if !common.GoogleOAuthEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "管理员未开启通过 Google 登录以及注册", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	code := c.Query("code") | ||||
|  | ||||
| 	googleUser, err := getGoogleUserInfoByCode(code, c.Request.Host) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user := model.User{ | ||||
| 		GoogleId: googleUser.Sub, | ||||
| 	} | ||||
| 	if model.IsGoogleIdAlreadyTaken(user.GoogleId) { | ||||
| 		err := user.FillUserByGoogleId() | ||||
| 		if err != nil { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": err.Error(), | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} else { | ||||
| 		if common.RegisterEnabled { | ||||
| 			user.Username = "google_" + strconv.Itoa(model.GetMaxUserId()+1) | ||||
| 			if googleUser.Name != "" { | ||||
| 				user.DisplayName = googleUser.Name | ||||
| 			} else { | ||||
| 				user.DisplayName = "Google User" | ||||
| 			} | ||||
| 			user.Role = common.RoleCommonUser | ||||
| 			user.Status = common.UserStatusEnabled | ||||
|  | ||||
| 			if err := user.Insert(0); err != nil { | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| 					"success": false, | ||||
| 					"message": err.Error(), | ||||
| 				}) | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "管理员关闭了新用户注册", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if user.Status != common.UserStatusEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "用户已被封禁", | ||||
| 			"success": false, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	setupLogin(&user, c) | ||||
| } | ||||
|  | ||||
| func GoogleBind(c *gin.Context) { | ||||
| 	if !common.GoogleOAuthEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "管理员未开启通过 Google 登录以及注册", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	code := c.Query("code") | ||||
|  | ||||
| 	googleUser, err := getGoogleUserInfoByCode(code, c.Request.Host) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user := model.User{ | ||||
| 		GoogleId: googleUser.Sub, | ||||
| 	} | ||||
| 	if model.IsGoogleIdAlreadyTaken(user.GoogleId) { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "该 Google 账户已被绑定", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	session := sessions.Default(c) | ||||
| 	id := session.Get("id") | ||||
| 	// id := c.GetInt("id")  // critical bug! | ||||
| 	user.Id = id.(int) | ||||
| 	err = user.FillUserById() | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user.GoogleId = googleUser.Sub | ||||
| 	err = user.Update(false) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "bind", | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
| @@ -6,6 +6,7 @@ import ( | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
| @@ -20,10 +21,6 @@ func GetStatus(c *gin.Context) { | ||||
| 			"email_verification":  common.EmailVerificationEnabled, | ||||
| 			"github_oauth":        common.GitHubOAuthEnabled, | ||||
| 			"github_client_id":    common.GitHubClientId, | ||||
| 			"discord_oauth":       common.DiscordOAuthEnabled, | ||||
| 			"discord_client_id":   common.DiscordClientId, | ||||
| 			"google_oauth":        common.GoogleOAuthEnabled, | ||||
| 			"google_client_id":    common.GoogleClientId, | ||||
| 			"system_name":         common.SystemName, | ||||
| 			"logo":                common.Logo, | ||||
| 			"footer_html":         common.Footer, | ||||
| @@ -83,6 +80,22 @@ func SendEmailVerification(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	if common.EmailDomainRestrictionEnabled { | ||||
| 		allowed := false | ||||
| 		for _, domain := range common.EmailDomainWhitelist { | ||||
| 			if strings.HasSuffix(email, "@"+domain) { | ||||
| 				allowed = true | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 		if !allowed { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	if model.IsEmailAlreadyTaken(email) { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
|   | ||||
| @@ -288,6 +288,15 @@ func init() { | ||||
| 			Root:       "ERNIE-Bot-turbo", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "Embedding-V1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "baidu", | ||||
| 			Permission: permission, | ||||
| 			Root:       "Embedding-V1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "PaLM-2", | ||||
| 			Object:     "model", | ||||
| @@ -324,6 +333,33 @@ func init() { | ||||
| 			Root:       "chatglm_lite", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "qwen-v1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "ali", | ||||
| 			Permission: permission, | ||||
| 			Root:       "qwen-v1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "qwen-plus-v1", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "ali", | ||||
| 			Permission: permission, | ||||
| 			Root:       "qwen-plus-v1", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         "SparkDesk", | ||||
| 			Object:     "model", | ||||
| 			Created:    1677649963, | ||||
| 			OwnedBy:    "xunfei", | ||||
| 			Permission: permission, | ||||
| 			Root:       "SparkDesk", | ||||
| 			Parent:     nil, | ||||
| 		}, | ||||
| 	} | ||||
| 	openAIModelsMap = make(map[string]OpenAIModels) | ||||
| 	for _, model := range openAIModels { | ||||
|   | ||||
| @@ -50,11 +50,11 @@ func UpdateOption(c *gin.Context) { | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	case "DiscordOAuthEnabled": | ||||
| 		if option.Value == "true" && common.DiscordClientId == "" { | ||||
| 	case "EmailDomainRestrictionEnabled": | ||||
| 		if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "无法启用 Discord OAuth,请先填入 Discord Client ID 以及 Discord Client Secret!", | ||||
| 				"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| @@ -66,14 +66,6 @@ func UpdateOption(c *gin.Context) { | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	case "GoogleOAuthEnabled": | ||||
| 		if option.Value == "true" && common.GoogleClientId == "" { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "无法启用 Google OAuth,请先填入 Google Client ID 以及 Google Client Secret!", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	case "TurnstileCheckEnabled": | ||||
| 		if option.Value == "true" && common.TurnstileSiteKey == "" { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
|   | ||||
							
								
								
									
										240
									
								
								controller/relay-ali.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										240
									
								
								controller/relay-ali.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,240 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r | ||||
|  | ||||
| type AliMessage struct { | ||||
| 	User string `json:"user"` | ||||
| 	Bot  string `json:"bot"` | ||||
| } | ||||
|  | ||||
| type AliInput struct { | ||||
| 	Prompt  string       `json:"prompt"` | ||||
| 	History []AliMessage `json:"history"` | ||||
| } | ||||
|  | ||||
| type AliParameters struct { | ||||
| 	TopP         float64 `json:"top_p,omitempty"` | ||||
| 	TopK         int     `json:"top_k,omitempty"` | ||||
| 	Seed         uint64  `json:"seed,omitempty"` | ||||
| 	EnableSearch bool    `json:"enable_search,omitempty"` | ||||
| } | ||||
|  | ||||
| type AliChatRequest struct { | ||||
| 	Model      string        `json:"model"` | ||||
| 	Input      AliInput      `json:"input"` | ||||
| 	Parameters AliParameters `json:"parameters,omitempty"` | ||||
| } | ||||
|  | ||||
| type AliError struct { | ||||
| 	Code      string `json:"code"` | ||||
| 	Message   string `json:"message"` | ||||
| 	RequestId string `json:"request_id"` | ||||
| } | ||||
|  | ||||
| type AliUsage struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| } | ||||
|  | ||||
| type AliOutput struct { | ||||
| 	Text         string `json:"text"` | ||||
| 	FinishReason string `json:"finish_reason"` | ||||
| } | ||||
|  | ||||
| type AliChatResponse struct { | ||||
| 	Output AliOutput `json:"output"` | ||||
| 	Usage  AliUsage  `json:"usage"` | ||||
| 	AliError | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { | ||||
| 	messages := make([]AliMessage, 0, len(request.Messages)) | ||||
| 	prompt := "" | ||||
| 	for i := 0; i < len(request.Messages); i++ { | ||||
| 		message := request.Messages[i] | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, AliMessage{ | ||||
| 				User: message.Content, | ||||
| 				Bot:  "Okay", | ||||
| 			}) | ||||
| 			continue | ||||
| 		} else { | ||||
| 			if i == len(request.Messages)-1 { | ||||
| 				prompt = message.Content | ||||
| 				break | ||||
| 			} | ||||
| 			messages = append(messages, AliMessage{ | ||||
| 				User: message.Content, | ||||
| 				Bot:  request.Messages[i+1].Content, | ||||
| 			}) | ||||
| 			i++ | ||||
| 		} | ||||
| 	} | ||||
| 	return &AliChatRequest{ | ||||
| 		Model: request.Model, | ||||
| 		Input: AliInput{ | ||||
| 			Prompt:  prompt, | ||||
| 			History: messages, | ||||
| 		}, | ||||
| 		//Parameters: AliParameters{  // ChatGPT's parameters are not compatible with Ali's | ||||
| 		//	TopP: request.TopP, | ||||
| 		//	TopK: 50, | ||||
| 		//	//Seed:         0, | ||||
| 		//	//EnableSearch: false, | ||||
| 		//}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Output.Text, | ||||
| 		}, | ||||
| 		FinishReason: response.Output.FinishReason, | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Id:      response.RequestId, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Usage: Usage{ | ||||
| 			PromptTokens:     response.Usage.InputTokens, | ||||
| 			CompletionTokens: response.Usage.OutputTokens, | ||||
| 			TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens, | ||||
| 		}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = aliResponse.Output.Text | ||||
| 	choice.FinishReason = aliResponse.Output.FinishReason | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Id:      aliResponse.RequestId, | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "ernie-bot", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] != "data:" { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = data[5:] | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	lastResponseText := "" | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var aliResponse AliChatResponse | ||||
| 			err := json.Unmarshal([]byte(data), &aliResponse) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			usage.PromptTokens += aliResponse.Usage.InputTokens | ||||
| 			usage.CompletionTokens += aliResponse.Usage.OutputTokens | ||||
| 			usage.TotalTokens += aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | ||||
| 			response := streamResponseAli2OpenAI(&aliResponse) | ||||
| 			response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) | ||||
| 			lastResponseText = aliResponse.Output.Text | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var aliResponse AliChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &aliResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if aliResponse.Code != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: aliResponse.Message, | ||||
| 				Type:    aliResponse.Code, | ||||
| 				Param:   aliResponse.RequestId, | ||||
| 				Code:    aliResponse.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseAli2OpenAI(&aliResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
| @@ -54,13 +54,43 @@ type BaiduChatStreamResponse struct { | ||||
| 	IsEnd      bool `json:"is_end"` | ||||
| } | ||||
|  | ||||
| type BaiduEmbeddingRequest struct { | ||||
| 	Input []string `json:"input"` | ||||
| } | ||||
|  | ||||
| type BaiduEmbeddingData struct { | ||||
| 	Object    string    `json:"object"` | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| 	Index     int       `json:"index"` | ||||
| } | ||||
|  | ||||
| type BaiduEmbeddingResponse struct { | ||||
| 	Id      string               `json:"id"` | ||||
| 	Object  string               `json:"object"` | ||||
| 	Created int64                `json:"created"` | ||||
| 	Data    []BaiduEmbeddingData `json:"data"` | ||||
| 	Usage   Usage                `json:"usage"` | ||||
| 	BaiduError | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { | ||||
| 	messages := make([]BaiduMessage, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		messages = append(messages, BaiduMessage{ | ||||
| 			Role:    message.Role, | ||||
| 			Content: message.Content, | ||||
| 		}) | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, BaiduMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	return &BaiduChatRequest{ | ||||
| 		Messages: messages, | ||||
| @@ -101,6 +131,36 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { | ||||
| 	baiduEmbeddingRequest := BaiduEmbeddingRequest{ | ||||
| 		Input: nil, | ||||
| 	} | ||||
| 	switch request.Input.(type) { | ||||
| 	case string: | ||||
| 		baiduEmbeddingRequest.Input = []string{request.Input.(string)} | ||||
| 	case []string: | ||||
| 		baiduEmbeddingRequest.Input = request.Input.([]string) | ||||
| 	} | ||||
| 	return &baiduEmbeddingRequest | ||||
| } | ||||
|  | ||||
| func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { | ||||
| 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), | ||||
| 		Model:  "baidu-embedding", | ||||
| 		Usage:  response.Usage, | ||||
| 	} | ||||
| 	for _, item := range response.Data { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ | ||||
| 			Object:    item.Object, | ||||
| 			Index:     item.Index, | ||||
| 			Embedding: item.Embedding, | ||||
| 		}) | ||||
| 	} | ||||
| 	return &openAIEmbeddingResponse | ||||
| } | ||||
|  | ||||
| func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| @@ -201,3 +261,39 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|  | ||||
| func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var baiduResponse BaiduEmbeddingResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &baiduResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if baiduResponse.ErrorMsg != "" { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: baiduResponse.ErrorMsg, | ||||
| 				Type:    "baidu_error", | ||||
| 				Param:   "", | ||||
| 				Code:    baiduResponse.ErrorCode, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|   | ||||
| @@ -69,11 +69,11 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { | ||||
| 			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) | ||||
| 		} else if message.Role == "assistant" { | ||||
| 			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) | ||||
| 		} else { | ||||
| 			// ignore other roles | ||||
| 		} else if message.Role == "system" { | ||||
| 			prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) | ||||
| 		} | ||||
| 		prompt += "\n\nAssistant:" | ||||
| 	} | ||||
| 	prompt += "\n\nAssistant:" | ||||
| 	claudeRequest.Prompt = prompt | ||||
| 	return &claudeRequest | ||||
| } | ||||
|   | ||||
| @@ -4,12 +4,11 @@ import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { | ||||
| @@ -35,23 +34,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O | ||||
| 			if len(data) < 6 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			// ChatGPT Next Web | ||||
| 			if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") { | ||||
| 				// Remove event: event in the front or back | ||||
| 				data = strings.TrimPrefix(data, "event: event") | ||||
| 				data = strings.TrimSuffix(data, "event: event") | ||||
| 				// Remove everything, only keep `data: {...}` <--- this is the json | ||||
| 				// Find the start and end indices of `data: {...}` substring | ||||
| 				startIndex := strings.Index(data, "data:") | ||||
| 				endIndex := strings.LastIndex(data, "}") | ||||
|  | ||||
| 				// If both indices are found and end index is greater than start index | ||||
| 				if startIndex != -1 && endIndex != -1 && endIndex > startIndex { | ||||
| 					// Extract the `data: {...}` substring | ||||
| 					data = data[startIndex : endIndex+1] | ||||
| 				} | ||||
| 			} | ||||
| 			if !strings.HasPrefix(data, "data:") { | ||||
| 			if data[:6] != "data: " && data[:6] != "[DONE]" { | ||||
| 				continue | ||||
| 			} | ||||
| 			dataChan <- data | ||||
| @@ -63,7 +46,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O | ||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||
| 					if err != nil { | ||||
| 						common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 						return | ||||
| 						continue // just ignore the error | ||||
| 					} | ||||
| 					for _, choice := range streamResponse.Choices { | ||||
| 						responseText += choice.Delta.Content | ||||
| @@ -73,7 +56,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O | ||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||
| 					if err != nil { | ||||
| 						common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 						return | ||||
| 						continue | ||||
| 					} | ||||
| 					for _, choice := range streamResponse.Choices { | ||||
| 						responseText += choice.Text | ||||
| @@ -109,7 +92,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O | ||||
| 	return nil, responseText | ||||
| } | ||||
|  | ||||
| func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var textResponse TextResponse | ||||
| 	if consumeQuota { | ||||
| 		responseBody, err := io.ReadAll(resp.Body) | ||||
| @@ -149,5 +132,17 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*Ope | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
|  | ||||
| 	if textResponse.Usage.TotalTokens == 0 { | ||||
| 		completionTokens := 0 | ||||
| 		for _, choice := range textResponse.Choices { | ||||
| 			completionTokens += countTokenText(choice.Message.Content, model) | ||||
| 		} | ||||
| 		textResponse.Usage = Usage{ | ||||
| 			PromptTokens:     promptTokens, | ||||
| 			CompletionTokens: completionTokens, | ||||
| 			TotalTokens:      promptTokens + completionTokens, | ||||
| 		} | ||||
| 	} | ||||
| 	return nil, &textResponse.Usage | ||||
| } | ||||
|   | ||||
| @@ -20,6 +20,8 @@ const ( | ||||
| 	APITypePaLM | ||||
| 	APITypeBaidu | ||||
| 	APITypeZhipu | ||||
| 	APITypeAli | ||||
| 	APITypeXunfei | ||||
| ) | ||||
|  | ||||
| var httpClient *http.Client | ||||
| @@ -73,7 +75,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 	// map model name | ||||
| 	modelMapping := c.GetString("model_mapping") | ||||
| 	isModelMapped := false | ||||
| 	if modelMapping != "" { | ||||
| 	if modelMapping != "" && modelMapping != "{}" { | ||||
| 		modelMap := make(map[string]string) | ||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| 		if err != nil { | ||||
| @@ -94,6 +96,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 		apiType = APITypePaLM | ||||
| 	case common.ChannelTypeZhipu: | ||||
| 		apiType = APITypeZhipu | ||||
| 	case common.ChannelTypeAli: | ||||
| 		apiType = APITypeAli | ||||
| 	case common.ChannelTypeXunfei: | ||||
| 		apiType = APITypeXunfei | ||||
| 	} | ||||
| 	baseURL := common.ChannelBaseURLs[channelType] | ||||
| 	requestURL := c.Request.URL.String() | ||||
| @@ -135,6 +141,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" | ||||
| 		case "BLOOMZ-7B": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" | ||||
| 		case "Embedding-V1": | ||||
| 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" | ||||
| 		} | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| @@ -153,6 +161,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			method = "sse-invoke" | ||||
| 		} | ||||
| 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) | ||||
| 	case APITypeAli: | ||||
| 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" | ||||
| 	} | ||||
| 	var promptTokens int | ||||
| 	var completionTokens int | ||||
| @@ -206,12 +216,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	case APITypeBaidu: | ||||
| 		baiduRequest := requestOpenAI2Baidu(textRequest) | ||||
| 		jsonStr, err := json.Marshal(baiduRequest) | ||||
| 		var jsonData []byte | ||||
| 		var err error | ||||
| 		switch relayMode { | ||||
| 		case RelayModeEmbeddings: | ||||
| 			baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) | ||||
| 			jsonData, err = json.Marshal(baiduEmbeddingRequest) | ||||
| 		default: | ||||
| 			baiduRequest := requestOpenAI2Baidu(textRequest) | ||||
| 			jsonData, err = json.Marshal(baiduRequest) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 		requestBody = bytes.NewBuffer(jsonData) | ||||
| 	case APITypePaLM: | ||||
| 		palmRequest := requestOpenAI2PaLM(textRequest) | ||||
| 		jsonStr, err := json.Marshal(palmRequest) | ||||
| @@ -226,49 +244,68 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} | ||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	apiKey := c.Request.Header.Get("Authorization") | ||||
| 	apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 	switch apiType { | ||||
| 	case APITypeOpenAI: | ||||
| 		if channelType == common.ChannelTypeAzure { | ||||
| 			req.Header.Set("api-key", apiKey) | ||||
| 		} else { | ||||
| 			req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
| 	case APITypeAli: | ||||
| 		aliRequest := requestOpenAI2Ali(textRequest) | ||||
| 		jsonStr, err := json.Marshal(aliRequest) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 	case APITypeClaude: | ||||
| 		req.Header.Set("x-api-key", apiKey) | ||||
| 		anthropicVersion := c.Request.Header.Get("anthropic-version") | ||||
| 		if anthropicVersion == "" { | ||||
| 			anthropicVersion = "2023-06-01" | ||||
| 		requestBody = bytes.NewBuffer(jsonStr) | ||||
| 	} | ||||
|  | ||||
| 	var req *http.Request | ||||
| 	var resp *http.Response | ||||
| 	isStream := textRequest.Stream | ||||
|  | ||||
| 	if apiType != APITypeXunfei { // cause xunfei use websocket | ||||
| 		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		req.Header.Set("anthropic-version", anthropicVersion) | ||||
| 	case APITypeZhipu: | ||||
| 		token := getZhipuToken(apiKey) | ||||
| 		req.Header.Set("Authorization", token) | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
| 	//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | ||||
| 	resp, err := httpClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = req.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| 	err = c.Request.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		apiKey := c.Request.Header.Get("Authorization") | ||||
| 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||
| 		switch apiType { | ||||
| 		case APITypeOpenAI: | ||||
| 			if channelType == common.ChannelTypeAzure { | ||||
| 				req.Header.Set("api-key", apiKey) | ||||
| 			} else { | ||||
| 				req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||||
| 			} | ||||
| 		case APITypeClaude: | ||||
| 			req.Header.Set("x-api-key", apiKey) | ||||
| 			anthropicVersion := c.Request.Header.Get("anthropic-version") | ||||
| 			if anthropicVersion == "" { | ||||
| 				anthropicVersion = "2023-06-01" | ||||
| 			} | ||||
| 			req.Header.Set("anthropic-version", anthropicVersion) | ||||
| 		case APITypeZhipu: | ||||
| 			token := getZhipuToken(apiKey) | ||||
| 			req.Header.Set("Authorization", token) | ||||
| 		case APITypeAli: | ||||
| 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||
| 			if textRequest.Stream { | ||||
| 				req.Header.Set("X-DashScope-SSE", "enable") | ||||
| 			} | ||||
| 		} | ||||
| 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||
| 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||
| 		//req.Header.Set("Connection", c.Request.Header.Get("Connection")) | ||||
| 		resp, err = httpClient.Do(req) | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = req.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		err = c.Request.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||
| 	} | ||||
|  | ||||
| 	var textResponse TextResponse | ||||
| 	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||
| 	var streamResponseText string | ||||
|  | ||||
| 	defer func() { | ||||
| 		if consumeQuota { | ||||
| @@ -280,16 +317,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			if strings.HasPrefix(textRequest.Model, "gpt-4") { | ||||
| 				completionRatio = 2 | ||||
| 			} | ||||
| 			if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu { | ||||
| 				completionTokens = countTokenText(streamResponseText, textRequest.Model) | ||||
| 			} else { | ||||
| 				promptTokens = textResponse.Usage.PromptTokens | ||||
| 				completionTokens = textResponse.Usage.CompletionTokens | ||||
| 				if apiType == APITypeZhipu { | ||||
| 					// zhipu's API does not return prompt tokens & completion tokens | ||||
| 					promptTokens = textResponse.Usage.TotalTokens | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			promptTokens = textResponse.Usage.PromptTokens | ||||
| 			completionTokens = textResponse.Usage.CompletionTokens | ||||
|  | ||||
| 			quota = promptTokens + int(float64(completionTokens)*completionRatio) | ||||
| 			quota = int(float64(quota) * ratio) | ||||
| 			if ratio != 0 && quota <= 0 { | ||||
| @@ -327,10 +358,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			streamResponseText = responseText | ||||
| 			textResponse.Usage.PromptTokens = promptTokens | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := openaiHandler(c, resp, consumeQuota) | ||||
| 			err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| @@ -345,7 +377,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			streamResponseText = responseText | ||||
| 			textResponse.Usage.PromptTokens = promptTokens | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) | ||||
| @@ -368,7 +401,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := baiduHandler(c, resp) | ||||
| 			var err *OpenAIErrorWithStatusCode | ||||
| 			var usage *Usage | ||||
| 			switch relayMode { | ||||
| 			case RelayModeEmbeddings: | ||||
| 				err, usage = baiduEmbeddingHandler(c, resp) | ||||
| 			default: | ||||
| 				err, usage = baiduHandler(c, resp) | ||||
| 			} | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| @@ -383,7 +423,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			streamResponseText = responseText | ||||
| 			textResponse.Usage.PromptTokens = promptTokens | ||||
| 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) | ||||
| @@ -404,6 +445,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			// zhipu's API does not return prompt tokens & completion tokens | ||||
| 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := zhipuHandler(c, resp) | ||||
| @@ -413,8 +456,49 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			// zhipu's API does not return prompt tokens & completion tokens | ||||
| 			textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeAli: | ||||
| 		if isStream { | ||||
| 			err, usage := aliStreamHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			err, usage := aliHandler(c, resp) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 	case APITypeXunfei: | ||||
| 		if isStream { | ||||
| 			auth := c.Request.Header.Get("Authorization") | ||||
| 			auth = strings.TrimPrefix(auth, "Bearer ") | ||||
| 			splits := strings.Split(auth, "|") | ||||
| 			if len(splits) != 3 { | ||||
| 				return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) | ||||
| 			} | ||||
| 			err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			if usage != nil { | ||||
| 				textResponse.Usage = *usage | ||||
| 			} | ||||
| 			return nil | ||||
| 		} else { | ||||
| 			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) | ||||
| 		} | ||||
| 	default: | ||||
| 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | ||||
| 	} | ||||
|   | ||||
							
								
								
									
										278
									
								
								controller/relay-xunfei.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										278
									
								
								controller/relay-xunfei.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,278 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"crypto/hmac" | ||||
| 	"crypto/sha256" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // https://console.xfyun.cn/services/cbm | ||||
| // https://www.xfyun.cn/doc/spark/Web.html | ||||
|  | ||||
| type XunfeiMessage struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Content string `json:"content"` | ||||
| } | ||||
|  | ||||
| type XunfeiChatRequest struct { | ||||
| 	Header struct { | ||||
| 		AppId string `json:"app_id"` | ||||
| 	} `json:"header"` | ||||
| 	Parameter struct { | ||||
| 		Chat struct { | ||||
| 			Domain      string  `json:"domain,omitempty"` | ||||
| 			Temperature float64 `json:"temperature,omitempty"` | ||||
| 			TopK        int     `json:"top_k,omitempty"` | ||||
| 			MaxTokens   int     `json:"max_tokens,omitempty"` | ||||
| 			Auditing    bool    `json:"auditing,omitempty"` | ||||
| 		} `json:"chat"` | ||||
| 	} `json:"parameter"` | ||||
| 	Payload struct { | ||||
| 		Message struct { | ||||
| 			Text []XunfeiMessage `json:"text"` | ||||
| 		} `json:"message"` | ||||
| 	} `json:"payload"` | ||||
| } | ||||
|  | ||||
| type XunfeiChatResponseTextItem struct { | ||||
| 	Content string `json:"content"` | ||||
| 	Role    string `json:"role"` | ||||
| 	Index   int    `json:"index"` | ||||
| } | ||||
|  | ||||
| type XunfeiChatResponse struct { | ||||
| 	Header struct { | ||||
| 		Code    int    `json:"code"` | ||||
| 		Message string `json:"message"` | ||||
| 		Sid     string `json:"sid"` | ||||
| 		Status  int    `json:"status"` | ||||
| 	} `json:"header"` | ||||
| 	Payload struct { | ||||
| 		Choices struct { | ||||
| 			Status int                          `json:"status"` | ||||
| 			Seq    int                          `json:"seq"` | ||||
| 			Text   []XunfeiChatResponseTextItem `json:"text"` | ||||
| 		} `json:"choices"` | ||||
| 		Usage struct { | ||||
| 			//Text struct { | ||||
| 			//	QuestionTokens   string `json:"question_tokens"` | ||||
| 			//	PromptTokens     string `json:"prompt_tokens"` | ||||
| 			//	CompletionTokens string `json:"completion_tokens"` | ||||
| 			//	TotalTokens      string `json:"total_tokens"` | ||||
| 			//} `json:"text"` | ||||
| 			Text Usage `json:"text"` | ||||
| 		} `json:"usage"` | ||||
| 	} `json:"payload"` | ||||
| } | ||||
|  | ||||
| func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest { | ||||
| 	messages := make([]XunfeiMessage, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, XunfeiMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	xunfeiRequest := XunfeiChatRequest{} | ||||
| 	xunfeiRequest.Header.AppId = xunfeiAppId | ||||
| 	xunfeiRequest.Parameter.Chat.Domain = "general" | ||||
| 	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature | ||||
| 	xunfeiRequest.Parameter.Chat.TopK = request.N | ||||
| 	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens | ||||
| 	xunfeiRequest.Payload.Message.Text = messages | ||||
| 	return &xunfeiRequest | ||||
| } | ||||
|  | ||||
| func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { | ||||
| 	if len(response.Payload.Choices.Text) == 0 { | ||||
| 		response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | ||||
| 			{ | ||||
| 				Content: "", | ||||
| 			}, | ||||
| 		} | ||||
| 	} | ||||
| 	choice := OpenAITextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Payload.Choices.Text[0].Content, | ||||
| 		}, | ||||
| 	} | ||||
| 	fullTextResponse := OpenAITextResponse{ | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Choices: []OpenAITextResponseChoice{choice}, | ||||
| 		Usage:   response.Payload.Usage.Text, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { | ||||
| 	if len(xunfeiResponse.Payload.Choices.Text) == 0 { | ||||
| 		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ | ||||
| 			{ | ||||
| 				Content: "", | ||||
| 			}, | ||||
| 		} | ||||
| 	} | ||||
| 	var choice ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content | ||||
| 	response := ChatCompletionsStreamResponse{ | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: common.GetTimestamp(), | ||||
| 		Model:   "SparkDesk", | ||||
| 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { | ||||
| 	HmacWithShaToBase64 := func(algorithm, data, key string) string { | ||||
| 		mac := hmac.New(sha256.New, []byte(key)) | ||||
| 		mac.Write([]byte(data)) | ||||
| 		encodeData := mac.Sum(nil) | ||||
| 		return base64.StdEncoding.EncodeToString(encodeData) | ||||
| 	} | ||||
| 	ul, err := url.Parse(hostUrl) | ||||
| 	if err != nil { | ||||
| 		fmt.Println(err) | ||||
| 	} | ||||
| 	date := time.Now().UTC().Format(time.RFC1123) | ||||
| 	signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} | ||||
| 	sign := strings.Join(signString, "\n") | ||||
| 	sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret) | ||||
| 	authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, | ||||
| 		"hmac-sha256", "host date request-line", sha) | ||||
| 	authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) | ||||
| 	v := url.Values{} | ||||
| 	v.Add("host", ul.Host) | ||||
| 	v.Add("date", date) | ||||
| 	v.Add("authorization", authorization) | ||||
| 	callUrl := hostUrl + "?" + v.Encode() | ||||
| 	return callUrl | ||||
| } | ||||
|  | ||||
| func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var usage Usage | ||||
| 	d := websocket.Dialer{ | ||||
| 		HandshakeTimeout: 5 * time.Second, | ||||
| 	} | ||||
| 	hostUrl := "wss://aichat.xf-yun.com/v1/chat" | ||||
| 	conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil) | ||||
| 	if err != nil || resp.StatusCode != 101 { | ||||
| 		return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	data := requestOpenAI2Xunfei(textRequest, appId) | ||||
| 	err = conn.WriteJSON(data) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	dataChan := make(chan XunfeiChatResponse) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			_, msg, err := conn.ReadMessage() | ||||
| 			if err != nil { | ||||
| 				common.SysError("error reading stream response: " + err.Error()) | ||||
| 				break | ||||
| 			} | ||||
| 			var response XunfeiChatResponse | ||||
| 			err = json.Unmarshal(msg, &response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				break | ||||
| 			} | ||||
| 			dataChan <- response | ||||
| 			if response.Payload.Choices.Status == 2 { | ||||
| 				err := conn.Close() | ||||
| 				if err != nil { | ||||
| 					common.SysError("error closing websocket connection: " + err.Error()) | ||||
| 				} | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	c.Writer.Header().Set("Content-Type", "text/event-stream") | ||||
| 	c.Writer.Header().Set("Cache-Control", "no-cache") | ||||
| 	c.Writer.Header().Set("Connection", "keep-alive") | ||||
| 	c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||||
| 	c.Writer.Header().Set("X-Accel-Buffering", "no") | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case xunfeiResponse := <-dataChan: | ||||
| 			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens | ||||
| 			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens | ||||
| 			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens | ||||
| 			response := streamResponseXunfei2OpenAI(&xunfeiResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				common.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||
| 	var xunfeiResponse XunfeiChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &xunfeiResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if xunfeiResponse.Header.Code != 0 { | ||||
| 		return &OpenAIErrorWithStatusCode{ | ||||
| 			OpenAIError: OpenAIError{ | ||||
| 				Message: xunfeiResponse.Header.Message, | ||||
| 				Type:    "xunfei_error", | ||||
| 				Param:   "", | ||||
| 				Code:    xunfeiResponse.Header.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
| @@ -111,10 +111,21 @@ func getZhipuToken(apikey string) string { | ||||
| func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { | ||||
| 	messages := make([]ZhipuMessage, 0, len(request.Messages)) | ||||
| 	for _, message := range request.Messages { | ||||
| 		messages = append(messages, ZhipuMessage{ | ||||
| 			Role:    message.Role, | ||||
| 			Content: message.Content, | ||||
| 		}) | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    "system", | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    "user", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 		} else { | ||||
| 			messages = append(messages, ZhipuMessage{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.Content, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	return &ZhipuRequest{ | ||||
| 		Prompt:      messages, | ||||
| @@ -183,8 +194,8 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 { | ||||
| 			return i + 2, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| @@ -197,14 +208,19 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			data = strings.Trim(data, "\"") | ||||
| 			if len(data) < 5 { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:5] == "data:" { | ||||
| 				dataChan <- data[5:] | ||||
| 			} else if data[:5] == "meta:" { | ||||
| 				metaChan <- data[5:] | ||||
| 			lines := strings.Split(data, "\n") | ||||
| 			for i, line := range lines { | ||||
| 				if len(line) < 5 { | ||||
| 					continue | ||||
| 				} | ||||
| 				if line[:5] == "data:" { | ||||
| 					dataChan <- line[5:] | ||||
| 					if i != len(lines)-1 { | ||||
| 						dataChan <- "\n" | ||||
| 					} | ||||
| 				} else if line[:5] == "meta:" { | ||||
| 					metaChan <- line[5:] | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		stopChan <- true | ||||
|   | ||||
| @@ -46,7 +46,6 @@ type ChatRequest struct { | ||||
| 	Model     string    `json:"model"` | ||||
| 	Messages  []Message `json:"messages"` | ||||
| 	MaxTokens int       `json:"max_tokens"` | ||||
| 	Stream    bool      `json:"stream"` | ||||
| } | ||||
|  | ||||
| type TextRequest struct { | ||||
| @@ -82,8 +81,9 @@ type OpenAIErrorWithStatusCode struct { | ||||
| } | ||||
|  | ||||
| type TextResponse struct { | ||||
| 	Usage `json:"usage"` | ||||
| 	Error OpenAIError `json:"error"` | ||||
| 	Choices []OpenAITextResponseChoice `json:"choices"` | ||||
| 	Usage   `json:"usage"` | ||||
| 	Error   OpenAIError `json:"error"` | ||||
| } | ||||
|  | ||||
| type OpenAITextResponseChoice struct { | ||||
| @@ -100,6 +100,19 @@ type OpenAITextResponse struct { | ||||
| 	Usage   `json:"usage"` | ||||
| } | ||||
|  | ||||
| type OpenAIEmbeddingResponseItem struct { | ||||
| 	Object    string    `json:"object"` | ||||
| 	Index     int       `json:"index"` | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| } | ||||
|  | ||||
| type OpenAIEmbeddingResponse struct { | ||||
| 	Object string                        `json:"object"` | ||||
| 	Data   []OpenAIEmbeddingResponseItem `json:"data"` | ||||
| 	Model  string                        `json:"model"` | ||||
| 	Usage  `json:"usage"` | ||||
| } | ||||
|  | ||||
| type ImageResponse struct { | ||||
| 	Created int `json:"created"` | ||||
| 	Data    []struct { | ||||
|   | ||||
| @@ -2,7 +2,7 @@ version: '3.4' | ||||
|  | ||||
| services: | ||||
|   one-api: | ||||
|     image: ckt1031/one-api:latest | ||||
|     image: justsong/one-api:latest | ||||
|     container_name: one-api | ||||
|     restart: always | ||||
|     command: --log-dir /app/logs | ||||
|   | ||||
							
								
								
									
										37
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								go.mod
									
									
									
									
									
								
							| @@ -9,29 +9,29 @@ require ( | ||||
| 	github.com/gin-contrib/sessions v0.0.5 | ||||
| 	github.com/gin-contrib/static v0.0.1 | ||||
| 	github.com/gin-gonic/gin v1.9.1 | ||||
| 	github.com/go-playground/validator/v10 v10.14.1 | ||||
| 	github.com/go-playground/validator/v10 v10.14.0 | ||||
| 	github.com/go-redis/redis/v8 v8.11.5 | ||||
| 	github.com/golang-jwt/jwt v3.2.2+incompatible | ||||
| 	github.com/google/uuid v1.3.0 | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.5 | ||||
| 	golang.org/x/crypto v0.11.0 | ||||
| 	gorm.io/driver/mysql v1.5.1 | ||||
| 	gorm.io/driver/sqlite v1.5.2 | ||||
| 	gorm.io/gorm v1.25.2 | ||||
| 	github.com/gorilla/websocket v1.5.0 | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.1 | ||||
| 	golang.org/x/crypto v0.9.0 | ||||
| 	gorm.io/driver/mysql v1.4.3 | ||||
| 	gorm.io/driver/sqlite v1.4.3 | ||||
| 	gorm.io/gorm v1.24.0 | ||||
| ) | ||||
|  | ||||
| require ( | ||||
| 	github.com/bytedance/sonic v1.9.2 // indirect | ||||
| 	github.com/cespare/xxhash/v2 v2.2.0 // indirect | ||||
| 	github.com/bytedance/sonic v1.9.1 // indirect | ||||
| 	github.com/cespare/xxhash/v2 v2.1.2 // indirect | ||||
| 	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect | ||||
| 	github.com/chenzhuoyu/iasm v0.9.0 // indirect | ||||
| 	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect | ||||
| 	github.com/dlclark/regexp2 v1.10.0 // indirect | ||||
| 	github.com/dlclark/regexp2 v1.8.1 // indirect | ||||
| 	github.com/gabriel-vasile/mimetype v1.4.2 // indirect | ||||
| 	github.com/gin-contrib/sse v0.1.0 // indirect | ||||
| 	github.com/go-playground/locales v0.14.1 // indirect | ||||
| 	github.com/go-playground/universal-translator v0.18.1 // indirect | ||||
| 	github.com/go-sql-driver/mysql v1.7.1 // indirect | ||||
| 	github.com/go-sql-driver/mysql v1.6.0 // indirect | ||||
| 	github.com/goccy/go-json v0.10.2 // indirect | ||||
| 	github.com/gorilla/context v1.1.1 // indirect | ||||
| 	github.com/gorilla/securecookie v1.1.1 // indirect | ||||
| @@ -39,20 +39,19 @@ require ( | ||||
| 	github.com/jinzhu/inflection v1.0.0 // indirect | ||||
| 	github.com/jinzhu/now v1.1.5 // indirect | ||||
| 	github.com/json-iterator/go v1.1.12 // indirect | ||||
| 	github.com/klauspost/cpuid/v2 v2.2.5 // indirect | ||||
| 	github.com/knz/go-libedit v1.10.1 // indirect | ||||
| 	github.com/klauspost/cpuid/v2 v2.2.4 // indirect | ||||
| 	github.com/leodido/go-urn v1.2.4 // indirect | ||||
| 	github.com/mattn/go-isatty v0.0.19 // indirect | ||||
| 	github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect | ||||
| 	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect | ||||
| 	github.com/modern-go/reflect2 v1.0.2 // indirect | ||||
| 	github.com/pelletier/go-toml/v2 v2.0.9 // indirect | ||||
| 	github.com/pelletier/go-toml/v2 v2.0.8 // indirect | ||||
| 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | ||||
| 	github.com/ugorji/go/codec v1.2.11 // indirect | ||||
| 	golang.org/x/arch v0.4.0 // indirect | ||||
| 	golang.org/x/net v0.12.0 // indirect | ||||
| 	golang.org/x/sys v0.10.0 // indirect | ||||
| 	golang.org/x/text v0.11.0 // indirect | ||||
| 	google.golang.org/protobuf v1.31.0 // indirect | ||||
| 	golang.org/x/arch v0.3.0 // indirect | ||||
| 	golang.org/x/net v0.10.0 // indirect | ||||
| 	golang.org/x/sys v0.8.0 // indirect | ||||
| 	golang.org/x/text v0.9.0 // indirect | ||||
| 	google.golang.org/protobuf v1.30.0 // indirect | ||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||
| ) | ||||
|   | ||||
							
								
								
									
										51
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										51
									
								
								go.sum
									
									
									
									
									
								
							| @@ -1,23 +1,11 @@ | ||||
| github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= | ||||
| github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= | ||||
| github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= | ||||
| github.com/bytedance/sonic v1.9.2 h1:GDaNjuWSGu09guE9Oql0MSTNhNCLlWwO8y/xM5BzcbM= | ||||
| github.com/bytedance/sonic v1.9.2/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= | ||||
| github.com/bytedance/sonic v1.10.0-rc h1:3S5HeWxjX08CUqNrXtEittExpJsEKBNzrV5UnrzHxVQ= | ||||
| github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM= | ||||
| github.com/bytedance/sonic v1.10.0-rc2 h1:oDfRZ+4m6AYCOC0GFeOCeYqvBmucy1isvouS2K0cPzo= | ||||
| github.com/bytedance/sonic v1.10.0-rc2/go.mod h1:iZcSUejdk5aukTND/Eu/ivjQuEL0Cu9/rf50Hi0u/g4= | ||||
| github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= | ||||
| github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= | ||||
| github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= | ||||
| github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= | ||||
| github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= | ||||
| github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= | ||||
| github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= | ||||
| github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d h1:77cEq6EriyTZ0g/qfRdp61a3Uu/AWrgIq2s0ClJV1g0= | ||||
| github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d/go.mod h1:8EPpVsBuRksnlj1mLy4AWzRNQYxauNi62uWcE3to6eA= | ||||
| github.com/chenzhuoyu/iasm v0.9.0 h1:9fhXjVzq5hUy2gkhhgHl95zG2cEAhw9OSGs8toWWAwo= | ||||
| github.com/chenzhuoyu/iasm v0.9.0/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog= | ||||
| github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= | ||||
| github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | ||||
| github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= | ||||
| @@ -26,8 +14,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r | ||||
| github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= | ||||
| github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= | ||||
| github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= | ||||
| github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= | ||||
| github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= | ||||
| github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= | ||||
| github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= | ||||
| github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= | ||||
| @@ -59,15 +45,10 @@ github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GO | ||||
| github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= | ||||
| github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= | ||||
| github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= | ||||
| github.com/go-playground/validator/v10 v10.14.1 h1:9c50NUPC30zyuKprjL3vNZ0m5oG+jU0zvx4AqHGnv4k= | ||||
| github.com/go-playground/validator/v10 v10.14.1/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= | ||||
| github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= | ||||
| github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= | ||||
| github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= | ||||
| github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= | ||||
| github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= | ||||
| github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= | ||||
| github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= | ||||
| github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= | ||||
| github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= | ||||
| github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= | ||||
| @@ -86,6 +67,8 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC | ||||
| github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= | ||||
| github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= | ||||
| github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= | ||||
| github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= | ||||
| github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= | ||||
| github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | ||||
| github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | ||||
| github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||||
| @@ -97,10 +80,6 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm | ||||
| github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= | ||||
| github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= | ||||
| github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= | ||||
| github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg= | ||||
| github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= | ||||
| github.com/knz/go-libedit v1.10.1 h1:0pHpWtx9vcvC0xGZqEQlQdfSQs7WRlAjuPvk3fOZDCo= | ||||
| github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= | ||||
| github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= | ||||
| github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= | ||||
| github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= | ||||
| @@ -132,13 +111,9 @@ github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= | ||||
| github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= | ||||
| github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= | ||||
| github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= | ||||
| github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= | ||||
| github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= | ||||
| github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= | ||||
| github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo= | ||||
| github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw= | ||||
| github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4= | ||||
| github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= | ||||
| github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||||
| github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||||
| github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= | ||||
| @@ -157,7 +132,6 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o | ||||
| github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= | ||||
| github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= | ||||
| github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= | ||||
| github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= | ||||
| github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= | ||||
| github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= | ||||
| github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= | ||||
| @@ -169,38 +143,27 @@ github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ | ||||
| golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | ||||
| golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= | ||||
| golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | ||||
| golang.org/x/arch v0.4.0 h1:A8WCeEWhLwPBKNbFi5Wv5UTCBx5zzubnXDlMOFAzFMc= | ||||
| golang.org/x/arch v0.4.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | ||||
| golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= | ||||
| golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= | ||||
| golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= | ||||
| golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= | ||||
| golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= | ||||
| golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | ||||
| golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= | ||||
| golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= | ||||
| golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= | ||||
| golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= | ||||
| golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= | ||||
| golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= | ||||
| golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | ||||
| golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= | ||||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= | ||||
| golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= | ||||
| golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= | ||||
| golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= | ||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| @@ -208,8 +171,6 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 | ||||
| google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= | ||||
| google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= | ||||
| google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= | ||||
| google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= | ||||
| google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= | ||||
| gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= | ||||
| @@ -226,17 +187,9 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= | ||||
| gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | ||||
| gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k= | ||||
| gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= | ||||
| gorm.io/driver/mysql v1.5.1 h1:WUEH5VF9obL/lTtzjmML/5e6VfFR/788coz2uaVCAZw= | ||||
| gorm.io/driver/mysql v1.5.1/go.mod h1:Jo3Xu7mMhCyj8dlrb3WoCaRd1FhsVh+yMXb1jUInf5o= | ||||
| gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= | ||||
| gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= | ||||
| gorm.io/driver/sqlite v1.5.2 h1:TpQ+/dqCY4uCigCFyrfnrJnrW9zjpelWVoEVNy5qJkc= | ||||
| gorm.io/driver/sqlite v1.5.2/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= | ||||
| gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= | ||||
| gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74= | ||||
| gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= | ||||
| gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= | ||||
| gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho= | ||||
| gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= | ||||
| nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= | ||||
| rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= | ||||
|   | ||||
| @@ -503,5 +503,12 @@ | ||||
|   "请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT", | ||||
|   "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", | ||||
|   "Homepage URL 填": "Fill in the Homepage URL", | ||||
|   "Authorization callback URL 填": "Fill in the Authorization callback URL" | ||||
|   "Authorization callback URL 填": "Fill in the Authorization callback URL", | ||||
|   "请为通道命名": "Please name the channel", | ||||
|   "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", | ||||
|   "模型重定向": "Model redirection", | ||||
|   "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel", | ||||
|   "注意,": "Note that, ", | ||||
|   ",图片演示。": "related image demo.", | ||||
|   "令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!" | ||||
| } | ||||
|   | ||||
| @@ -2,7 +2,6 @@ package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"one-api/common" | ||||
| 	"one-api/model" | ||||
| @@ -13,8 +12,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| type ModelRequest struct { | ||||
| 	Model  string `json:"model"` | ||||
| 	Stream bool   `json:"stream" default:"true"` | ||||
| 	Model string `json:"model"` | ||||
| } | ||||
|  | ||||
| func Distribute() func(c *gin.Context) { | ||||
| @@ -86,8 +84,7 @@ func Distribute() func(c *gin.Context) { | ||||
| 					modelRequest.Model = "dall-e" | ||||
| 				} | ||||
| 			} | ||||
| 			log.Print(modelRequest.Stream) | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, modelRequest.Stream) | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) | ||||
| 			if err != nil { | ||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | ||||
| 				if channel != nil { | ||||
|   | ||||
| @@ -1,36 +1,24 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type Ability struct { | ||||
| 	Group             string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"` | ||||
| 	Model             string `json:"model" gorm:"primaryKey;autoIncrement:false"` | ||||
| 	ChannelId         int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` | ||||
| 	Enabled           bool   `json:"enabled"` | ||||
| 	AllowStreaming    int    `json:"allow_streaming" gorm:"default:1"` | ||||
| 	AllowNonStreaming int    `json:"allow_non_streaming" gorm:"default:1"` | ||||
| 	Group     string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"` | ||||
| 	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"` | ||||
| 	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` | ||||
| 	Enabled   bool   `json:"enabled"` | ||||
| } | ||||
|  | ||||
| func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) { | ||||
| func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||
| 	ability := Ability{} | ||||
| 	var err error = nil | ||||
|  | ||||
| 	cmd := "`group` = ? and model = ? and enabled = 1" | ||||
|  | ||||
| 	if stream { | ||||
| 		cmd += fmt.Sprintf(" and allow_streaming = %d", common.ChannelAllowStreamEnabled) | ||||
| 	} else { | ||||
| 		cmd += fmt.Sprintf(" and allow_non_streaming = %d", common.ChannelAllowNonStreamEnabled) | ||||
| 	} | ||||
|  | ||||
| 	if common.UsingSQLite { | ||||
| 		err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error | ||||
| 		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error | ||||
| 	} else { | ||||
| 		err = DB.Where(cmd, group, model).Order("RAND()").Limit(1).First(&ability).Error | ||||
| 		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @@ -48,12 +36,10 @@ func (channel *Channel) AddAbilities() error { | ||||
| 	for _, model := range models_ { | ||||
| 		for _, group := range groups_ { | ||||
| 			ability := Ability{ | ||||
| 				Group:             group, | ||||
| 				Model:             model, | ||||
| 				ChannelId:         channel.Id, | ||||
| 				Enabled:           channel.Status == common.ChannelStatusEnabled, | ||||
| 				AllowStreaming:    channel.AllowStreaming, | ||||
| 				AllowNonStreaming: channel.AllowNonStreaming, | ||||
| 				Group:     group, | ||||
| 				Model:     model, | ||||
| 				ChannelId: channel.Id, | ||||
| 				Enabled:   channel.Status == common.ChannelStatusEnabled, | ||||
| 			} | ||||
| 			abilities = append(abilities, ability) | ||||
| 		} | ||||
|   | ||||
| @@ -160,9 +160,9 @@ func SyncChannelCache(frequency int) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func CacheGetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) { | ||||
| func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||
| 	if !common.RedisEnabled { | ||||
| 		return GetRandomSatisfiedChannel(group, model, stream) | ||||
| 		return GetRandomSatisfiedChannel(group, model) | ||||
| 	} | ||||
| 	channelSyncLock.RLock() | ||||
| 	defer channelSyncLock.RUnlock() | ||||
| @@ -170,14 +170,6 @@ func CacheGetRandomSatisfiedChannel(group string, model string, stream bool) (*C | ||||
| 	if len(channels) == 0 { | ||||
| 		return nil, errors.New("channel not found") | ||||
| 	} | ||||
|  | ||||
| 	var filteredChannels []*Channel | ||||
| 	for _, channel := range channels { | ||||
| 		if (stream && channel.AllowStreaming == common.ChannelAllowStreamEnabled) || (!stream && channel.AllowNonStreaming == common.ChannelAllowNonStreamEnabled) { | ||||
| 			filteredChannels = append(filteredChannels, channel) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	idx := rand.Intn(len(filteredChannels)) | ||||
| 	return filteredChannels[idx], nil | ||||
| 	idx := rand.Intn(len(channels)) | ||||
| 	return channels[idx], nil | ||||
| } | ||||
|   | ||||
| @@ -1,9 +1,8 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"one-api/common" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| 	"one-api/common" | ||||
| ) | ||||
|  | ||||
| type Channel struct { | ||||
| @@ -24,8 +23,6 @@ type Channel struct { | ||||
| 	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"` | ||||
| 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"` | ||||
| 	ModelMapping       string  `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | ||||
| 	AllowStreaming     int     `json:"allow_streaming" gorm:"default:1"` | ||||
| 	AllowNonStreaming  int     `json:"allow_non_streaming" gorm:"default:1"` | ||||
| } | ||||
|  | ||||
| func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | ||||
|   | ||||
| @@ -30,9 +30,7 @@ func InitOptionMap() { | ||||
| 	common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) | ||||
| 	common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) | ||||
| 	common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) | ||||
| 	common.OptionMap["DiscordOAuthEnabled"] = strconv.FormatBool(common.DiscordOAuthEnabled) | ||||
| 	common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) | ||||
| 	common.OptionMap["GoogleOAuthEnabled"] = strconv.FormatBool(common.GoogleOAuthEnabled) | ||||
| 	common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) | ||||
| 	common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) | ||||
| 	common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) | ||||
| @@ -41,6 +39,8 @@ func InitOptionMap() { | ||||
| 	common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) | ||||
| 	common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) | ||||
| 	common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) | ||||
| 	common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) | ||||
| 	common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") | ||||
| 	common.OptionMap["SMTPServer"] = "" | ||||
| 	common.OptionMap["SMTPFrom"] = "" | ||||
| 	common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) | ||||
| @@ -55,13 +55,9 @@ func InitOptionMap() { | ||||
| 	common.OptionMap["ServerAddress"] = "" | ||||
| 	common.OptionMap["GitHubClientId"] = "" | ||||
| 	common.OptionMap["GitHubClientSecret"] = "" | ||||
| 	common.OptionMap["DiscordClientId"] = "" | ||||
| 	common.OptionMap["DiscordClientSecret"] = "" | ||||
| 	common.OptionMap["WeChatServerAddress"] = "" | ||||
| 	common.OptionMap["WeChatServerToken"] = "" | ||||
| 	common.OptionMap["WeChatAccountQRCodeImageURL"] = "" | ||||
| 	common.OptionMap["GoogleClientId"] = "" | ||||
| 	common.OptionMap["GoogleClientSecret"] = "" | ||||
| 	common.OptionMap["TurnstileSiteKey"] = "" | ||||
| 	common.OptionMap["TurnstileSecretKey"] = "" | ||||
| 	common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) | ||||
| @@ -141,16 +137,14 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 			common.EmailVerificationEnabled = boolValue | ||||
| 		case "GitHubOAuthEnabled": | ||||
| 			common.GitHubOAuthEnabled = boolValue | ||||
| 		case "DiscordOAuthEnabled": | ||||
| 			common.DiscordOAuthEnabled = boolValue | ||||
| 		case "WeChatAuthEnabled": | ||||
| 			common.WeChatAuthEnabled = boolValue | ||||
| 		case "GoogleOAuthEnabled": | ||||
| 			common.GoogleOAuthEnabled = boolValue | ||||
| 		case "TurnstileCheckEnabled": | ||||
| 			common.TurnstileCheckEnabled = boolValue | ||||
| 		case "RegisterEnabled": | ||||
| 			common.RegisterEnabled = boolValue | ||||
| 		case "EmailDomainRestrictionEnabled": | ||||
| 			common.EmailDomainRestrictionEnabled = boolValue | ||||
| 		case "AutomaticDisableChannelEnabled": | ||||
| 			common.AutomaticDisableChannelEnabled = boolValue | ||||
| 		case "ApproximateTokenEnabled": | ||||
| @@ -164,6 +158,8 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 		} | ||||
| 	} | ||||
| 	switch key { | ||||
| 	case "EmailDomainWhitelist": | ||||
| 		common.EmailDomainWhitelist = strings.Split(value, ",") | ||||
| 	case "SMTPServer": | ||||
| 		common.SMTPServer = value | ||||
| 	case "SMTPPort": | ||||
| @@ -181,10 +177,6 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 		common.GitHubClientId = value | ||||
| 	case "GitHubClientSecret": | ||||
| 		common.GitHubClientSecret = value | ||||
| 	case "DiscordClientId": | ||||
| 		common.DiscordClientId = value | ||||
| 	case "DiscordClientSecret": | ||||
| 		common.DiscordClientSecret = value | ||||
| 	case "Footer": | ||||
| 		common.Footer = value | ||||
| 	case "SystemName": | ||||
| @@ -197,10 +189,6 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 		common.WeChatServerToken = value | ||||
| 	case "WeChatAccountQRCodeImageURL": | ||||
| 		common.WeChatAccountQRCodeImageURL = value | ||||
| 	case "GoogleClientId": | ||||
| 		common.GoogleClientId = value | ||||
| 	case "GoogleClientSecret": | ||||
| 		common.GoogleClientSecret = value | ||||
| 	case "TurnstileSiteKey": | ||||
| 		common.TurnstileSiteKey = value | ||||
| 	case "TurnstileSecretKey": | ||||
|   | ||||
| @@ -3,10 +3,9 @@ package model | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"gorm.io/gorm" | ||||
| 	"one-api/common" | ||||
| 	"strings" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // User if you add sensitive fields, don't forget to clean them in setupLogin function. | ||||
| @@ -20,9 +19,7 @@ type User struct { | ||||
| 	Status           int    `json:"status" gorm:"type:int;default:1"` // enabled, disabled | ||||
| 	Email            string `json:"email" gorm:"index" validate:"max=50"` | ||||
| 	GitHubId         string `json:"github_id" gorm:"column:github_id;index"` | ||||
| 	DiscordId        string `json:"discord_id" gorm:"column:discord_id;index"` | ||||
| 	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"` | ||||
| 	GoogleId         string `json:"google_id" gorm:"column:google_id;index"` | ||||
| 	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database! | ||||
| 	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management | ||||
| 	Quota            int    `json:"quota" gorm:"type:int;default:0"` | ||||
| @@ -172,14 +169,6 @@ func (user *User) FillUserByGitHubId() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (user *User) FillUserByDiscordId() error { | ||||
| 	if user.DiscordId == "" { | ||||
| 		return errors.New("Discord id 为空!") | ||||
| 	} | ||||
| 	DB.Where(User{DiscordId: user.DiscordId}).First(user) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (user *User) FillUserByWeChatId() error { | ||||
| 	if user.WeChatId == "" { | ||||
| 		return errors.New("WeChat id 为空!") | ||||
| @@ -188,14 +177,6 @@ func (user *User) FillUserByWeChatId() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (user *User) FillUserByGoogleId() error { | ||||
| 	if user.GoogleId == "" { | ||||
| 		return errors.New("Google id 为空!") | ||||
| 	} | ||||
| 	DB.Where(User{GoogleId: user.GoogleId}).First(user) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (user *User) FillUserByUsername() error { | ||||
| 	if user.Username == "" { | ||||
| 		return errors.New("username 为空!") | ||||
| @@ -212,14 +193,6 @@ func IsWeChatIdAlreadyTaken(wechatId string) bool { | ||||
| 	return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|  | ||||
| func IsDiscordIdAlreadyTaken(discordId string) bool { | ||||
| 	return DB.Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|  | ||||
| func IsGoogleIdAlreadyTaken(googleId string) bool { | ||||
| 	return DB.Where("google_id = ?", googleId).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|  | ||||
| func IsGitHubIdAlreadyTaken(githubId string) bool { | ||||
| 	return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|   | ||||
| @@ -21,16 +21,14 @@ func SetApiRouter(router *gin.Engine) { | ||||
| 		apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) | ||||
| 		apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) | ||||
| 		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) | ||||
| 		apiRouter.GET("/oauth/discord", middleware.CriticalRateLimit(), controller.DiscordOAuth) | ||||
| 		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) | ||||
| 		apiRouter.GET("/oauth/google", middleware.CriticalRateLimit(), controller.GoogleOAuth) | ||||
| 		apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) | ||||
| 		apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) | ||||
|  | ||||
| 		userRoute := apiRouter.Group("/user") | ||||
| 		{ | ||||
| 			userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register) | ||||
| 			userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login) | ||||
| 			userRoute.POST("/login", middleware.CriticalRateLimit(), controller.Login) | ||||
| 			userRoute.GET("/logout", controller.Logout) | ||||
|  | ||||
| 			selfRoute := userRoute.Group("/") | ||||
|   | ||||
| @@ -3,19 +3,18 @@ | ||||
|   "version": "0.1.0", | ||||
|   "private": true, | ||||
|   "dependencies": { | ||||
|     "@babel/plugin-proposal-private-property-in-object": "^7.21.11", | ||||
|     "axios": "^1.4.0", | ||||
|     "axios": "^0.27.2", | ||||
|     "history": "^5.3.0", | ||||
|     "marked": "^5.1.1", | ||||
|     "marked": "^4.1.1", | ||||
|     "react": "^18.2.0", | ||||
|     "react-dom": "^18.2.0", | ||||
|     "react-dropzone": "^14.2.3", | ||||
|     "react-router-dom": "^6.14.2", | ||||
|     "react-router-dom": "^6.3.0", | ||||
|     "react-scripts": "5.0.1", | ||||
|     "react-toastify": "^9.1.3", | ||||
|     "react-turnstile": "^1.1.1", | ||||
|     "react-toastify": "^9.0.8", | ||||
|     "react-turnstile": "^1.0.5", | ||||
|     "semantic-ui-css": "^2.5.0", | ||||
|     "semantic-ui-react": "^2.1.4" | ||||
|     "semantic-ui-react": "^2.1.3" | ||||
|   }, | ||||
|   "scripts": { | ||||
|     "start": "react-scripts start", | ||||
| @@ -42,7 +41,7 @@ | ||||
|     ] | ||||
|   }, | ||||
|   "devDependencies": { | ||||
|     "prettier": "^3.0.0" | ||||
|     "prettier": "^2.7.1" | ||||
|   }, | ||||
|   "prettier": { | ||||
|     "singleQuote": true, | ||||
|   | ||||
| @@ -12,8 +12,6 @@ import AddUser from './pages/User/AddUser'; | ||||
| import { API, getLogo, getSystemName, showError, showNotice } from './helpers'; | ||||
| import PasswordResetForm from './components/PasswordResetForm'; | ||||
| import GitHubOAuth from './components/GitHubOAuth'; | ||||
| import DiscordOAuth from './components/DiscordOAuth'; | ||||
| import GoogleOAuth from './components/GoogleOAuth'; | ||||
| import PasswordResetConfirm from './components/PasswordResetConfirm'; | ||||
| import { UserContext } from './context/User'; | ||||
| import { StatusContext } from './context/Status'; | ||||
| @@ -241,24 +239,6 @@ function App() { | ||||
|           </Suspense> | ||||
|         } | ||||
|       /> | ||||
|       <Route | ||||
|         HEAD | ||||
|         path='/oauth/discord' | ||||
|         element={ | ||||
|           <Suspense fallback={<Loading></Loading>}> | ||||
|             <DiscordOAuth /> | ||||
|           </Suspense> | ||||
|         } | ||||
|       /> | ||||
|       <Route | ||||
|         path='/oauth/google' | ||||
|         element={ | ||||
|           <Suspense fallback={<Loading></Loading>}> | ||||
|             <GoogleOAuth /> | ||||
|             support-google-oauth | ||||
|           </Suspense> | ||||
|         } | ||||
|       /> | ||||
|       <Route | ||||
|         path='/setting' | ||||
|         element={ | ||||
| @@ -272,11 +252,11 @@ function App() { | ||||
|       <Route | ||||
|         path='/topup' | ||||
|         element={ | ||||
|           <PrivateRoute> | ||||
|             <Suspense fallback={<Loading></Loading>}> | ||||
|               <TopUp /> | ||||
|             </Suspense> | ||||
|           </PrivateRoute> | ||||
|         <PrivateRoute> | ||||
|           <Suspense fallback={<Loading></Loading>}> | ||||
|             <TopUp /> | ||||
|           </Suspense> | ||||
|         </PrivateRoute> | ||||
|         } | ||||
|       /> | ||||
|       <Route | ||||
|   | ||||
| @@ -363,9 +363,12 @@ const ChannelsTable = () => { | ||||
|                   </Table.Cell> | ||||
|                   <Table.Cell> | ||||
|                     <Popup | ||||
|                       content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'} | ||||
|                       key={channel.id} | ||||
|                       trigger={renderBalance(channel.type, channel.balance)} | ||||
|                       trigger={<span onClick={() => { | ||||
|                         updateChannelBalance(channel.id, channel.name, idx); | ||||
|                       }} style={{ cursor: 'pointer' }}> | ||||
|                       {renderBalance(channel.type, channel.balance)} | ||||
|                     </span>} | ||||
|                       content="点击更新" | ||||
|                       basic | ||||
|                     /> | ||||
|                   </Table.Cell> | ||||
| @@ -380,16 +383,16 @@ const ChannelsTable = () => { | ||||
|                       > | ||||
|                         测试 | ||||
|                       </Button> | ||||
|                       <Button | ||||
|                         size={'small'} | ||||
|                         positive | ||||
|                         loading={updatingBalance} | ||||
|                         onClick={() => { | ||||
|                           updateChannelBalance(channel.id, channel.name, idx); | ||||
|                         }} | ||||
|                       > | ||||
|                         更新余额 | ||||
|                       </Button> | ||||
|                       {/*<Button*/} | ||||
|                       {/*  size={'small'}*/} | ||||
|                       {/*  positive*/} | ||||
|                       {/*  loading={updatingBalance}*/} | ||||
|                       {/*  onClick={() => {*/} | ||||
|                       {/*    updateChannelBalance(channel.id, channel.name, idx);*/} | ||||
|                       {/*  }}*/} | ||||
|                       {/*>*/} | ||||
|                       {/*  更新余额*/} | ||||
|                       {/*</Button>*/} | ||||
|                       <Popup | ||||
|                         trigger={ | ||||
|                           <Button size='small' negative> | ||||
|   | ||||
| @@ -1,57 +0,0 @@ | ||||
| import React, { useContext, useEffect, useState } from 'react'; | ||||
| import { Dimmer, Loader, Segment } from 'semantic-ui-react'; | ||||
| import { useNavigate, useSearchParams } from 'react-router-dom'; | ||||
| import { API, showError, showSuccess } from '../helpers'; | ||||
| import { UserContext } from '../context/User'; | ||||
|  | ||||
| const DiscordOAuth = () => { | ||||
|   const [searchParams, setSearchParams] = useSearchParams(); | ||||
|  | ||||
|   const [userState, userDispatch] = useContext(UserContext); | ||||
|   const [prompt, setPrompt] = useState('处理中...'); | ||||
|   const [processing, setProcessing] = useState(true); | ||||
|  | ||||
|   let navigate = useNavigate(); | ||||
|  | ||||
|   const sendCode = async (code, count) => { | ||||
|     const res = await API.get(`/api/oauth/discord?code=${code}`); | ||||
|     const { success, message, data } = res.data; | ||||
|     if (success) { | ||||
|       if (message === 'bind') { | ||||
|         showSuccess('绑定成功!'); | ||||
|         navigate('/setting'); | ||||
|       } else { | ||||
|         userDispatch({ type: 'login', payload: data }); | ||||
|         localStorage.setItem('user', JSON.stringify(data)); | ||||
|         showSuccess('登录成功!'); | ||||
|         navigate('/'); | ||||
|       } | ||||
|     } else { | ||||
|       showError(message); | ||||
|       if (count === 0) { | ||||
|         setPrompt(`操作失败,重定向至登录界面中...`); | ||||
|         navigate('/setting'); // in case this is failed to bind GitHub | ||||
|         return; | ||||
|       } | ||||
|       count++; | ||||
|       setPrompt(`出现错误,第 ${count} 次重试中...`); | ||||
|       await new Promise((resolve) => setTimeout(resolve, count * 2000)); | ||||
|       await sendCode(code, count); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let code = searchParams.get('code'); | ||||
|     sendCode(code, 0).then(); | ||||
|   }, []); | ||||
|  | ||||
|   return ( | ||||
|     <Segment style={{ minHeight: '300px' }}> | ||||
|       <Dimmer active inverted> | ||||
|         <Loader size='large'>{prompt}</Loader> | ||||
|       </Dimmer> | ||||
|     </Segment> | ||||
|   ); | ||||
| }; | ||||
|  | ||||
| export default DiscordOAuth; | ||||
| @@ -1,57 +0,0 @@ | ||||
| import React, { useContext, useEffect, useState } from 'react'; | ||||
| import { Dimmer, Loader, Segment } from 'semantic-ui-react'; | ||||
| import { useNavigate, useSearchParams } from 'react-router-dom'; | ||||
| import { API, showError, showSuccess } from '../helpers'; | ||||
| import { UserContext } from '../context/User'; | ||||
|  | ||||
| const GoogleOAuth = () => { | ||||
|   const [searchParams, setSearchParams] = useSearchParams(); | ||||
|  | ||||
|   const [userState, userDispatch] = useContext(UserContext); | ||||
|   const [prompt, setPrompt] = useState('处理中...'); | ||||
|   const [processing, setProcessing] = useState(true); | ||||
|  | ||||
|   let navigate = useNavigate(); | ||||
|  | ||||
|   const sendCode = async (code, count) => { | ||||
|     const res = await API.get(`/api/oauth/google?code=${code}`); | ||||
|     const { success, message, data } = res.data; | ||||
|     if (success) { | ||||
|       if (message === 'bind') { | ||||
|         showSuccess('绑定成功!'); | ||||
|         navigate('/setting'); | ||||
|       } else { | ||||
|         userDispatch({ type: 'login', payload: data }); | ||||
|         localStorage.setItem('user', JSON.stringify(data)); | ||||
|         showSuccess('登录成功!'); | ||||
|         navigate('/'); | ||||
|       } | ||||
|     } else { | ||||
|       showError(message); | ||||
|       if (count === 0) { | ||||
|         setPrompt(`操作失败,重定向至登录界面中...`); | ||||
|         navigate('/setting'); // in case this is failed to bind GitHub | ||||
|         return; | ||||
|       } | ||||
|       count++; | ||||
|       setPrompt(`出现错误,第 ${count} 次重试中...`); | ||||
|       await new Promise((resolve) => setTimeout(resolve, count * 2000)); | ||||
|       await sendCode(code, count); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let code = searchParams.get('code'); | ||||
|     sendCode(code, 0).then(); | ||||
|   }, []); | ||||
|  | ||||
|   return ( | ||||
|     <Segment style={{ minHeight: '300px' }}> | ||||
|       <Dimmer active inverted> | ||||
|         <Loader size='large'>{prompt}</Loader> | ||||
|       </Dimmer> | ||||
|     </Segment> | ||||
|   ); | ||||
| }; | ||||
|  | ||||
| export default GoogleOAuth; | ||||
| @@ -2,8 +2,7 @@ import React, { useContext, useEffect, useState } from 'react'; | ||||
| import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react'; | ||||
| import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | ||||
| import { UserContext } from '../context/User'; | ||||
| import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; | ||||
| import Turnstile from 'react-turnstile'; | ||||
| import { API, getLogo, showError, showSuccess } from '../helpers'; | ||||
|  | ||||
| const LoginForm = () => { | ||||
|   const [inputs, setInputs] = useState({ | ||||
| @@ -15,9 +14,6 @@ const LoginForm = () => { | ||||
|   const [submitted, setSubmitted] = useState(false); | ||||
|   const { username, password } = inputs; | ||||
|   const [userState, userDispatch] = useContext(UserContext); | ||||
|   const [turnstileEnabled, setTurnstileEnabled] = useState(false); | ||||
|   const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); | ||||
|   const [turnstileToken, setTurnstileToken] = useState(''); | ||||
|   let navigate = useNavigate(); | ||||
|   const [status, setStatus] = useState({}); | ||||
|   const logo = getLogo(); | ||||
| @@ -30,34 +26,17 @@ const LoginForm = () => { | ||||
|     if (status) { | ||||
|       status = JSON.parse(status); | ||||
|       setStatus(status); | ||||
|  | ||||
|       if (status.turnstile_check) { | ||||
|         setTurnstileEnabled(true); | ||||
|         setTurnstileSiteKey(status.turnstile_site_key); | ||||
|       } | ||||
|     } | ||||
|   }, []); | ||||
|  | ||||
|   const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); | ||||
|  | ||||
|   const openGoogleOAuth = () => { | ||||
|     window.open( | ||||
|       `https://accounts.google.com/o/oauth2/v2/auth?client_id=${status.google_client_id}&redirect_uri=${window.location.origin}/oauth/google&response_type=code&scope=profile` | ||||
|     ); | ||||
|   }; | ||||
|  | ||||
|   const onGitHubOAuthClicked = () => { | ||||
|     window.open( | ||||
|       `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` | ||||
|     ); | ||||
|   }; | ||||
|  | ||||
|   const onDiscordOAuthClicked = () => { | ||||
|     window.open( | ||||
|       `https://discord.com/oauth2/authorize?response_type=code&client_id=${status.discord_client_id}&redirect_uri=${window.location.origin}/oauth/discord&scope=identify`, | ||||
|     ); | ||||
|   }; | ||||
|  | ||||
|   const onWeChatLoginClicked = () => { | ||||
|     setShowWeChatLoginModal(true); | ||||
|   }; | ||||
| @@ -86,12 +65,7 @@ const LoginForm = () => { | ||||
|   async function handleSubmit(e) { | ||||
|     setSubmitted(true); | ||||
|     if (username && password) { | ||||
|       if (turnstileEnabled && turnstileToken === '') { | ||||
|         showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); | ||||
|         return; | ||||
|       } | ||||
|  | ||||
|       const res = await API.post(`/api/user/login?turnstile=${turnstileToken}`, { | ||||
|       const res = await API.post(`/api/user/login`, { | ||||
|         username, | ||||
|         password | ||||
|       }); | ||||
| @@ -134,16 +108,6 @@ const LoginForm = () => { | ||||
|               value={password} | ||||
|               onChange={handleChange} | ||||
|             /> | ||||
|             {turnstileEnabled ? ( | ||||
|               <Turnstile | ||||
|                 sitekey={turnstileSiteKey} | ||||
|                 onVerify={(token) => { | ||||
|                   setTurnstileToken(token); | ||||
|                 }} | ||||
|               /> | ||||
|             ) : ( | ||||
|               <></> | ||||
|             )} | ||||
|             <Button color='green' fluid size='large' onClick={handleSubmit}> | ||||
|               登录 | ||||
|             </Button> | ||||
| @@ -159,40 +123,28 @@ const LoginForm = () => { | ||||
|             点击注册 | ||||
|           </Link> | ||||
|         </Message> | ||||
|         {status.github_oauth || status.wechat_login || status.discord_oauth || status.google_oauth ? ( | ||||
|         {status.github_oauth || status.wechat_login ? ( | ||||
|           <> | ||||
|             <Divider horizontal>Or</Divider> | ||||
|             {status.discord_oauth && ( | ||||
|               <Button | ||||
|                 circular | ||||
|                 color='blue' | ||||
|                 icon='discord' | ||||
|                 onClick={onDiscordOAuthClicked} | ||||
|               /> | ||||
|             )} | ||||
|             {status.github_oauth && ( | ||||
|             {status.github_oauth ? ( | ||||
|               <Button | ||||
|                 circular | ||||
|                 color='black' | ||||
|                 icon='github' | ||||
|                 onClick={onGitHubOAuthClicked} | ||||
|               /> | ||||
|             ) : ( | ||||
|               <></> | ||||
|             )} | ||||
|             {status.wechat_login && ( | ||||
|             {status.wechat_login ? ( | ||||
|               <Button | ||||
|                 circular | ||||
|                 color='green' | ||||
|                 icon='wechat' | ||||
|                 onClick={onWeChatLoginClicked} | ||||
|               /> | ||||
|             )} | ||||
|             {status.google_oauth && ( | ||||
|               <Button | ||||
|                 circular | ||||
|                 color='red' | ||||
|                 icon='google' | ||||
|                 onClick={openGoogleOAuth} | ||||
|               /> | ||||
|             ) : ( | ||||
|               <></> | ||||
|             )} | ||||
|           </> | ||||
|         ) : ( | ||||
|   | ||||
| @@ -112,24 +112,12 @@ const PersonalSetting = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const openGoogleOAuth = () => { | ||||
|     window.open( | ||||
|       `https://accounts.google.com/o/oauth2/v2/auth?client_id=${status.google_client_id}&redirect_uri=${window.location.origin}/oauth/google&response_type=code&scope=https://www.googleapis.com/auth/userinfo.profile` | ||||
|     ); | ||||
|   }; | ||||
|  | ||||
|   const openGitHubOAuth = () => { | ||||
|     window.open( | ||||
|       `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` | ||||
|     ); | ||||
|   }; | ||||
|  | ||||
|   const openDiscordOAuth = () => { | ||||
|     window.open( | ||||
|       `https://discord.com/api/oauth2/authorize?client_id=${status.discord_client_id}&response_type=code&redirect_uri=${window.location.origin}/oauth/discord&scope=identify`, | ||||
|     ); | ||||
|   }; | ||||
|  | ||||
|   const sendVerificationCode = async () => { | ||||
|     setDisableButton(true); | ||||
|     if (inputs.email === '') return; | ||||
| @@ -227,17 +215,6 @@ const PersonalSetting = () => { | ||||
|           <Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button> | ||||
|         ) | ||||
|       } | ||||
|       { | ||||
|         status.discord_oauth && ( | ||||
|           <Button onClick={openDiscordOAuth}>绑定 Discord 账号</Button> | ||||
|         ) | ||||
|       } | ||||
|       { | ||||
|         status.google_oauth && ( | ||||
|           <Button onClick={openGoogleOAuth}>绑定 Google 账号</Button> | ||||
|  | ||||
|         ) | ||||
|       } | ||||
|       <Button | ||||
|         onClick={() => { | ||||
|           setShowEmailBindModal(true); | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Divider, Form, Grid, Header, Message } from 'semantic-ui-react'; | ||||
| import { API, removeTrailingSlash, showError, verifyJSON } from '../helpers'; | ||||
| import { Button, Divider, Form, Grid, Header, Input, Message } from 'semantic-ui-react'; | ||||
| import { API, removeTrailingSlash, showError } from '../helpers'; | ||||
|  | ||||
| const SystemSetting = () => { | ||||
|   let [inputs, setInputs] = useState({ | ||||
| @@ -8,11 +8,8 @@ const SystemSetting = () => { | ||||
|     PasswordRegisterEnabled: '', | ||||
|     EmailVerificationEnabled: '', | ||||
|     GitHubOAuthEnabled: '', | ||||
|     DiscordOAuthEnabled: '', | ||||
|     GitHubClientId: '', | ||||
|     GitHubClientSecret: '', | ||||
|     DiscordClientId: '', | ||||
|     DiscordClientSecret: '', | ||||
|     Notice: '', | ||||
|     SMTPServer: '', | ||||
|     SMTPPort: '', | ||||
| @@ -25,16 +22,17 @@ const SystemSetting = () => { | ||||
|     WeChatServerAddress: '', | ||||
|     WeChatServerToken: '', | ||||
|     WeChatAccountQRCodeImageURL: '', | ||||
|     GoogleOAuthEnabled: '', | ||||
|     GoogleClientId: '', | ||||
|     GoogleClientSecret: '', | ||||
|     TurnstileCheckEnabled: '', | ||||
|     TurnstileSiteKey: '', | ||||
|     TurnstileSecretKey: '', | ||||
|     RegisterEnabled: '', | ||||
|     EmailDomainRestrictionEnabled: '', | ||||
|     EmailDomainWhitelist: '' | ||||
|   }); | ||||
|   const [originInputs, setOriginInputs] = useState({}); | ||||
|   let [loading, setLoading] = useState(false); | ||||
|   const [EmailDomainWhitelist, setEmailDomainWhitelist] = useState([]); | ||||
|   const [restrictedDomainInput, setRestrictedDomainInput] = useState(''); | ||||
|  | ||||
|   const getOptions = async () => { | ||||
|     const res = await API.get('/api/option/'); | ||||
| @@ -44,8 +42,15 @@ const SystemSetting = () => { | ||||
|       data.forEach((item) => { | ||||
|         newInputs[item.key] = item.value; | ||||
|       }); | ||||
|       setInputs(newInputs); | ||||
|       setInputs({ | ||||
|         ...newInputs, | ||||
|         EmailDomainWhitelist: newInputs.EmailDomainWhitelist.split(',') | ||||
|       }); | ||||
|       setOriginInputs(newInputs); | ||||
|  | ||||
|       setEmailDomainWhitelist(newInputs.EmailDomainWhitelist.split(',').map((item) => { | ||||
|         return { key: item, text: item, value: item }; | ||||
|       })); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
| @@ -62,10 +67,9 @@ const SystemSetting = () => { | ||||
|       case 'PasswordRegisterEnabled': | ||||
|       case 'EmailVerificationEnabled': | ||||
|       case 'GitHubOAuthEnabled': | ||||
|       case 'DiscordOAuthEnabled': | ||||
|       case 'WeChatAuthEnabled': | ||||
|       case 'GoogleOAuthEnabled': | ||||
|       case 'TurnstileCheckEnabled': | ||||
|       case 'EmailDomainRestrictionEnabled': | ||||
|       case 'RegisterEnabled': | ||||
|         value = inputs[key] === 'true' ? 'false' : 'true'; | ||||
|         break; | ||||
| @@ -78,7 +82,12 @@ const SystemSetting = () => { | ||||
|     }); | ||||
|     const { success, message } = res.data; | ||||
|     if (success) { | ||||
|       setInputs((inputs) => ({ ...inputs, [key]: value })); | ||||
|       if (key === 'EmailDomainWhitelist') { | ||||
|         value = value.split(','); | ||||
|       } | ||||
|       setInputs((inputs) => ({ | ||||
|         ...inputs, [key]: value | ||||
|       })); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
| @@ -92,15 +101,12 @@ const SystemSetting = () => { | ||||
|       name === 'ServerAddress' || | ||||
|       name === 'GitHubClientId' || | ||||
|       name === 'GitHubClientSecret' || | ||||
|       name === 'DiscordClientId' || | ||||
|       name === 'DiscordClientSecret' || | ||||
|       name === 'WeChatServerAddress' || | ||||
|       name === 'WeChatServerToken' || | ||||
|       name === 'WeChatAccountQRCodeImageURL' || | ||||
|       name === 'GoogleClientId' || | ||||
|       name === 'GoogleClientSecret' || | ||||
|       name === 'TurnstileSiteKey' || | ||||
|       name === 'TurnstileSecretKey' | ||||
|       name === 'TurnstileSecretKey' || | ||||
|       name === 'EmailDomainWhitelist' | ||||
|     ) { | ||||
|       setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
|     } else { | ||||
| @@ -137,6 +143,16 @@ const SystemSetting = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|  | ||||
|   const submitEmailDomainWhitelist = async () => { | ||||
|     if ( | ||||
|       originInputs['EmailDomainWhitelist'] !== inputs.EmailDomainWhitelist.join(',') && | ||||
|       inputs.SMTPToken !== '' | ||||
|     ) { | ||||
|       await updateOption('EmailDomainWhitelist', inputs.EmailDomainWhitelist.join(',')); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const submitWeChat = async () => { | ||||
|     if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) { | ||||
|       await updateOption( | ||||
| @@ -161,18 +177,6 @@ const SystemSetting = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const submitGoogleOAuth = async () => { | ||||
|     if (originInputs['GoogleClientId'] !== inputs.GoogleClientId) { | ||||
|       await updateOption('GoogleClientId', inputs.GoogleClientId); | ||||
|     } | ||||
|     if ( | ||||
|       originInputs['GoogleClientSecret'] !== inputs.GoogleClientSecret && | ||||
|       inputs.GoogleClientSecret !== '' | ||||
|     ) { | ||||
|       await updateOption('GoogleClientSecret', inputs.GoogleClientSecret); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const submitGitHubOAuth = async () => { | ||||
|     if (originInputs['GitHubClientId'] !== inputs.GitHubClientId) { | ||||
|       await updateOption('GitHubClientId', inputs.GitHubClientId); | ||||
| @@ -185,18 +189,6 @@ const SystemSetting = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const submitDiscordOAuth = async () => { | ||||
|     if (originInputs['DiscordClientId'] !== inputs.DiscordClientId) { | ||||
|       await updateOption('DiscordClientId', inputs.DiscordClientId); | ||||
|     } | ||||
|     if ( | ||||
|       originInputs['DiscordClientSecret'] !== inputs.DiscordClientSecret && | ||||
|       inputs.DiscordClientSecret !== '' | ||||
|     ) { | ||||
|       await updateOption('DiscordClientSecret', inputs.DiscordClientSecret); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const submitTurnstile = async () => { | ||||
|     if (originInputs['TurnstileSiteKey'] !== inputs.TurnstileSiteKey) { | ||||
|       await updateOption('TurnstileSiteKey', inputs.TurnstileSiteKey); | ||||
| @@ -209,6 +201,22 @@ const SystemSetting = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const submitNewRestrictedDomain = () => { | ||||
|     const localDomainList = inputs.EmailDomainWhitelist; | ||||
|     if (restrictedDomainInput !== '' && !localDomainList.includes(restrictedDomainInput)) { | ||||
|       setRestrictedDomainInput(''); | ||||
|       setInputs({ | ||||
|         ...inputs, | ||||
|         EmailDomainWhitelist: [...localDomainList, restrictedDomainInput], | ||||
|       }); | ||||
|       setEmailDomainWhitelist([...EmailDomainWhitelist, { | ||||
|         key: restrictedDomainInput, | ||||
|         text: restrictedDomainInput, | ||||
|         value: restrictedDomainInput, | ||||
|       }]); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   return ( | ||||
|     <Grid columns={1}> | ||||
|       <Grid.Column> | ||||
| @@ -247,24 +255,12 @@ const SystemSetting = () => { | ||||
|               name='EmailVerificationEnabled' | ||||
|               onChange={handleInputChange} | ||||
|             /> | ||||
|             <Form.Checkbox | ||||
|               checked={inputs.DiscordOAuthEnabled === 'true'} | ||||
|               label='允许通过 Discord 账户登录和注册' | ||||
|               name='DiscordOAuthEnabled' | ||||
|               onChange={handleInputChange} | ||||
|             /> | ||||
|             <Form.Checkbox | ||||
|               checked={inputs.GitHubOAuthEnabled === 'true'} | ||||
|               label='允许通过 GitHub 账户登录 & 注册' | ||||
|               name='GitHubOAuthEnabled' | ||||
|               onChange={handleInputChange} | ||||
|             /> | ||||
|             <Form.Checkbox | ||||
|               checked={inputs.GoogleOAuthEnabled === 'true'} | ||||
|               label='允许通过 Google 账户登录和注册' | ||||
|               name='GoogleOAuthEnabled' | ||||
|               onChange={handleInputChange} | ||||
|             /> | ||||
|             <Form.Checkbox | ||||
|               checked={inputs.WeChatAuthEnabled === 'true'} | ||||
|               label='允许通过微信登录 & 注册' | ||||
| @@ -287,6 +283,54 @@ const SystemSetting = () => { | ||||
|             /> | ||||
|           </Form.Group> | ||||
|           <Divider /> | ||||
|           <Header as='h3'> | ||||
|             配置邮箱域名白名单 | ||||
|             <Header.Subheader>用以防止恶意用户利用临时邮箱批量注册</Header.Subheader> | ||||
|           </Header> | ||||
|           <Form.Group widths={3}> | ||||
|             <Form.Checkbox | ||||
|               label='启用邮箱域名白名单' | ||||
|               name='EmailDomainRestrictionEnabled' | ||||
|               onChange={handleInputChange} | ||||
|               checked={inputs.EmailDomainRestrictionEnabled === 'true'} | ||||
|             /> | ||||
|           </Form.Group> | ||||
|           <Form.Group widths={2}> | ||||
|             <Form.Dropdown | ||||
|               label='允许的邮箱域名' | ||||
|               placeholder='允许的邮箱域名' | ||||
|               name='EmailDomainWhitelist' | ||||
|               required | ||||
|               fluid | ||||
|               multiple | ||||
|               selection | ||||
|               onChange={handleInputChange} | ||||
|               value={inputs.EmailDomainWhitelist} | ||||
|               autoComplete='new-password' | ||||
|               options={EmailDomainWhitelist} | ||||
|             /> | ||||
|             <Form.Input | ||||
|               label='添加新的允许的邮箱域名' | ||||
|               action={ | ||||
|                 <Button type='button' onClick={() => { | ||||
|                   submitNewRestrictedDomain(); | ||||
|                 }}>填入</Button> | ||||
|               } | ||||
|               onKeyDown={(e) => { | ||||
|                 if (e.key === 'Enter') { | ||||
|                   submitNewRestrictedDomain(); | ||||
|                 } | ||||
|               }} | ||||
|               autoComplete='new-password' | ||||
|               placeholder='输入新的允许的邮箱域名' | ||||
|               value={restrictedDomainInput} | ||||
|               onChange={(e, { value }) => { | ||||
|                 setRestrictedDomainInput(value); | ||||
|               }} | ||||
|             /> | ||||
|           </Form.Group> | ||||
|           <Form.Button onClick={submitEmailDomainWhitelist}>保存邮箱域名白名单设置</Form.Button> | ||||
|           <Divider /> | ||||
|           <Header as='h3'> | ||||
|             配置 SMTP | ||||
|             <Header.Subheader>用以支持系统的邮件发送</Header.Subheader> | ||||
| @@ -332,7 +376,7 @@ const SystemSetting = () => { | ||||
|               onChange={handleInputChange} | ||||
|               type='password' | ||||
|               autoComplete='new-password' | ||||
|               value={inputs.SMTPToken} | ||||
|               checked={inputs.RegisterEnabled === 'true'} | ||||
|               placeholder='敏感信息不会发送到前端显示' | ||||
|             /> | ||||
|           </Form.Group> | ||||
| @@ -420,82 +464,6 @@ const SystemSetting = () => { | ||||
|             保存 WeChat Server 设置 | ||||
|           </Form.Button> | ||||
|           <Divider /> | ||||
|           <Header as='h3'> | ||||
|             配置 Discord OAuth 应用程序 | ||||
|             <Header.Subheader> | ||||
|               用以支持通过 Discord 进行登录注册, | ||||
|               <a href='https://discord.com/developers/applications' target='_blank'> | ||||
|                 点击此处 | ||||
|               </a> | ||||
|               管理你的 Discord OAuth App | ||||
|             </Header.Subheader> | ||||
|           </Header> | ||||
|           <Message> | ||||
|             Homepage URL 填 <code>{inputs.ServerAddress}</code> | ||||
|             ,Authorization callback URL 填{' '} | ||||
|             <code>{`${inputs.ServerAddress}/oauth/discord`}</code> | ||||
|           </Message> | ||||
|           <Form.Group widths={3}> | ||||
|             <Form.Input | ||||
|               label='Discord 客户 ID' | ||||
|               name='DiscordClientId' | ||||
|               onChange={handleInputChange} | ||||
|               autoComplete='new-password' | ||||
|               value={inputs.DiscordClientId} | ||||
|               placeholder='输入您注册的 Discord OAuth APP 的 ID' | ||||
|             /> | ||||
|             <Form.Input | ||||
|               label='Discord 客户秘密' | ||||
|               name='DiscordClientSecret' | ||||
|               onChange={handleInputChange} | ||||
|               type='password' | ||||
|               autoComplete='new-password' | ||||
|               value={inputs.DiscordClientSecret} | ||||
|               placeholder='敏感信息不会发送到前端显示' | ||||
|             /> | ||||
|           </Form.Group> | ||||
|           <Form.Button onClick={submitDiscordOAuth}> | ||||
|             保存 Discord OAuth 设置 | ||||
|           </Form.Button> | ||||
|           <Divider /> | ||||
|           <Header as='h3'> | ||||
|             配置 Google OAuth 应用程序 | ||||
|             <Header.Subheader> | ||||
|               用以支持通过 Google 进行登录注册, | ||||
|               <a href='https://console.cloud.google.com/' target='_blank'> | ||||
|                 点击此处 | ||||
|               </a> | ||||
|               管理你的 Google OAuth App | ||||
|             </Header.Subheader> | ||||
|           </Header> | ||||
|           <Message> | ||||
|             Homepage URL 填 <code>{inputs.ServerAddress}</code> | ||||
|             ,Authorization callback URL 填{' '} | ||||
|             <code>{`${inputs.ServerAddress}/oauth/google`}</code> | ||||
|           </Message> | ||||
|           <Form.Group widths={3}> | ||||
|             <Form.Input | ||||
|               label='Google 客户 ID' | ||||
|               name='GoogleClientId' | ||||
|               onChange={handleInputChange} | ||||
|               autoComplete='new-password' | ||||
|               value={inputs.GoogleClientId} | ||||
|               placeholder='输入您注册的 Google OAuth APP 的 ID' | ||||
|             /> | ||||
|             <Form.Input | ||||
|               label='Google 客户秘密' | ||||
|               name='GoogleClientSecret' | ||||
|               onChange={handleInputChange} | ||||
|               type='password' | ||||
|               autoComplete='new-password' | ||||
|               value={inputs.GoogleClientSecret} | ||||
|               placeholder='敏感信息不会发送到前端显示' | ||||
|             /> | ||||
|           </Form.Group> | ||||
|           <Form.Button onClick={submitGoogleOAuth}> | ||||
|             保存 Google OAuth 设置 | ||||
|           </Form.Button> | ||||
|           <Divider /> | ||||
|           <Header as='h3'> | ||||
|             配置 Turnstile | ||||
|             <Header.Subheader> | ||||
|   | ||||
| @@ -1,11 +1,17 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Button, Form, Label, Modal, Pagination, Popup, Table } from 'semantic-ui-react'; | ||||
| import { Button, Dropdown, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; | ||||
| import { Link } from 'react-router-dom'; | ||||
| import { API, copy, showError, showSuccess, showWarning, timestamp2string } from '../helpers'; | ||||
|  | ||||
| import { ITEMS_PER_PAGE } from '../constants'; | ||||
| import { renderQuota } from '../helpers/render'; | ||||
|  | ||||
| const COPY_OPTIONS = [ | ||||
|   { key: 'next', text: 'ChatGPT Next Web', value: 'next' }, | ||||
|   { key: 'ama', text: 'AMA 问天', value: 'ama' }, | ||||
|   { key: 'opencat', text: 'OpenCat', value: 'opencat' }, | ||||
| ]; | ||||
|  | ||||
| function renderTimestamp(timestamp) { | ||||
|   return ( | ||||
|     <> | ||||
| @@ -68,7 +74,40 @@ const TokensTable = () => { | ||||
|   const refresh = async () => { | ||||
|     setLoading(true); | ||||
|     await loadTokens(activePage - 1); | ||||
|   } | ||||
|   }; | ||||
|  | ||||
|   const onCopy = async (type, key) => { | ||||
|     let status = localStorage.getItem('status'); | ||||
|     let serverAddress = ''; | ||||
|     if (status) { | ||||
|       status = JSON.parse(status); | ||||
|       serverAddress = status.server_address; | ||||
|     } | ||||
|     if (serverAddress === '') { | ||||
|       serverAddress = window.location.origin; | ||||
|     } | ||||
|     let encodedServerAddress = encodeURIComponent(serverAddress); | ||||
|     let url; | ||||
|     switch (type) { | ||||
|       case 'ama': | ||||
|         url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`; | ||||
|         break; | ||||
|       case 'opencat': | ||||
|         url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`; | ||||
|         break; | ||||
|       case 'next': | ||||
|         url = `https://chat.oneapi.pro/#/?settings=%7B%22key%22:%22sk-${key}%22,%22url%22:%22${serverAddress}%22%7D`; | ||||
|         break; | ||||
|       default: | ||||
|         url = `sk-${key}`; | ||||
|     } | ||||
|     if (await copy(url)) { | ||||
|       showSuccess('已复制到剪贴板!'); | ||||
|     } else { | ||||
|       showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。'); | ||||
|       setSearchKeyword(url); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   useEffect(() => { | ||||
|     loadTokens(0) | ||||
| @@ -235,21 +274,28 @@ const TokensTable = () => { | ||||
|                   <Table.Cell>{token.expired_time === -1 ? '永不过期' : renderTimestamp(token.expired_time)}</Table.Cell> | ||||
|                   <Table.Cell> | ||||
|                     <div> | ||||
|                       <Button | ||||
|                         size={'small'} | ||||
|                         positive | ||||
|                         onClick={async () => { | ||||
|                           let key = "sk-" + token.key; | ||||
|                           if (await copy(key)) { | ||||
|                             showSuccess('已复制到剪贴板!'); | ||||
|                           } else { | ||||
|                             showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。'); | ||||
|                             setSearchKeyword(key); | ||||
|                       <Button.Group color='green' size={'small'}> | ||||
|                         <Button | ||||
|                           size={'small'} | ||||
|                           positive | ||||
|                           onClick={async () => { | ||||
|                             await onCopy('', token.key); | ||||
|                           } | ||||
|                         }} | ||||
|                       > | ||||
|                         复制 | ||||
|                       </Button> | ||||
|                           } | ||||
|                         > | ||||
|                           复制 | ||||
|                         </Button> | ||||
|                         <Dropdown | ||||
|                           className='button icon' | ||||
|                           floating | ||||
|                           options={COPY_OPTIONS} | ||||
|                           onChange={async (e, { value } = {}) => { | ||||
|                             await onCopy(value, token.key); | ||||
|                           }} | ||||
|                           trigger={<></>} | ||||
|                         /> | ||||
|                       </Button.Group> | ||||
|                       {' '} | ||||
|                       <Popup | ||||
|                         trigger={ | ||||
|                           <Button size='small' negative> | ||||
|   | ||||
| @@ -227,7 +227,7 @@ const UsersTable = () => { | ||||
|                       content={user.email ? user.email : '未绑定邮箱地址'} | ||||
|                       key={user.username} | ||||
|                       header={user.display_name ? user.display_name : user.username} | ||||
|                       trigger={<span>{renderText(user.username, 10)}</span>} | ||||
|                       trigger={<span>{renderText(user.username, 15)}</span>} | ||||
|                       hoverable | ||||
|                     /> | ||||
|                   </Table.Cell> | ||||
|   | ||||
| @@ -4,6 +4,8 @@ export const CHANNEL_OPTIONS = [ | ||||
|   { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, | ||||
|   { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, | ||||
|   { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, | ||||
|   { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, | ||||
|   { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, | ||||
|   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, | ||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||
|   { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| export const toastConstants = { | ||||
|   SUCCESS_TIMEOUT: 500, | ||||
|   SUCCESS_TIMEOUT: 1500, | ||||
|   INFO_TIMEOUT: 3000, | ||||
|   ERROR_TIMEOUT: 5000, | ||||
|   WARNING_TIMEOUT: 10000, | ||||
|   | ||||
| @@ -22,8 +22,6 @@ const EditChannel = () => { | ||||
|     base_url: '', | ||||
|     other: '', | ||||
|     model_mapping: '', | ||||
|     allow_streaming: 1, | ||||
|     allow_non_streaming: 1, | ||||
|     models: [], | ||||
|     groups: ['default'] | ||||
|   }; | ||||
| @@ -37,6 +35,30 @@ const EditChannel = () => { | ||||
|   const [customModel, setCustomModel] = useState(''); | ||||
|   const handleInputChange = (e, { name, value }) => { | ||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
|     if (name === 'type' && inputs.models.length === 0) { | ||||
|       let localModels = []; | ||||
|       switch (value) { | ||||
|         case 14: | ||||
|           localModels = ['claude-instant-1', 'claude-2']; | ||||
|           break; | ||||
|         case 11: | ||||
|           localModels = ['PaLM-2']; | ||||
|           break; | ||||
|         case 15: | ||||
|           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; | ||||
|           break; | ||||
|         case 17: | ||||
|           localModels = ['qwen-v1', 'qwen-plus-v1']; | ||||
|           break; | ||||
|         case 16: | ||||
|           localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||
|           break; | ||||
|         case 18: | ||||
|           localModels = ['SparkDesk']; | ||||
|           break; | ||||
|       } | ||||
|       setInputs((inputs) => ({ ...inputs, models: localModels })); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const loadChannel = async () => { | ||||
| @@ -96,9 +118,6 @@ const EditChannel = () => { | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let localModelOptions = [...originModelOptions]; | ||||
|     if (!Array.isArray(inputs.models)) { | ||||
|       inputs.models = inputs.models.split(','); | ||||
|     } | ||||
|     inputs.models.forEach((model) => { | ||||
|       if (!localModelOptions.find((option) => option.key === model)) { | ||||
|         localModelOptions.push({ | ||||
| @@ -132,17 +151,15 @@ const EditChannel = () => { | ||||
|       showInfo('模型映射必须是合法的 JSON 格式!'); | ||||
|       return; | ||||
|     } | ||||
|     // allow streaming and allow non streaming cannot be both false | ||||
|     if (inputs.allow_streaming === 2 && inputs.allow_non_streaming === 2) { | ||||
|       showInfo('流式请求和非流式请求不能同时禁用!'); | ||||
|       return; | ||||
|     } | ||||
|     let localInputs = inputs; | ||||
|     if (localInputs.base_url.endsWith('/')) { | ||||
|       localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); | ||||
|     } | ||||
|     if (localInputs.type === 3 && localInputs.other === '') { | ||||
|       localInputs.other = '2023-03-15-preview'; | ||||
|       localInputs.other = '2023-06-01-preview'; | ||||
|     } | ||||
|     if (localInputs.model_mapping === '') { | ||||
|       localInputs.model_mapping = '{}'; | ||||
|     } | ||||
|     let res; | ||||
|     localInputs.models = localInputs.models.join(','); | ||||
| @@ -186,7 +203,7 @@ const EditChannel = () => { | ||||
|                 <Message> | ||||
|                   注意,<strong>模型部署名称必须和模型名称保持一致</strong>,因为 One API 会把请求体中的 model | ||||
|                   参数替换为你的部署名称(模型名称中的点会被剔除),<a target='_blank' | ||||
|                     href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>。 | ||||
|                                                                     href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>。 | ||||
|                 </Message> | ||||
|                 <Form.Field> | ||||
|                   <Form.Input | ||||
| @@ -202,7 +219,7 @@ const EditChannel = () => { | ||||
|                   <Form.Input | ||||
|                     label='默认 API 版本' | ||||
|                     name='other' | ||||
|                     placeholder={'请输入默认 API 版本,例如:2023-03-15-preview,该配置可以被实际的请求查询参数所覆盖'} | ||||
|                     placeholder={'请输入默认 API 版本,例如:2023-06-01-preview,该配置可以被实际的请求查询参数所覆盖'} | ||||
|                     onChange={handleInputChange} | ||||
|                     value={inputs.other} | ||||
|                     autoComplete='new-password' | ||||
| @@ -281,7 +298,7 @@ const EditChannel = () => { | ||||
|             <Input | ||||
|               action={ | ||||
|                 <Button type={'button'} onClick={() => { | ||||
|                   if (customModel.trim() === "") return; | ||||
|                   if (customModel.trim() === '') return; | ||||
|                   if (inputs.models.includes(customModel)) return; | ||||
|                   let localModels = [...inputs.models]; | ||||
|                   localModels.push(customModel); | ||||
| @@ -289,7 +306,7 @@ const EditChannel = () => { | ||||
|                   localModelOptions.push({ | ||||
|                     key: customModel, | ||||
|                     text: customModel, | ||||
|                     value: customModel, | ||||
|                     value: customModel | ||||
|                   }); | ||||
|                   setModelOptions(modelOptions => { | ||||
|                     return [...modelOptions, ...localModelOptions]; | ||||
| @@ -307,7 +324,7 @@ const EditChannel = () => { | ||||
|           </div> | ||||
|           <Form.Field> | ||||
|             <Form.TextArea | ||||
|               label='模型映射' | ||||
|               label='模型重定向' | ||||
|               placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`} | ||||
|               name='model_mapping' | ||||
|               onChange={handleInputChange} | ||||
| @@ -316,26 +333,6 @@ const EditChannel = () => { | ||||
|               autoComplete='new-password' | ||||
|             /> | ||||
|           </Form.Field> | ||||
|           <Form.Field> | ||||
|             <Form.Checkbox | ||||
|               checked={inputs.allow_streaming === 1} | ||||
|               label='允许流式请求' | ||||
|               name='allow_streaming' | ||||
|               onChange={() => { | ||||
|                 setInputs((inputs) => ({ ...inputs, allow_streaming: inputs.allow_streaming === 1 ? 2 : 1 })); | ||||
|               }} | ||||
|             /> | ||||
|           </Form.Field> | ||||
|           <Form.Field> | ||||
|             <Form.Checkbox | ||||
|               checked={inputs.allow_non_streaming === 1} | ||||
|               label='允许非流式请求' | ||||
|               name='allow_non_streaming' | ||||
|               onChange={() => { | ||||
|                 setInputs((inputs) => ({ ...inputs, allow_non_streaming: inputs.allow_non_streaming === 1 ? 2 : 1 })); | ||||
|               }} | ||||
|             /> | ||||
|           </Form.Field> | ||||
|           { | ||||
|             batch ? <Form.Field> | ||||
|               <Form.TextArea | ||||
| @@ -353,7 +350,7 @@ const EditChannel = () => { | ||||
|                 label='密钥' | ||||
|                 name='key' | ||||
|                 required | ||||
|                 placeholder={inputs.type === 15 ? "请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次" : '请输入渠道对应的鉴权密钥'} | ||||
|                 placeholder={inputs.type === 15 ? '请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} | ||||
|                 onChange={handleInputChange} | ||||
|                 value={inputs.key} | ||||
|                 autoComplete='new-password' | ||||
| @@ -384,7 +381,7 @@ const EditChannel = () => { | ||||
|               </Form.Field> | ||||
|             ) | ||||
|           } | ||||
|           <Button type={isEdit ? "button" : "submit"} positive onClick={submit}>提交</Button> | ||||
|           <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button> | ||||
|         </Form> | ||||
|       </Segment> | ||||
|     </> | ||||
|   | ||||
| @@ -97,24 +97,12 @@ const Home = () => { | ||||
|                           ? '已启用' | ||||
|                           : '未启用'} | ||||
|                       </p> | ||||
|                       <p> | ||||
|                         Discord 身份验证: | ||||
|                         {statusState?.status?.discord_oauth === true | ||||
|                           ? '已启用' | ||||
|                           : '未启用'} | ||||
|                       </p> | ||||
|                       <p> | ||||
|                         微信身份验证: | ||||
|                         {statusState?.status?.wechat_login === true | ||||
|                           ? '已启用' | ||||
|                           : '未启用'} | ||||
|                       </p> | ||||
|                       <p> | ||||
|                         Google 身份验证: | ||||
|                         {statusState?.status?.google_oauth === true | ||||
|                           ? '已启用' | ||||
|                           : '未启用'} | ||||
|                       </p> | ||||
|                       <p> | ||||
|                         Turnstile 用户校验: | ||||
|                         {statusState?.status?.turnstile_check === true | ||||
|   | ||||
| @@ -83,7 +83,7 @@ const EditToken = () => { | ||||
|       if (isEdit) { | ||||
|         showSuccess('令牌更新成功!'); | ||||
|       } else { | ||||
|         showSuccess('令牌创建成功!'); | ||||
|         showSuccess('令牌创建成功,请在列表页面点击复制获取令牌!'); | ||||
|         setInputs(originInputs); | ||||
|       } | ||||
|     } else { | ||||
|   | ||||
| @@ -13,15 +13,13 @@ const EditUser = () => { | ||||
|     display_name: '', | ||||
|     password: '', | ||||
|     github_id: '', | ||||
|     discord_id: '', | ||||
|     wechat_id: '', | ||||
|     google_id: '', | ||||
|     email: '', | ||||
|     quota: 0, | ||||
|     group: 'default' | ||||
|   }); | ||||
|   const [groupOptions, setGroupOptions] = useState([]); | ||||
|   const { username, display_name, password, github_id, wechat_id, email, quota, google_id, discord_id } = | ||||
|   const { username, display_name, password, github_id, wechat_id, email, quota, group } = | ||||
|     inputs; | ||||
|   const handleInputChange = (e, { name, value }) => { | ||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
| @@ -168,26 +166,6 @@ const EditUser = () => { | ||||
|               readOnly | ||||
|             /> | ||||
|           </Form.Field> | ||||
|           <Form.Field> | ||||
|             <Form.Input | ||||
|               label='已绑定的 Discord 账户' | ||||
|               name='discord_id' | ||||
|               value={discord_id} | ||||
|               autoComplete='new-password' | ||||
|               placeholder='此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改' | ||||
|               readOnly | ||||
|             /> | ||||
|           </Form.Field> | ||||
|           <Form.Field> | ||||
|             <Form.Input | ||||
|               label='已绑定的 Google 账户' | ||||
|               name='google_id' | ||||
|               value={google_id} | ||||
|               autoComplete='new-password' | ||||
|               placeholder='此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改' | ||||
|               readOnly | ||||
|             /> | ||||
|           </Form.Field> | ||||
|           <Form.Field> | ||||
|             <Form.Input | ||||
|               label='已绑定的邮箱账户' | ||||
|   | ||||
		Reference in New Issue
	
	Block a user