mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-26 11:23:43 +08:00 
			
		
		
		
	Compare commits
	
		
			30 Commits
		
	
	
		
			v0.6.4
			...
			v0.6.5-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 840ef80d94 | ||
|  | 9a2662af0d | ||
|  | 77f9e75654 | ||
|  | 5b41f57423 | ||
|  | 0bb7db0b44 | ||
|  | 4d61b9937b | ||
|  | 68605800af | ||
|  | c49778c254 | ||
|  | f02c7138ea | ||
|  | ca3228855a | ||
|  | f8cc63f00b | ||
|  | 0a37aa4cbd | ||
|  | 054b00b725 | ||
|  | 76569bb0b6 | ||
|  | 1994256bac | ||
|  | 1f80b0a39f | ||
|  | f73f2e51df | ||
|  | 6f036bd0c9 | ||
|  | fb90747c23 | ||
|  | ed70881a58 | ||
|  | 8b9fa3d6e4 | ||
|  | 8b9813d63b | ||
|  | dc7aaf2de5 | ||
|  | 065da8ef8c | ||
|  | e3cfb1fa52 | ||
|  | f89ae5ad58 | ||
|  | 06a3fc5421 | ||
|  | a9c464ec5a | ||
|  | 3f3c13c98c | ||
|  | 2ba28c72cb | 
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -8,3 +8,4 @@ build | ||||
| logs | ||||
| data | ||||
| /web/node_modules | ||||
| cmd.md | ||||
| @@ -81,11 +81,12 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|    + [x] [Groq](https://wow.groq.com/) | ||||
|    + [x] [Ollama](https://github.com/ollama/ollama) | ||||
|    + [x] [零一万物](https://platform.lingyiwanwu.com/) | ||||
|    + [x] [阶跃星辰](https://platform.stepfun.com/) | ||||
| 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | ||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||
| 5. 支持**多机部署**,[详见此处](#多机部署)。 | ||||
| 6. 支持**令牌管理**,设置令牌的过期时间和额度。 | ||||
| 6. 支持**令牌管理**,设置令牌的过期时间、额度、允许的 IP 范围以及允许的模型访问。 | ||||
| 7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 | ||||
| 8. 支持**渠道管理**,批量创建渠道。 | ||||
| 9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 | ||||
| @@ -101,10 +102,11 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
| 19. 支持丰富的**自定义**设置, | ||||
|     1. 支持自定义系统名称,logo 以及页脚。 | ||||
|     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 | ||||
| 20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。 | ||||
| 20. 支持通过系统访问令牌调用管理 API,进而**在无需二开的情况下扩展和自定义** One API 的功能,详情请参考此处 [API 文档](./docs/API.md)。。 | ||||
| 21. 支持 Cloudflare Turnstile 用户校验。 | ||||
| 22. 支持用户管理,支持**多种用户登录注册方式**: | ||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||
|     + 支持使用飞书进行授权登录。 | ||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||
| 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 | ||||
|   | ||||
| @@ -66,6 +66,9 @@ var SMTPToken = "" | ||||
| var GitHubClientId = "" | ||||
| var GitHubClientSecret = "" | ||||
|  | ||||
| var LarkClientId = "" | ||||
| var LarkClientSecret = "" | ||||
|  | ||||
| var WeChatServerAddress = "" | ||||
| var WeChatServerToken = "" | ||||
| var WeChatAccountQRCodeImageURL = "" | ||||
|   | ||||
| @@ -71,6 +71,7 @@ const ( | ||||
| 	ChannelTypeGroq | ||||
| 	ChannelTypeOllama | ||||
| 	ChannelTypeLingYiWanWu | ||||
| 	ChannelTypeStepFun | ||||
|  | ||||
| 	ChannelTypeDummy | ||||
| ) | ||||
| @@ -108,6 +109,7 @@ var ChannelBaseURLs = []string{ | ||||
| 	"https://api.groq.com/openai",               // 29 | ||||
| 	"http://localhost:11434",                    // 30 | ||||
| 	"https://api.lingyiwanwu.com",               // 31 | ||||
| 	"https://api.stepfun.com",                   // 32 | ||||
| } | ||||
|  | ||||
| const ( | ||||
|   | ||||
							
								
								
									
										6
									
								
								common/conv/any.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								common/conv/any.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| package conv | ||||
|  | ||||
| func AsString(v any) string { | ||||
| 	str, _ := v.(string) | ||||
| 	return str | ||||
| } | ||||
| @@ -72,14 +72,22 @@ var ModelRatio = map[string]float64{ | ||||
| 	"claude-3-sonnet-20240229": 3.0 / 1000 * USD, | ||||
| 	"claude-3-opus-20240229":   15.0 / 1000 * USD, | ||||
| 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 | ||||
| 	"ERNIE-Bot":       0.8572,     // ¥0.012 / 1k tokens | ||||
| 	"ERNIE-Bot-turbo": 0.5715,     // ¥0.008 / 1k tokens | ||||
| 	"ERNIE-Bot-4":     0.12 * RMB, // ¥0.12 / 1k tokens | ||||
| 	"ERNIE-Bot-8k":    0.024 * RMB, | ||||
| 	"Embedding-V1":    0.1429, // ¥0.002 / 1k tokens | ||||
| 	"bge-large-zh":    0.002 * RMB, | ||||
| 	"bge-large-en":    0.002 * RMB, | ||||
| 	"bge-large-8k":    0.002 * RMB, | ||||
| 	"ERNIE-4.0-8K":       0.120 * RMB, | ||||
| 	"ERNIE-Bot-8K-0922":  0.024 * RMB, | ||||
| 	"ERNIE-3.5-8K":       0.012 * RMB, | ||||
| 	"ERNIE-Lite-8K-0922": 0.008 * RMB, | ||||
| 	"ERNIE-Speed-8K":     0.004 * RMB, | ||||
| 	"ERNIE-3.5-4K-0205":  0.012 * RMB, | ||||
| 	"ERNIE-3.5-8K-0205":  0.024 * RMB, | ||||
| 	"ERNIE-3.5-8K-1222":  0.012 * RMB, | ||||
| 	"ERNIE-Lite-8K":      0.003 * RMB, | ||||
| 	"ERNIE-Speed-128K":   0.004 * RMB, | ||||
| 	"ERNIE-Tiny-8K":      0.001 * RMB, | ||||
| 	"BLOOMZ-7B":          0.004 * RMB, | ||||
| 	"Embedding-V1":       0.002 * RMB, | ||||
| 	"bge-large-zh":       0.002 * RMB, | ||||
| 	"bge-large-en":       0.002 * RMB, | ||||
| 	"tao-8k":             0.002 * RMB, | ||||
| 	// https://ai.google.dev/pricing | ||||
| 	"PaLM-2":                    1, | ||||
| 	"gemini-pro":                1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||
| @@ -91,6 +99,7 @@ var ModelRatio = map[string]float64{ | ||||
| 	"glm-4":                     0.1 * RMB, | ||||
| 	"glm-4v":                    0.1 * RMB, | ||||
| 	"glm-3-turbo":               0.005 * RMB, | ||||
| 	"embedding-2":               0.0005 * RMB, | ||||
| 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens | ||||
| 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | ||||
| 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | ||||
|   | ||||
							
								
								
									
										25
									
								
								common/network/ip.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								common/network/ip.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | ||||
| package network | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"net" | ||||
| ) | ||||
|  | ||||
| func IsValidSubnet(subnet string) error { | ||||
| 	_, _, err := net.ParseCIDR(subnet) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to parse subnet: %w", err) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func IsIpInSubnet(ctx context.Context, ip string, subnet string) bool { | ||||
| 	_, ipNet, err := net.ParseCIDR(subnet) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf(ctx, "failed to parse subnet: %s", err.Error()) | ||||
| 		return false | ||||
| 	} | ||||
| 	return ipNet.Contains(net.ParseIP(ip)) | ||||
| } | ||||
							
								
								
									
										19
									
								
								common/network/ip_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								common/network/ip_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| package network | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"testing" | ||||
|  | ||||
| 	. "github.com/smartystreets/goconvey/convey" | ||||
| ) | ||||
|  | ||||
| func TestIsIpInSubnet(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	ip1 := "192.168.0.5" | ||||
| 	ip2 := "125.216.250.89" | ||||
| 	subnet := "192.168.0.0/24" | ||||
| 	Convey("TestIsIpInSubnet", t, func() { | ||||
| 		So(IsIpInSubnet(ctx, ip1, subnet), ShouldBeTrue) | ||||
| 		So(IsIpInSubnet(ctx, ip2, subnet), ShouldBeFalse) | ||||
| 	}) | ||||
| } | ||||
| @@ -1,4 +1,4 @@ | ||||
| package controller | ||||
| package auth | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| @@ -11,6 +11,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| @@ -159,7 +160,7 @@ func GitHubOAuth(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	setupLogin(&user, c) | ||||
| 	controller.SetupLogin(&user, c) | ||||
| } | ||||
| 
 | ||||
| func GitHubBind(c *gin.Context) { | ||||
							
								
								
									
										201
									
								
								controller/auth/lark.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								controller/auth/lark.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,201 @@ | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type LarkOAuthResponse struct { | ||||
| 	AccessToken string `json:"access_token"` | ||||
| } | ||||
|  | ||||
| type LarkUser struct { | ||||
| 	Name   string `json:"name"` | ||||
| 	OpenID string `json:"open_id"` | ||||
| } | ||||
|  | ||||
| func getLarkUserInfoByCode(code string) (*LarkUser, error) { | ||||
| 	if code == "" { | ||||
| 		return nil, errors.New("无效的参数") | ||||
| 	} | ||||
| 	values := map[string]string{ | ||||
| 		"client_id":     config.LarkClientId, | ||||
| 		"client_secret": config.LarkClientSecret, | ||||
| 		"code":          code, | ||||
| 		"grant_type":    "authorization_code", | ||||
| 		"redirect_uri":  fmt.Sprintf("%s/oauth/lark", config.ServerAddress), | ||||
| 	} | ||||
| 	jsonData, err := json.Marshal(values) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/json") | ||||
| 	req.Header.Set("Accept", "application/json") | ||||
| 	client := http.Client{ | ||||
| 		Timeout: 5 * time.Second, | ||||
| 	} | ||||
| 	res, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		logger.SysLog(err.Error()) | ||||
| 		return nil, errors.New("无法连接至飞书服务器,请稍后重试!") | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
| 	var oAuthResponse LarkOAuthResponse | ||||
| 	err = json.NewDecoder(res.Body).Decode(&oAuthResponse) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) | ||||
| 	res2, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		logger.SysLog(err.Error()) | ||||
| 		return nil, errors.New("无法连接至飞书服务器,请稍后重试!") | ||||
| 	} | ||||
| 	var larkUser LarkUser | ||||
| 	err = json.NewDecoder(res2.Body).Decode(&larkUser) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &larkUser, nil | ||||
| } | ||||
|  | ||||
| func LarkOAuth(c *gin.Context) { | ||||
| 	session := sessions.Default(c) | ||||
| 	state := c.Query("state") | ||||
| 	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { | ||||
| 		c.JSON(http.StatusForbidden, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "state is empty or not same", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	username := session.Get("username") | ||||
| 	if username != nil { | ||||
| 		LarkBind(c) | ||||
| 		return | ||||
| 	} | ||||
| 	code := c.Query("code") | ||||
| 	larkUser, err := getLarkUserInfoByCode(code) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user := model.User{ | ||||
| 		LarkId: larkUser.OpenID, | ||||
| 	} | ||||
| 	if model.IsLarkIdAlreadyTaken(user.LarkId) { | ||||
| 		err := user.FillUserByLarkId() | ||||
| 		if err != nil { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": err.Error(), | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} else { | ||||
| 		if config.RegisterEnabled { | ||||
| 			user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1) | ||||
| 			if larkUser.Name != "" { | ||||
| 				user.DisplayName = larkUser.Name | ||||
| 			} else { | ||||
| 				user.DisplayName = "Lark User" | ||||
| 			} | ||||
| 			user.Role = common.RoleCommonUser | ||||
| 			user.Status = common.UserStatusEnabled | ||||
|  | ||||
| 			if err := user.Insert(0); err != nil { | ||||
| 				c.JSON(http.StatusOK, gin.H{ | ||||
| 					"success": false, | ||||
| 					"message": err.Error(), | ||||
| 				}) | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			c.JSON(http.StatusOK, gin.H{ | ||||
| 				"success": false, | ||||
| 				"message": "管理员关闭了新用户注册", | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if user.Status != common.UserStatusEnabled { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"message": "用户已被封禁", | ||||
| 			"success": false, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	controller.SetupLogin(&user, c) | ||||
| } | ||||
|  | ||||
| func LarkBind(c *gin.Context) { | ||||
| 	code := c.Query("code") | ||||
| 	larkUser, err := getLarkUserInfoByCode(code) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user := model.User{ | ||||
| 		LarkId: larkUser.OpenID, | ||||
| 	} | ||||
| 	if model.IsLarkIdAlreadyTaken(user.LarkId) { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "该飞书账户已被绑定", | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	session := sessions.Default(c) | ||||
| 	id := session.Get("id") | ||||
| 	// id := c.GetInt("id")  // critical bug! | ||||
| 	user.Id = id.(int) | ||||
| 	err = user.FillUserById() | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	user.LarkId = larkUser.OpenID | ||||
| 	err = user.Update(false) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "bind", | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
| @@ -1,4 +1,4 @@ | ||||
| package controller | ||||
| package auth | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| @@ -7,6 +7,7 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| @@ -109,7 +110,7 @@ func WeChatAuth(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	setupLogin(&user, c) | ||||
| 	controller.SetupLogin(&user, c) | ||||
| } | ||||
| 
 | ||||
| func WeChatBind(c *gin.Context) { | ||||
| @@ -23,6 +23,7 @@ func GetStatus(c *gin.Context) { | ||||
| 			"email_verification":  config.EmailVerificationEnabled, | ||||
| 			"github_oauth":        config.GitHubOAuthEnabled, | ||||
| 			"github_client_id":    config.GitHubClientId, | ||||
| 			"lark_client_id":      config.LarkClientId, | ||||
| 			"system_name":         config.SystemName, | ||||
| 			"logo":                config.Logo, | ||||
| 			"footer_html":         config.Footer, | ||||
|   | ||||
| @@ -4,12 +4,14 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/helper" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://platform.openai.com/docs/api-reference/models/list | ||||
| @@ -120,9 +122,41 @@ func DashboardListModels(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func ListModels(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	var availableModels []string | ||||
| 	if c.GetString("available_models") != "" { | ||||
| 		availableModels = strings.Split(c.GetString("available_models"), ",") | ||||
| 	} else { | ||||
| 		userId := c.GetInt("id") | ||||
| 		userGroup, _ := model.CacheGetUserGroup(userId) | ||||
| 		availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) | ||||
| 	} | ||||
| 	modelSet := make(map[string]bool) | ||||
| 	for _, availableModel := range availableModels { | ||||
| 		modelSet[availableModel] = true | ||||
| 	} | ||||
| 	availableOpenAIModels := make([]OpenAIModels, 0) | ||||
| 	for _, model := range openAIModels { | ||||
| 		if _, ok := modelSet[model.Id]; ok { | ||||
| 			modelSet[model.Id] = false | ||||
| 			availableOpenAIModels = append(availableOpenAIModels, model) | ||||
| 		} | ||||
| 	} | ||||
| 	for modelName, ok := range modelSet { | ||||
| 		if ok { | ||||
| 			availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ | ||||
| 				Id:      modelName, | ||||
| 				Object:  "model", | ||||
| 				Created: 1626777600, | ||||
| 				OwnedBy: "custom", | ||||
| 				Root:    modelName, | ||||
| 				Parent:  nil, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	c.JSON(200, gin.H{ | ||||
| 		"object": "list", | ||||
| 		"data":   openAIModels, | ||||
| 		"data":   availableOpenAIModels, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| @@ -142,3 +176,30 @@ func RetrieveModel(c *gin.Context) { | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetUserAvailableModels(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	id := c.GetInt("id") | ||||
| 	userGroup, err := model.CacheGetUserGroup(id) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	models, err := model.CacheGetGroupModels(ctx, userGroup) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data":    models, | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -1,10 +1,12 @@ | ||||
| package controller | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/network" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| @@ -104,6 +106,19 @@ func GetTokenStatus(c *gin.Context) { | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func validateToken(c *gin.Context, token model.Token) error { | ||||
| 	if len(token.Name) > 30 { | ||||
| 		return fmt.Errorf("令牌名称过长") | ||||
| 	} | ||||
| 	if token.Subnet != nil && *token.Subnet != "" { | ||||
| 		err := network.IsValidSubnet(*token.Subnet) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("无效的网段:%s", err.Error()) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func AddToken(c *gin.Context) { | ||||
| 	token := model.Token{} | ||||
| 	err := c.ShouldBindJSON(&token) | ||||
| @@ -114,13 +129,15 @@ func AddToken(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	if len(token.Name) > 30 { | ||||
| 	err = validateToken(c, token) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "令牌名称过长", | ||||
| 			"message": fmt.Sprintf("参数错误:%s", err.Error()), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	cleanToken := model.Token{ | ||||
| 		UserId:         c.GetInt("id"), | ||||
| 		Name:           token.Name, | ||||
| @@ -130,6 +147,8 @@ func AddToken(c *gin.Context) { | ||||
| 		ExpiredTime:    token.ExpiredTime, | ||||
| 		RemainQuota:    token.RemainQuota, | ||||
| 		UnlimitedQuota: token.UnlimitedQuota, | ||||
| 		Models:         token.Models, | ||||
| 		Subnet:         token.Subnet, | ||||
| 	} | ||||
| 	err = cleanToken.Insert() | ||||
| 	if err != nil { | ||||
| @@ -177,10 +196,11 @@ func UpdateToken(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	if len(token.Name) > 30 { | ||||
| 	err = validateToken(c, token) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": "令牌名称过长", | ||||
| 			"message": fmt.Sprintf("参数错误:%s", err.Error()), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| @@ -216,6 +236,8 @@ func UpdateToken(c *gin.Context) { | ||||
| 		cleanToken.ExpiredTime = token.ExpiredTime | ||||
| 		cleanToken.RemainQuota = token.RemainQuota | ||||
| 		cleanToken.UnlimitedQuota = token.UnlimitedQuota | ||||
| 		cleanToken.Models = token.Models | ||||
| 		cleanToken.Subnet = token.Subnet | ||||
| 	} | ||||
| 	err = cleanToken.Update() | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -58,11 +58,11 @@ func Login(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	setupLogin(&user, c) | ||||
| 	SetupLogin(&user, c) | ||||
| } | ||||
|  | ||||
| // setup session & cookies and then return user info | ||||
| func setupLogin(user *model.User, c *gin.Context) { | ||||
| func SetupLogin(user *model.User, c *gin.Context) { | ||||
| 	session := sessions.Default(c) | ||||
| 	session.Set("id", user.Id) | ||||
| 	session.Set("username", user.Username) | ||||
| @@ -180,27 +180,27 @@ func Register(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func GetAllUsers(c *gin.Context) { | ||||
|     p, _ := strconv.Atoi(c.Query("p")) | ||||
|     if p < 0 { | ||||
|         p = 0 | ||||
|     } | ||||
|      | ||||
|     order := c.DefaultQuery("order", "") | ||||
|     users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) | ||||
| 	 | ||||
|     if err != nil { | ||||
|         c.JSON(http.StatusOK, gin.H{ | ||||
|             "success": false, | ||||
|             "message": err.Error(), | ||||
|         }) | ||||
|         return | ||||
|     } | ||||
|      | ||||
|     c.JSON(http.StatusOK, gin.H{ | ||||
|         "success": true, | ||||
|         "message": "", | ||||
|         "data":    users, | ||||
|     }) | ||||
| 	p, _ := strconv.Atoi(c.Query("p")) | ||||
| 	if p < 0 { | ||||
| 		p = 0 | ||||
| 	} | ||||
|  | ||||
| 	order := c.DefaultQuery("order", "") | ||||
| 	users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 		"data":    users, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func SearchUsers(c *gin.Context) { | ||||
| @@ -770,3 +770,38 @@ func TopUp(c *gin.Context) { | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| type adminTopUpRequest struct { | ||||
| 	UserId int    `json:"user_id"` | ||||
| 	Quota  int    `json:"quota"` | ||||
| 	Remark string `json:"remark"` | ||||
| } | ||||
|  | ||||
| func AdminTopUp(c *gin.Context) { | ||||
| 	req := adminTopUpRequest{} | ||||
| 	err := c.ShouldBindJSON(&req) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	err = model.IncreaseUserQuota(req.UserId, int64(req.Quota)) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	if req.Remark == "" { | ||||
| 		req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) | ||||
| 	} | ||||
| 	model.RecordTopupLog(req.UserId, req.Remark, req.Quota) | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|   | ||||
							
								
								
									
										53
									
								
								docs/API.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								docs/API.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | ||||
| # 使用 API 操控 & 扩展 One API | ||||
| > 欢迎提交 PR 在此放上你的拓展项目。 | ||||
|  | ||||
| 例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。 | ||||
|  | ||||
| 又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。 | ||||
|  | ||||
| ## 鉴权 | ||||
| One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下图获取: | ||||
|  | ||||
|  | ||||
|  | ||||
| 之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API: | ||||
|  | ||||
|  | ||||
| ## 请求格式与响应格式 | ||||
| One API 使用 JSON 格式进行请求和响应。 | ||||
|  | ||||
| 对于响应体,一般格式如下: | ||||
| ```json | ||||
| { | ||||
|   "message": "请求信息", | ||||
|   "success": true, | ||||
|   "data": {} | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## API 列表 | ||||
| > 当前 API 列表不全,请自行通过浏览器抓取前端请求 | ||||
|  | ||||
| 如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。 | ||||
|  | ||||
| ### 获取当前登录用户信息 | ||||
| **GET** `/api/user/self` | ||||
|  | ||||
| ### 为给定用户充值额度 | ||||
| **POST** `/api/topup` | ||||
| ```json | ||||
| { | ||||
|   "user_id": 1, | ||||
|   "quota": 100000, | ||||
|   "remark": "充值 100000 额度" | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## 其他 | ||||
| ### 充值链接上的附加参数 | ||||
| One API 会在用户点击充值按钮的时候,将用户的信息和充值信息附加在链接上,例如: | ||||
| `https://example.com?username=root&user_id=1&transaction_id=4b3eed80-55d5-443f-bd44-fb18c648c837` | ||||
|  | ||||
| 你可以通过解析链接上的参数来获取用户信息和充值信息,然后调用 API 来为用户充值。 | ||||
|  | ||||
| 注意,不是所有主题都支持该功能,欢迎 PR 补齐。 | ||||
							
								
								
									
										4
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.mod
									
									
									
									
									
								
							| @@ -15,6 +15,7 @@ require ( | ||||
| 	github.com/google/uuid v1.3.0 | ||||
| 	github.com/gorilla/websocket v1.5.0 | ||||
| 	github.com/pkoukk/tiktoken-go v0.1.5 | ||||
| 	github.com/smartystreets/goconvey v1.8.1 | ||||
| 	github.com/stretchr/testify v1.8.3 | ||||
| 	golang.org/x/crypto v0.17.0 | ||||
| 	golang.org/x/image v0.14.0 | ||||
| @@ -37,6 +38,7 @@ require ( | ||||
| 	github.com/go-playground/universal-translator v0.18.1 // indirect | ||||
| 	github.com/go-sql-driver/mysql v1.6.0 // indirect | ||||
| 	github.com/goccy/go-json v0.10.2 // indirect | ||||
| 	github.com/gopherjs/gopherjs v1.17.2 // indirect | ||||
| 	github.com/gorilla/context v1.1.1 // indirect | ||||
| 	github.com/gorilla/securecookie v1.1.1 // indirect | ||||
| 	github.com/gorilla/sessions v1.2.1 // indirect | ||||
| @@ -47,6 +49,7 @@ require ( | ||||
| 	github.com/jinzhu/inflection v1.0.0 // indirect | ||||
| 	github.com/jinzhu/now v1.1.5 // indirect | ||||
| 	github.com/json-iterator/go v1.1.12 // indirect | ||||
| 	github.com/jtolds/gls v4.20.0+incompatible // indirect | ||||
| 	github.com/klauspost/cpuid/v2 v2.2.4 // indirect | ||||
| 	github.com/leodido/go-urn v1.2.4 // indirect | ||||
| 	github.com/mattn/go-isatty v0.0.19 // indirect | ||||
| @@ -55,6 +58,7 @@ require ( | ||||
| 	github.com/modern-go/reflect2 v1.0.2 // indirect | ||||
| 	github.com/pelletier/go-toml/v2 v2.0.8 // indirect | ||||
| 	github.com/pmezard/go-difflib v1.0.0 // indirect | ||||
| 	github.com/smarty/assertions v1.15.0 // indirect | ||||
| 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | ||||
| 	github.com/ugorji/go/codec v1.2.11 // indirect | ||||
| 	golang.org/x/arch v0.3.0 // indirect | ||||
|   | ||||
							
								
								
									
										12
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								go.sum
									
									
									
									
									
								
							| @@ -56,11 +56,13 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL | ||||
| github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= | ||||
| github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= | ||||
| github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= | ||||
| github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= | ||||
| github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | ||||
| github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= | ||||
| github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= | ||||
| github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= | ||||
| github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= | ||||
| github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= | ||||
| github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= | ||||
| github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= | ||||
| github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= | ||||
| github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= | ||||
| @@ -85,6 +87,8 @@ github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/ | ||||
| github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= | ||||
| github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= | ||||
| github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= | ||||
| github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= | ||||
| github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= | ||||
| github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= | ||||
| github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= | ||||
| github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= | ||||
| @@ -127,6 +131,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN | ||||
| github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= | ||||
| github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= | ||||
| github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= | ||||
| github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= | ||||
| github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= | ||||
| github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= | ||||
| github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= | ||||
| github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= | ||||
| github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= | ||||
| github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= | ||||
| @@ -177,8 +185,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= | ||||
| golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= | ||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= | ||||
| google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= | ||||
| google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= | ||||
| google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= | ||||
|   | ||||
| @@ -1,10 +1,12 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/blacklist" | ||||
| 	"github.com/songquanpeng/one-api/common/network" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| @@ -88,6 +90,7 @@ func RootAuth() func(c *gin.Context) { | ||||
|  | ||||
| func TokenAuth() func(c *gin.Context) { | ||||
| 	return func(c *gin.Context) { | ||||
| 		ctx := c.Request.Context() | ||||
| 		key := c.Request.Header.Get("Authorization") | ||||
| 		key = strings.TrimPrefix(key, "Bearer ") | ||||
| 		key = strings.TrimPrefix(key, "sk-") | ||||
| @@ -98,6 +101,12 @@ func TokenAuth() func(c *gin.Context) { | ||||
| 			abortWithMessage(c, http.StatusUnauthorized, err.Error()) | ||||
| 			return | ||||
| 		} | ||||
| 		if token.Subnet != nil && *token.Subnet != "" { | ||||
| 			if !network.IsIpInSubnet(ctx, c.ClientIP(), *token.Subnet) { | ||||
| 				abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP())) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 		userEnabled, err := model.CacheIsUserEnabled(token.UserId) | ||||
| 		if err != nil { | ||||
| 			abortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||
| @@ -107,6 +116,19 @@ func TokenAuth() func(c *gin.Context) { | ||||
| 			abortWithMessage(c, http.StatusForbidden, "用户已被封禁") | ||||
| 			return | ||||
| 		} | ||||
| 		requestModel, err := getRequestModel(c) | ||||
| 		if err != nil && !strings.HasPrefix(c.Request.URL.Path, "/v1/models") { | ||||
| 			abortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||
| 			return | ||||
| 		} | ||||
| 		c.Set("request_model", requestModel) | ||||
| 		if token.Models != nil && *token.Models != "" { | ||||
| 			c.Set("available_models", *token.Models) | ||||
| 			if requestModel != "" && !isModelInList(requestModel, *token.Models) { | ||||
| 				abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 		c.Set("id", token.UserId) | ||||
| 		c.Set("token_id", token.Id) | ||||
| 		c.Set("token_name", token.Name) | ||||
|   | ||||
| @@ -2,14 +2,12 @@ package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
|  | ||||
| type ModelRequest struct { | ||||
| @@ -40,37 +38,11 @@ func Distribute() func(c *gin.Context) { | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			// Select a channel for the user | ||||
| 			var modelRequest ModelRequest | ||||
| 			err := common.UnmarshalBodyReusable(c, &modelRequest) | ||||
| 			requestModel := c.GetString("request_model") | ||||
| 			var err error | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false) | ||||
| 			if err != nil { | ||||
| 				abortWithMessage(c, http.StatusBadRequest, "无效的请求") | ||||
| 				return | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "text-moderation-stable" | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = c.Param("model") | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "dall-e-2" | ||||
| 				} | ||||
| 			} | ||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { | ||||
| 				if modelRequest.Model == "" { | ||||
| 					modelRequest.Model = "whisper-1" | ||||
| 				} | ||||
| 			} | ||||
| 			requestModel = modelRequest.Model | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false) | ||||
| 			if err != nil { | ||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) | ||||
| 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel) | ||||
| 				if channel != nil { | ||||
| 					logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | ||||
| 					message = "数据库一致性已被破坏,请联系管理员" | ||||
|   | ||||
| @@ -1,9 +1,12 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func abortWithMessage(c *gin.Context, statusCode int, message string) { | ||||
| @@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { | ||||
| 	c.Abort() | ||||
| 	logger.Error(c.Request.Context(), message) | ||||
| } | ||||
|  | ||||
| func getRequestModel(c *gin.Context) (string, error) { | ||||
| 	var modelRequest ModelRequest | ||||
| 	err := common.UnmarshalBodyReusable(c, &modelRequest) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) | ||||
| 	} | ||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||
| 		if modelRequest.Model == "" { | ||||
| 			modelRequest.Model = "text-moderation-stable" | ||||
| 		} | ||||
| 	} | ||||
| 	if strings.HasSuffix(c.Request.URL.Path, "embeddings") { | ||||
| 		if modelRequest.Model == "" { | ||||
| 			modelRequest.Model = c.Param("model") | ||||
| 		} | ||||
| 	} | ||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { | ||||
| 		if modelRequest.Model == "" { | ||||
| 			modelRequest.Model = "dall-e-2" | ||||
| 		} | ||||
| 	} | ||||
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { | ||||
| 		if modelRequest.Model == "" { | ||||
| 			modelRequest.Model = "whisper-1" | ||||
| 		} | ||||
| 	} | ||||
| 	return modelRequest.Model, nil | ||||
| } | ||||
|  | ||||
| func isModelInList(modelName string, models string) bool { | ||||
| 	modelList := strings.Split(models, ",") | ||||
| 	for _, model := range modelList { | ||||
| 		if modelName == model { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|   | ||||
| @@ -1,8 +1,10 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"gorm.io/gorm" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| @@ -88,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error { | ||||
| func UpdateAbilityStatus(channelId int, status bool) error { | ||||
| 	return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error | ||||
| } | ||||
|  | ||||
| func GetGroupModels(ctx context.Context, group string) ([]string, error) { | ||||
| 	groupCol := "`group`" | ||||
| 	trueVal := "1" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		groupCol = `"group"` | ||||
| 		trueVal = "true" | ||||
| 	} | ||||
| 	var models []string | ||||
| 	err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	sort.Strings(models) | ||||
| 	return models, err | ||||
| } | ||||
|   | ||||
| @@ -21,6 +21,7 @@ var ( | ||||
| 	UserId2GroupCacheSeconds  = config.SyncFrequency | ||||
| 	UserId2QuotaCacheSeconds  = config.SyncFrequency | ||||
| 	UserId2StatusCacheSeconds = config.SyncFrequency | ||||
| 	GroupModelsCacheSeconds   = config.SyncFrequency | ||||
| ) | ||||
|  | ||||
| func CacheGetTokenByKey(key string) (*Token, error) { | ||||
| @@ -146,6 +147,25 @@ func CacheIsUserEnabled(userId int) (bool, error) { | ||||
| 	return userEnabled, err | ||||
| } | ||||
|  | ||||
| func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { | ||||
| 	if !common.RedisEnabled { | ||||
| 		return GetGroupModels(ctx, group) | ||||
| 	} | ||||
| 	modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group)) | ||||
| 	if err == nil { | ||||
| 		return strings.Split(modelsStr, ","), nil | ||||
| 	} | ||||
| 	models, err := GetGroupModels(ctx, group) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second) | ||||
| 	if err != nil { | ||||
| 		logger.SysError("Redis set group models error: " + err.Error()) | ||||
| 	} | ||||
| 	return models, nil | ||||
| } | ||||
|  | ||||
| var group2model2channels map[string]map[string][]*Channel | ||||
| var channelSyncLock sync.RWMutex | ||||
|  | ||||
|   | ||||
							
								
								
									
										15
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -51,6 +51,21 @@ func RecordLog(userId int, logType int, content string) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func RecordTopupLog(userId int, content string, quota int) { | ||||
| 	log := &Log{ | ||||
| 		UserId:    userId, | ||||
| 		Username:  GetUsernameById(userId), | ||||
| 		CreatedAt: helper.GetTimestamp(), | ||||
| 		Type:      LogTypeTopup, | ||||
| 		Content:   content, | ||||
| 		Quota:     quota, | ||||
| 	} | ||||
| 	err := LOG_DB.Create(log).Error | ||||
| 	if err != nil { | ||||
| 		logger.SysError("failed to record log: " + err.Error()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { | ||||
| 	logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||
| 	if !config.LogConsumeEnabled { | ||||
|   | ||||
| @@ -172,6 +172,10 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 		config.GitHubClientId = value | ||||
| 	case "GitHubClientSecret": | ||||
| 		config.GitHubClientSecret = value | ||||
| 	case "LarkClientId": | ||||
| 		config.LarkClientId = value | ||||
| 	case "LarkClientSecret": | ||||
| 		config.LarkClientSecret = value | ||||
| 	case "Footer": | ||||
| 		config.Footer = value | ||||
| 	case "SystemName": | ||||
|   | ||||
| @@ -12,24 +12,26 @@ import ( | ||||
| ) | ||||
|  | ||||
| type Token struct { | ||||
| 	Id             int    `json:"id"` | ||||
| 	UserId         int    `json:"user_id"` | ||||
| 	Key            string `json:"key" gorm:"type:char(48);uniqueIndex"` | ||||
| 	Status         int    `json:"status" gorm:"default:1"` | ||||
| 	Name           string `json:"name" gorm:"index" ` | ||||
| 	CreatedTime    int64  `json:"created_time" gorm:"bigint"` | ||||
| 	AccessedTime   int64  `json:"accessed_time" gorm:"bigint"` | ||||
| 	ExpiredTime    int64  `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired | ||||
| 	RemainQuota    int64  `json:"remain_quota" gorm:"bigint;default:0"` | ||||
| 	UnlimitedQuota bool   `json:"unlimited_quota" gorm:"default:false"` | ||||
| 	UsedQuota      int64  `json:"used_quota" gorm:"bigint;default:0"` // used quota | ||||
| 	Id             int     `json:"id"` | ||||
| 	UserId         int     `json:"user_id"` | ||||
| 	Key            string  `json:"key" gorm:"type:char(48);uniqueIndex"` | ||||
| 	Status         int     `json:"status" gorm:"default:1"` | ||||
| 	Name           string  `json:"name" gorm:"index" ` | ||||
| 	CreatedTime    int64   `json:"created_time" gorm:"bigint"` | ||||
| 	AccessedTime   int64   `json:"accessed_time" gorm:"bigint"` | ||||
| 	ExpiredTime    int64   `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired | ||||
| 	RemainQuota    int64   `json:"remain_quota" gorm:"bigint;default:0"` | ||||
| 	UnlimitedQuota bool    `json:"unlimited_quota" gorm:"default:false"` | ||||
| 	UsedQuota      int64   `json:"used_quota" gorm:"bigint;default:0"` // used quota | ||||
| 	Models         *string `json:"models" gorm:"default:''"`           // allowed models | ||||
| 	Subnet         *string `json:"subnet" gorm:"default:''"`           // allowed subnet | ||||
| } | ||||
|  | ||||
| func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) { | ||||
| 	var tokens []*Token | ||||
| 	var err error | ||||
| 	query := DB.Where("user_id = ?", userId) | ||||
| 	 | ||||
|  | ||||
| 	switch order { | ||||
| 	case "remain_quota": | ||||
| 		query = query.Order("unlimited_quota desc, remain_quota desc") | ||||
| @@ -38,7 +40,7 @@ func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token | ||||
| 	default: | ||||
| 		query = query.Order("id desc") | ||||
| 	} | ||||
| 	 | ||||
|  | ||||
| 	err = query.Limit(num).Offset(startIdx).Find(&tokens).Error | ||||
| 	return tokens, err | ||||
| } | ||||
| @@ -61,7 +63,7 @@ func ValidateUserToken(key string) (token *Token, err error) { | ||||
| 		return nil, errors.New("令牌验证失败") | ||||
| 	} | ||||
| 	if token.Status == common.TokenStatusExhausted { | ||||
| 		return nil, errors.New("该令牌额度已用尽") | ||||
| 		return nil, fmt.Errorf("令牌 %s(#%d)额度已用尽", token.Name, token.Id) | ||||
| 	} else if token.Status == common.TokenStatusExpired { | ||||
| 		return nil, errors.New("该令牌已过期") | ||||
| 	} | ||||
| @@ -121,7 +123,7 @@ func (token *Token) Insert() error { | ||||
| // Update Make sure your token's fields is completed, because this will update non-zero values | ||||
| func (token *Token) Update() error { | ||||
| 	var err error | ||||
| 	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error | ||||
| 	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -24,6 +24,7 @@ type User struct { | ||||
| 	Email            string `json:"email" gorm:"index" validate:"max=50"` | ||||
| 	GitHubId         string `json:"github_id" gorm:"column:github_id;index"` | ||||
| 	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"` | ||||
| 	LarkId           string `json:"lark_id" gorm:"column:lark_id;index"` | ||||
| 	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database! | ||||
| 	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management | ||||
| 	Quota            int64  `json:"quota" gorm:"bigint;default:0"` | ||||
| @@ -41,21 +42,21 @@ func GetMaxUserId() int { | ||||
| } | ||||
|  | ||||
| func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) { | ||||
|     query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted) | ||||
|      | ||||
|     switch order { | ||||
|     case "quota": | ||||
|         query = query.Order("quota desc") | ||||
|     case "used_quota": | ||||
|         query = query.Order("used_quota desc") | ||||
|     case "request_count": | ||||
|         query = query.Order("request_count desc") | ||||
|     default: | ||||
|         query = query.Order("id desc") | ||||
|     } | ||||
|      | ||||
|     err = query.Find(&users).Error | ||||
|     return users, err | ||||
| 	query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted) | ||||
|  | ||||
| 	switch order { | ||||
| 	case "quota": | ||||
| 		query = query.Order("quota desc") | ||||
| 	case "used_quota": | ||||
| 		query = query.Order("used_quota desc") | ||||
| 	case "request_count": | ||||
| 		query = query.Order("request_count desc") | ||||
| 	default: | ||||
| 		query = query.Order("id desc") | ||||
| 	} | ||||
|  | ||||
| 	err = query.Find(&users).Error | ||||
| 	return users, err | ||||
| } | ||||
|  | ||||
| func SearchUsers(keyword string) (users []*User, err error) { | ||||
| @@ -206,6 +207,14 @@ func (user *User) FillUserByGitHubId() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (user *User) FillUserByLarkId() error { | ||||
| 	if user.LarkId == "" { | ||||
| 		return errors.New("lark id 为空!") | ||||
| 	} | ||||
| 	DB.Where(User{LarkId: user.LarkId}).First(user) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (user *User) FillUserByWeChatId() error { | ||||
| 	if user.WeChatId == "" { | ||||
| 		return errors.New("WeChat id 为空!") | ||||
| @@ -234,6 +243,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool { | ||||
| 	return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|  | ||||
| func IsLarkIdAlreadyTaken(githubId string) bool { | ||||
| 	return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|  | ||||
| func IsUsernameAlreadyTaken(username string) bool { | ||||
| 	return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 | ||||
| } | ||||
|   | ||||
| @@ -48,6 +48,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 			MaxTokens:         request.MaxTokens, | ||||
| 			Temperature:       request.Temperature, | ||||
| 			TopP:              request.TopP, | ||||
| 			TopK:              request.TopK, | ||||
| 			ResultFormat:      "message", | ||||
| 			Tools:             request.Tools, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
| @@ -117,19 +120,11 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR | ||||
| } | ||||
|  | ||||
| func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: model.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Output.Text, | ||||
| 		}, | ||||
| 		FinishReason: response.Output.FinishReason, | ||||
| 	} | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      response.RequestId, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 		Choices: response.Output.Choices, | ||||
| 		Usage: model.Usage{ | ||||
| 			PromptTokens:     response.Usage.InputTokens, | ||||
| 			CompletionTokens: response.Usage.OutputTokens, | ||||
| @@ -140,10 +135,14 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| } | ||||
|  | ||||
| func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	if len(aliResponse.Output.Choices) == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
| 	aliChoice := aliResponse.Output.Choices[0] | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = aliResponse.Output.Text | ||||
| 	if aliResponse.Output.FinishReason != "null" { | ||||
| 		finishReason := aliResponse.Output.FinishReason | ||||
| 	choice.Delta = aliChoice.Message | ||||
| 	if aliChoice.FinishReason != "null" { | ||||
| 		finishReason := aliChoice.FinishReason | ||||
| 		choice.FinishReason = &finishReason | ||||
| 	} | ||||
| 	response := openai.ChatCompletionsStreamResponse{ | ||||
| @@ -204,6 +203,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens | ||||
| 			} | ||||
| 			response := streamResponseAli2OpenAI(&aliResponse) | ||||
| 			if response == nil { | ||||
| 				return true | ||||
| 			} | ||||
| 			//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) | ||||
| 			//lastResponseText = aliResponse.Output.Text | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| @@ -226,6 +228,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| } | ||||
|  | ||||
| func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	var aliResponse ChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| @@ -235,6 +238,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	logger.Debugf(ctx, "response body: %s\n", responseBody) | ||||
| 	err = json.Unmarshal(responseBody, &aliResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
|   | ||||
| @@ -1,5 +1,10 @@ | ||||
| package ali | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| type Message struct { | ||||
| 	Content string `json:"content"` | ||||
| 	Role    string `json:"role"` | ||||
| @@ -11,13 +16,15 @@ type Input struct { | ||||
| } | ||||
|  | ||||
| type Parameters struct { | ||||
| 	TopP              float64 `json:"top_p,omitempty"` | ||||
| 	TopK              int     `json:"top_k,omitempty"` | ||||
| 	Seed              uint64  `json:"seed,omitempty"` | ||||
| 	EnableSearch      bool    `json:"enable_search,omitempty"` | ||||
| 	IncrementalOutput bool    `json:"incremental_output,omitempty"` | ||||
| 	MaxTokens         int     `json:"max_tokens,omitempty"` | ||||
| 	Temperature       float64 `json:"temperature,omitempty"` | ||||
| 	TopP              float64      `json:"top_p,omitempty"` | ||||
| 	TopK              int          `json:"top_k,omitempty"` | ||||
| 	Seed              uint64       `json:"seed,omitempty"` | ||||
| 	EnableSearch      bool         `json:"enable_search,omitempty"` | ||||
| 	IncrementalOutput bool         `json:"incremental_output,omitempty"` | ||||
| 	MaxTokens         int          `json:"max_tokens,omitempty"` | ||||
| 	Temperature       float64      `json:"temperature,omitempty"` | ||||
| 	ResultFormat      string       `json:"result_format,omitempty"` | ||||
| 	Tools             []model.Tool `json:"tools,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
| @@ -62,8 +69,9 @@ type Usage struct { | ||||
| } | ||||
|  | ||||
| type Output struct { | ||||
| 	Text         string `json:"text"` | ||||
| 	FinishReason string `json:"finish_reason"` | ||||
| 	//Text         string                      `json:"text"` | ||||
| 	//FinishReason string                      `json:"finish_reason"` | ||||
| 	Choices []openai.TextResponseChoice `json:"choices"` | ||||
| } | ||||
|  | ||||
| type ChatResponse struct { | ||||
|   | ||||
| @@ -38,6 +38,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 		MaxTokens:   textRequest.MaxTokens, | ||||
| 		Temperature: textRequest.Temperature, | ||||
| 		TopP:        textRequest.TopP, | ||||
| 		TopK:        textRequest.TopK, | ||||
| 		Stream:      textRequest.Stream, | ||||
| 	} | ||||
| 	if claudeRequest.MaxTokens == 0 { | ||||
|   | ||||
| @@ -38,16 +38,26 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| 		suffix += "completions_pro" | ||||
| 	case "ERNIE-Bot-4": | ||||
| 		suffix += "completions_pro" | ||||
| 	case "ERNIE-3.5-8K": | ||||
| 		suffix += "completions" | ||||
| 	case "ERNIE-Bot-8K": | ||||
| 		suffix += "ernie_bot_8k" | ||||
| 	case "ERNIE-Bot": | ||||
| 		suffix += "completions" | ||||
| 	case "ERNIE-Speed": | ||||
| 		suffix += "ernie_speed" | ||||
| 	case "ERNIE-Bot-turbo": | ||||
| 		suffix += "eb-instant" | ||||
| 	case "ERNIE-Speed": | ||||
| 		suffix += "ernie_speed" | ||||
| 	case "ERNIE-Bot-8K": | ||||
| 		suffix += "ernie_bot_8k" | ||||
| 	case "ERNIE-4.0-8K": | ||||
| 		suffix += "completions_pro" | ||||
| 	case "ERNIE-3.5-8K": | ||||
| 		suffix += "completions" | ||||
| 	case "ERNIE-Speed-8K": | ||||
| 		suffix += "ernie_speed" | ||||
| 	case "ERNIE-Speed-128K": | ||||
| 		suffix += "ernie-speed-128k" | ||||
| 	case "ERNIE-Lite-8K": | ||||
| 		suffix += "ernie-lite-8k" | ||||
| 	case "ERNIE-Tiny-8K": | ||||
| 		suffix += "ernie-tiny-8k" | ||||
| 	case "BLOOMZ-7B": | ||||
| 		suffix += "bloomz_7b1" | ||||
| 	case "Embedding-V1": | ||||
| @@ -59,7 +69,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| 	case "tao-8k": | ||||
| 		suffix += "tao_8k" | ||||
| 	default: | ||||
| 		suffix += meta.ActualModelName | ||||
| 		suffix += strings.ToLower(meta.ActualModelName) | ||||
| 	} | ||||
| 	fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix) | ||||
| 	var accessToken string | ||||
|   | ||||
| @@ -1,11 +1,18 @@ | ||||
| package baidu | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"ERNIE-Bot-4", | ||||
| 	"ERNIE-Bot-8K", | ||||
| 	"ERNIE-Bot", | ||||
| 	"ERNIE-Speed", | ||||
| 	"ERNIE-Bot-turbo", | ||||
| 	"ERNIE-4.0-8K", | ||||
| 	"ERNIE-Bot-8K-0922", | ||||
| 	"ERNIE-3.5-8K", | ||||
| 	"ERNIE-Lite-8K-0922", | ||||
| 	"ERNIE-Speed-8K", | ||||
| 	"ERNIE-3.5-4K-0205", | ||||
| 	"ERNIE-3.5-8K-0205", | ||||
| 	"ERNIE-3.5-8K-1222", | ||||
| 	"ERNIE-Lite-8K", | ||||
| 	"ERNIE-Speed-128K", | ||||
| 	"ERNIE-Tiny-8K", | ||||
| 	"BLOOMZ-7B", | ||||
| 	"Embedding-V1", | ||||
| 	"bge-large-zh", | ||||
| 	"bge-large-en", | ||||
|   | ||||
| @@ -70,8 +70,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		var responseText string | ||||
| 		err, responseText, _ = StreamHandler(c, resp, meta.Mode) | ||||
| 		usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) | ||||
| 		err, responseText, usage = StreamHandler(c, resp, meta.Mode) | ||||
| 		if usage == nil { | ||||
| 			usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) | ||||
| 		} | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||
| 	} | ||||
|   | ||||
| @@ -9,6 +9,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/minimax" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/mistral" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/moonshot" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/stepfun" | ||||
| ) | ||||
|  | ||||
| var CompatibleChannels = []int{ | ||||
| @@ -20,6 +21,7 @@ var CompatibleChannels = []int{ | ||||
| 	common.ChannelTypeMistral, | ||||
| 	common.ChannelTypeGroq, | ||||
| 	common.ChannelTypeLingYiWanWu, | ||||
| 	common.ChannelTypeStepFun, | ||||
| } | ||||
|  | ||||
| func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||
| @@ -40,6 +42,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||
| 		return "groq", groq.ModelList | ||||
| 	case common.ChannelTypeLingYiWanWu: | ||||
| 		return "lingyiwanwu", lingyiwanwu.ModelList | ||||
| 	case common.ChannelTypeStepFun: | ||||
| 		return "stepfun", stepfun.ModelList | ||||
| 	default: | ||||
| 		return "openai", ModelList | ||||
| 	} | ||||
|   | ||||
| @@ -6,6 +6,7 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/conv" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| @@ -53,7 +54,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | ||||
| 						continue // just ignore the error | ||||
| 					} | ||||
| 					for _, choice := range streamResponse.Choices { | ||||
| 						responseText += choice.Delta.Content | ||||
| 						responseText += conv.AsString(choice.Delta.Content) | ||||
| 					} | ||||
| 					if streamResponse.Usage != nil { | ||||
| 						usage = streamResponse.Usage | ||||
|   | ||||
| @@ -118,12 +118,9 @@ type ImageResponse struct { | ||||
| } | ||||
|  | ||||
| type ChatCompletionsStreamResponseChoice struct { | ||||
| 	Index int `json:"index"` | ||||
| 	Delta struct { | ||||
| 		Content string `json:"content"` | ||||
| 		Role    string `json:"role,omitempty"` | ||||
| 	} `json:"delta"` | ||||
| 	FinishReason *string `json:"finish_reason,omitempty"` | ||||
| 	Index        int           `json:"index"` | ||||
| 	Delta        model.Message `json:"delta"` | ||||
| 	FinishReason *string       `json:"finish_reason,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatCompletionsStreamResponse struct { | ||||
|   | ||||
							
								
								
									
										7
									
								
								relay/channel/stepfun/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								relay/channel/stepfun/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| package stepfun | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"step-1-32k", | ||||
| 	"step-1v-32k", | ||||
| 	"step-1-200k", | ||||
| } | ||||
| @@ -10,6 +10,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/conv" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| @@ -129,7 +130,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC | ||||
| 			} | ||||
| 			response := streamResponseTencent2OpenAI(&TencentResponse) | ||||
| 			if len(response.Choices) != 0 { | ||||
| 				responseText += response.Choices[0].Delta.Content | ||||
| 				responseText += conv.AsString(response.Choices[0].Delta.Content) | ||||
| 			} | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
|   | ||||
| @@ -26,7 +26,11 @@ import ( | ||||
|  | ||||
| func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { | ||||
| 	messages := make([]Message, 0, len(request.Messages)) | ||||
| 	var lastToolCalls []model.Tool | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.ToolCalls != nil { | ||||
| 			lastToolCalls = message.ToolCalls | ||||
| 		} | ||||
| 		messages = append(messages, Message{ | ||||
| 			Role:    message.Role, | ||||
| 			Content: message.StringContent(), | ||||
| @@ -39,9 +43,33 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string | ||||
| 	xunfeiRequest.Parameter.Chat.TopK = request.N | ||||
| 	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens | ||||
| 	xunfeiRequest.Payload.Message.Text = messages | ||||
| 	if len(lastToolCalls) != 0 { | ||||
| 		for _, toolCall := range lastToolCalls { | ||||
| 			xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return &xunfeiRequest | ||||
| } | ||||
|  | ||||
| func getToolCalls(response *ChatResponse) []model.Tool { | ||||
| 	var toolCalls []model.Tool | ||||
| 	if len(response.Payload.Choices.Text) == 0 { | ||||
| 		return toolCalls | ||||
| 	} | ||||
| 	item := response.Payload.Choices.Text[0] | ||||
| 	if item.FunctionCall == nil { | ||||
| 		return toolCalls | ||||
| 	} | ||||
| 	toolCall := model.Tool{ | ||||
| 		Id:       fmt.Sprintf("call_%s", helper.GetUUID()), | ||||
| 		Type:     "function", | ||||
| 		Function: *item.FunctionCall, | ||||
| 	} | ||||
| 	toolCalls = append(toolCalls, toolCall) | ||||
| 	return toolCalls | ||||
| } | ||||
|  | ||||
| func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	if len(response.Payload.Choices.Text) == 0 { | ||||
| 		response.Payload.Choices.Text = []ChatResponseTextItem{ | ||||
| @@ -53,8 +81,9 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: model.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: response.Payload.Choices.Text[0].Content, | ||||
| 			Role:      "assistant", | ||||
| 			Content:   response.Payload.Choices.Text[0].Content, | ||||
| 			ToolCalls: getToolCalls(response), | ||||
| 		}, | ||||
| 		FinishReason: constant.StopFinishReason, | ||||
| 	} | ||||
| @@ -78,6 +107,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl | ||||
| 	} | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content | ||||
| 	choice.Delta.ToolCalls = getToolCalls(xunfeiResponse) | ||||
| 	if xunfeiResponse.Payload.Choices.Status == 2 { | ||||
| 		choice.FinishReason = &constant.StopFinishReason | ||||
| 	} | ||||
|   | ||||
| @@ -26,13 +26,18 @@ type ChatRequest struct { | ||||
| 		Message struct { | ||||
| 			Text []Message `json:"text"` | ||||
| 		} `json:"message"` | ||||
| 		Functions struct { | ||||
| 			Text []model.Function `json:"text,omitempty"` | ||||
| 		} `json:"functions,omitempty"` | ||||
| 	} `json:"payload"` | ||||
| } | ||||
|  | ||||
| type ChatResponseTextItem struct { | ||||
| 	Content string `json:"content"` | ||||
| 	Role    string `json:"role"` | ||||
| 	Index   int    `json:"index"` | ||||
| 	Content      string          `json:"content"` | ||||
| 	Role         string          `json:"role"` | ||||
| 	Index        int             `json:"index"` | ||||
| 	ContentType  string          `json:"content_type"` | ||||
| 	FunctionCall *model.Function `json:"function_call"` | ||||
| } | ||||
|  | ||||
| type ChatResponse struct { | ||||
|   | ||||
| @@ -6,6 +6,7 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"io" | ||||
| @@ -35,6 +36,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| 	if a.APIVersion == "v4" { | ||||
| 		return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil | ||||
| 	} | ||||
| 	if meta.Mode == constant.RelayModeEmbeddings { | ||||
| 		return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil | ||||
| 	} | ||||
| 	method := "invoke" | ||||
| 	if meta.IsStream { | ||||
| 		method = "sse-invoke" | ||||
| @@ -53,18 +57,24 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	// TopP (0.0, 1.0) | ||||
| 	request.TopP = math.Min(0.99, request.TopP) | ||||
| 	request.TopP = math.Max(0.01, request.TopP) | ||||
| 	switch relayMode { | ||||
| 	case constant.RelayModeEmbeddings: | ||||
| 		baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) | ||||
| 		return baiduEmbeddingRequest, nil | ||||
| 	default: | ||||
| 		// TopP (0.0, 1.0) | ||||
| 		request.TopP = math.Min(0.99, request.TopP) | ||||
| 		request.TopP = math.Max(0.01, request.TopP) | ||||
|  | ||||
| 	// Temperature (0.0, 1.0) | ||||
| 	request.Temperature = math.Min(0.99, request.Temperature) | ||||
| 	request.Temperature = math.Max(0.01, request.Temperature) | ||||
| 	a.SetVersionByModeName(request.Model) | ||||
| 	if a.APIVersion == "v4" { | ||||
| 		return request, nil | ||||
| 		// Temperature (0.0, 1.0) | ||||
| 		request.Temperature = math.Min(0.99, request.Temperature) | ||||
| 		request.Temperature = math.Max(0.01, request.Temperature) | ||||
| 		a.SetVersionByModeName(request.Model) | ||||
| 		if a.APIVersion == "v4" { | ||||
| 			return request, nil | ||||
| 		} | ||||
| 		return ConvertRequest(*request), nil | ||||
| 	} | ||||
| 	return ConvertRequest(*request), nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { | ||||
| @@ -84,14 +94,26 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel | ||||
| 	if a.APIVersion == "v4" { | ||||
| 		return a.DoResponseV4(c, resp, meta) | ||||
| 	} | ||||
|  | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, resp) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, resp) | ||||
| 		if meta.Mode == constant.RelayModeEmbeddings { | ||||
| 			err, usage = EmbeddingsHandler(c, resp) | ||||
| 		} else { | ||||
| 			err, usage = Handler(c, resp) | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { | ||||
| 	return &EmbeddingRequest{ | ||||
| 		Model: "embedding-2", | ||||
| 		Input: request.Input.(string), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetModelList() []string { | ||||
| 	return ModelList | ||||
| } | ||||
|   | ||||
| @@ -2,5 +2,5 @@ package zhipu | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", | ||||
| 	"glm-4", "glm-4v", "glm-3-turbo", | ||||
| 	"glm-4", "glm-4v", "glm-3-turbo", "embedding-2", | ||||
| } | ||||
|   | ||||
| @@ -254,3 +254,50 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|  | ||||
| func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	var zhipuResponse EmbeddingRespone | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &zhipuResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
|  | ||||
| func embeddingResponseZhipu2OpenAI(response *EmbeddingRespone) *openai.EmbeddingResponse { | ||||
| 	openAIEmbeddingResponse := openai.EmbeddingResponse{ | ||||
| 		Object: "list", | ||||
| 		Data:   make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)), | ||||
| 		Model:  response.Model, | ||||
| 		Usage: model.Usage{ | ||||
| 			PromptTokens:     response.PromptTokens, | ||||
| 			CompletionTokens: response.CompletionTokens, | ||||
| 			TotalTokens:      response.Usage.TotalTokens, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, item := range response.Embeddings { | ||||
| 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ | ||||
| 			Object:    `embedding`, | ||||
| 			Index:     item.Index, | ||||
| 			Embedding: item.Embedding, | ||||
| 		}) | ||||
| 	} | ||||
| 	return &openAIEmbeddingResponse | ||||
| } | ||||
|   | ||||
| @@ -44,3 +44,21 @@ type tokenData struct { | ||||
| 	Token      string | ||||
| 	ExpiryTime time.Time | ||||
| } | ||||
|  | ||||
| type EmbeddingRequest struct { | ||||
| 	Model string `json:"model"` | ||||
| 	Input string `json:"input"` | ||||
| } | ||||
|  | ||||
| type EmbeddingRespone struct { | ||||
| 	Model       string          `json:"model"` | ||||
| 	Object      string          `json:"object"` | ||||
| 	Embeddings  []EmbeddingData `json:"data"` | ||||
| 	model.Usage `json:"usage"` | ||||
| } | ||||
|  | ||||
| type EmbeddingData struct { | ||||
| 	Index     int       `json:"index"` | ||||
| 	Object    string    `json:"object"` | ||||
| 	Embedding []float64 `json:"embedding"` | ||||
| } | ||||
|   | ||||
| @@ -5,25 +5,29 @@ type ResponseFormat struct { | ||||
| } | ||||
|  | ||||
| type GeneralOpenAIRequest struct { | ||||
| 	Model            string          `json:"model,omitempty"` | ||||
| 	Messages         []Message       `json:"messages,omitempty"` | ||||
| 	Prompt           any             `json:"prompt,omitempty"` | ||||
| 	Stream           bool            `json:"stream,omitempty"` | ||||
| 	MaxTokens        int             `json:"max_tokens,omitempty"` | ||||
| 	Temperature      float64         `json:"temperature,omitempty"` | ||||
| 	TopP             float64         `json:"top_p,omitempty"` | ||||
| 	N                int             `json:"n,omitempty"` | ||||
| 	Input            any             `json:"input,omitempty"` | ||||
| 	Instruction      string          `json:"instruction,omitempty"` | ||||
| 	Size             string          `json:"size,omitempty"` | ||||
| 	Functions        any             `json:"functions,omitempty"` | ||||
| 	Model            string          `json:"model,omitempty"` | ||||
| 	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"` | ||||
| 	MaxTokens        int             `json:"max_tokens,omitempty"` | ||||
| 	N                int             `json:"n,omitempty"` | ||||
| 	PresencePenalty  float64         `json:"presence_penalty,omitempty"` | ||||
| 	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"` | ||||
| 	Seed             float64         `json:"seed,omitempty"` | ||||
| 	Tools            any             `json:"tools,omitempty"` | ||||
| 	Stream           bool            `json:"stream,omitempty"` | ||||
| 	Temperature      float64         `json:"temperature,omitempty"` | ||||
| 	TopP             float64         `json:"top_p,omitempty"` | ||||
| 	TopK             int             `json:"top_k,omitempty"` | ||||
| 	Tools            []Tool          `json:"tools,omitempty"` | ||||
| 	ToolChoice       any             `json:"tool_choice,omitempty"` | ||||
| 	FunctionCall     any             `json:"function_call,omitempty"` | ||||
| 	Functions        any             `json:"functions,omitempty"` | ||||
| 	User             string          `json:"user,omitempty"` | ||||
| 	Prompt           any             `json:"prompt,omitempty"` | ||||
| 	Input            any             `json:"input,omitempty"` | ||||
| 	EncodingFormat   string          `json:"encoding_format,omitempty"` | ||||
| 	Dimensions       int             `json:"dimensions,omitempty"` | ||||
| 	Instruction      string          `json:"instruction,omitempty"` | ||||
| 	Size             string          `json:"size,omitempty"` | ||||
| } | ||||
|  | ||||
| func (r GeneralOpenAIRequest) ParseInput() []string { | ||||
|   | ||||
| @@ -1,9 +1,10 @@ | ||||
| package model | ||||
|  | ||||
| type Message struct { | ||||
| 	Role    string  `json:"role"` | ||||
| 	Content any     `json:"content"` | ||||
| 	Name    *string `json:"name,omitempty"` | ||||
| 	Role      string  `json:"role,omitempty"` | ||||
| 	Content   any     `json:"content,omitempty"` | ||||
| 	Name      *string `json:"name,omitempty"` | ||||
| 	ToolCalls []Tool  `json:"tool_calls,omitempty"` | ||||
| } | ||||
|  | ||||
| func (m Message) IsStringContent() bool { | ||||
|   | ||||
							
								
								
									
										14
									
								
								relay/model/tool.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								relay/model/tool.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| package model | ||||
|  | ||||
| type Tool struct { | ||||
| 	Id       string   `json:"id,omitempty"` | ||||
| 	Type     string   `json:"type"` | ||||
| 	Function Function `json:"function"` | ||||
| } | ||||
|  | ||||
| type Function struct { | ||||
| 	Description string `json:"description,omitempty"` | ||||
| 	Name        string `json:"name"` | ||||
| 	Parameters  any    `json:"parameters,omitempty"` // request | ||||
| 	Arguments   any    `json:"arguments,omitempty"`  // response | ||||
| } | ||||
| @@ -46,6 +46,15 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { | ||||
| 	} else if strings.HasPrefix(err.Message, "This organization has been disabled.") { | ||||
| 		return true | ||||
| 	} | ||||
| 	//if strings.Contains(err.Message, "quota") { | ||||
| 	//	return true | ||||
| 	//} | ||||
| 	if strings.Contains(err.Message, "credit") { | ||||
| 		return true | ||||
| 	} | ||||
| 	if strings.Contains(err.Message, "balance") { | ||||
| 		return true | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -2,6 +2,7 @@ package router | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/controller/auth" | ||||
| 	"github.com/songquanpeng/one-api/middleware" | ||||
|  | ||||
| 	"github.com/gin-contrib/gzip" | ||||
| @@ -21,11 +22,13 @@ func SetApiRouter(router *gin.Engine) { | ||||
| 		apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) | ||||
| 		apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) | ||||
| 		apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) | ||||
| 		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) | ||||
| 		apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) | ||||
| 		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) | ||||
| 		apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) | ||||
| 		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) | ||||
| 		apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) | ||||
| 		apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) | ||||
| 		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth) | ||||
| 		apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), auth.WeChatBind) | ||||
| 		apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) | ||||
| 		apiRouter.POST("/topup", middleware.AdminAuth(), controller.AdminTopUp) | ||||
|  | ||||
| 		userRoute := apiRouter.Group("/user") | ||||
| 		{ | ||||
| @@ -43,6 +46,7 @@ func SetApiRouter(router *gin.Engine) { | ||||
| 				selfRoute.GET("/token", controller.GenerateAccessToken) | ||||
| 				selfRoute.GET("/aff", controller.GetAffCode) | ||||
| 				selfRoute.POST("/topup", controller.TopUp) | ||||
| 				selfRoute.GET("/available_models", controller.GetUserAvailableModels) | ||||
| 			} | ||||
|  | ||||
| 			adminRoute := userRoute.Group("/") | ||||
|   | ||||
| @@ -2,6 +2,9 @@ | ||||
|  | ||||
| > 每个文件夹代表一个主题,欢迎提交你的主题 | ||||
|  | ||||
| > [!WARNING] | ||||
| > 不是每一个主题都及时同步了所有功能,由于精力有限,优先更新默认主题,其他主题欢迎 & 期待 PR | ||||
|  | ||||
| ## 提交新的主题 | ||||
|  | ||||
| > 欢迎在页面底部保留你和 One API 的版权信息以及指向链接 | ||||
|   | ||||
| @@ -107,6 +107,12 @@ export const CHANNEL_OPTIONS = { | ||||
|     value: 31, | ||||
|     color: 'primary' | ||||
|   }, | ||||
|   32: { | ||||
|     key: 32, | ||||
|     text: '阶跃星辰', | ||||
|     value: 32, | ||||
|     color: 'primary' | ||||
|   }, | ||||
|   8: { | ||||
|     key: 8, | ||||
|     text: '自定义渠道', | ||||
|   | ||||
| @@ -24,6 +24,7 @@ import EditRedemption from './pages/Redemption/EditRedemption'; | ||||
| import TopUp from './pages/TopUp'; | ||||
| import Log from './pages/Log'; | ||||
| import Chat from './pages/Chat'; | ||||
| import LarkOAuth from './components/LarkOAuth'; | ||||
|  | ||||
| const Home = lazy(() => import('./pages/Home')); | ||||
| const About = lazy(() => import('./pages/About')); | ||||
| @@ -239,6 +240,14 @@ function App() { | ||||
|           </Suspense> | ||||
|         } | ||||
|       /> | ||||
|       <Route | ||||
|         path='/oauth/lark' | ||||
|         element={ | ||||
|           <Suspense fallback={<Loading></Loading>}> | ||||
|             <LarkOAuth /> | ||||
|           </Suspense> | ||||
|         } | ||||
|       /> | ||||
|       <Route | ||||
|         path='/setting' | ||||
|         element={ | ||||
|   | ||||
							
								
								
									
										58
									
								
								web/default/src/components/LarkOAuth.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								web/default/src/components/LarkOAuth.js
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | ||||
| import React, { useContext, useEffect, useState } from 'react'; | ||||
| import { Dimmer, Loader, Segment } from 'semantic-ui-react'; | ||||
| import { useNavigate, useSearchParams } from 'react-router-dom'; | ||||
| import { API, showError, showSuccess } from '../helpers'; | ||||
| import { UserContext } from '../context/User'; | ||||
|  | ||||
| const LarkOAuth = () => { | ||||
|   const [searchParams, setSearchParams] = useSearchParams(); | ||||
|  | ||||
|   const [userState, userDispatch] = useContext(UserContext); | ||||
|   const [prompt, setPrompt] = useState('处理中...'); | ||||
|   const [processing, setProcessing] = useState(true); | ||||
|  | ||||
|   let navigate = useNavigate(); | ||||
|  | ||||
|   const sendCode = async (code, state, count) => { | ||||
|     const res = await API.get(`/api/oauth/lark?code=${code}&state=${state}`); | ||||
|     const { success, message, data } = res.data; | ||||
|     if (success) { | ||||
|       if (message === 'bind') { | ||||
|         showSuccess('绑定成功!'); | ||||
|         navigate('/setting'); | ||||
|       } else { | ||||
|         userDispatch({ type: 'login', payload: data }); | ||||
|         localStorage.setItem('user', JSON.stringify(data)); | ||||
|         showSuccess('登录成功!'); | ||||
|         navigate('/'); | ||||
|       } | ||||
|     } else { | ||||
|       showError(message); | ||||
|       if (count === 0) { | ||||
|         setPrompt(`操作失败,重定向至登录界面中...`); | ||||
|         navigate('/setting'); // in case this is failed to bind lark | ||||
|         return; | ||||
|       } | ||||
|       count++; | ||||
|       setPrompt(`出现错误,第 ${count} 次重试中...`); | ||||
|       await new Promise((resolve) => setTimeout(resolve, count * 2000)); | ||||
|       await sendCode(code, state, count); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let code = searchParams.get('code'); | ||||
|     let state = searchParams.get('state'); | ||||
|     sendCode(code, state, 0).then(); | ||||
|   }, []); | ||||
|  | ||||
|   return ( | ||||
|     <Segment style={{ minHeight: '300px' }}> | ||||
|       <Dimmer active inverted> | ||||
|         <Loader size='large'>{prompt}</Loader> | ||||
|       </Dimmer> | ||||
|     </Segment> | ||||
|   ); | ||||
| }; | ||||
|  | ||||
| export default LarkOAuth; | ||||
| @@ -3,7 +3,8 @@ import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } f | ||||
| import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | ||||
| import { UserContext } from '../context/User'; | ||||
| import { API, getLogo, showError, showSuccess, showWarning } from '../helpers'; | ||||
| import { onGitHubOAuthClicked } from './utils'; | ||||
| import { onGitHubOAuthClicked, onLarkOAuthClicked } from './utils'; | ||||
| import larkIcon from '../images/lark.svg'; | ||||
|  | ||||
| const LoginForm = () => { | ||||
|   const [inputs, setInputs] = useState({ | ||||
| @@ -124,7 +125,7 @@ const LoginForm = () => { | ||||
|             点击注册 | ||||
|           </Link> | ||||
|         </Message> | ||||
|         {status.github_oauth || status.wechat_login ? ( | ||||
|         {status.github_oauth || status.wechat_login || status.lark_client_id ? ( | ||||
|           <> | ||||
|             <Divider horizontal>Or</Divider> | ||||
|             {status.github_oauth ? ( | ||||
| @@ -137,6 +138,18 @@ const LoginForm = () => { | ||||
|             ) : ( | ||||
|               <></> | ||||
|             )} | ||||
|             {status.lark_client_id ? ( | ||||
|               <Button | ||||
|                 // circular | ||||
|                 color='' | ||||
|                 onClick={() => onLarkOAuthClicked(status.lark_client_id)} | ||||
|                 style={{ padding: 0, width: 36, height: 36 }} | ||||
|               > | ||||
|                 <img src={larkIcon} width={36} height={36} /> | ||||
|               </Button> | ||||
|             ) : ( | ||||
|               <></> | ||||
|             )} | ||||
|             {status.wechat_login ? ( | ||||
|               <Button | ||||
|                 circular | ||||
|   | ||||
| @@ -4,7 +4,7 @@ import { Link, useNavigate } from 'react-router-dom'; | ||||
| import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; | ||||
| import Turnstile from 'react-turnstile'; | ||||
| import { UserContext } from '../context/User'; | ||||
| import { onGitHubOAuthClicked } from './utils'; | ||||
| import { onGitHubOAuthClicked, onLarkOAuthClicked } from './utils'; | ||||
|  | ||||
| const PersonalSetting = () => { | ||||
|   const [userState, userDispatch] = useContext(UserContext); | ||||
| @@ -247,6 +247,11 @@ const PersonalSetting = () => { | ||||
|           <Button onClick={()=>{onGitHubOAuthClicked(status.github_client_id)}}>绑定 GitHub 账号</Button> | ||||
|         ) | ||||
|       } | ||||
|       { | ||||
|         status.lark_client_id && ( | ||||
|           <Button onClick={()=>{onLarkOAuthClicked(status.lark_client_id)}}>绑定飞书账号</Button> | ||||
|         ) | ||||
|       } | ||||
|       <Button | ||||
|         onClick={() => { | ||||
|           setShowEmailBindModal(true); | ||||
|   | ||||
| @@ -10,6 +10,8 @@ const SystemSetting = () => { | ||||
|     GitHubOAuthEnabled: '', | ||||
|     GitHubClientId: '', | ||||
|     GitHubClientSecret: '', | ||||
|     LarkClientId: '', | ||||
|     LarkClientSecret: '', | ||||
|     Notice: '', | ||||
|     SMTPServer: '', | ||||
|     SMTPPort: '', | ||||
| @@ -109,6 +111,8 @@ const SystemSetting = () => { | ||||
|       name === 'ServerAddress' || | ||||
|       name === 'GitHubClientId' || | ||||
|       name === 'GitHubClientSecret' || | ||||
|       name === 'LarkClientId' || | ||||
|       name === 'LarkClientSecret' || | ||||
|       name === 'WeChatServerAddress' || | ||||
|       name === 'WeChatServerToken' || | ||||
|       name === 'WeChatAccountQRCodeImageURL' || | ||||
| @@ -212,6 +216,18 @@ const SystemSetting = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|    const submitLarkOAuth = async () => { | ||||
|     if (originInputs['LarkClientId'] !== inputs.LarkClientId) { | ||||
|       await updateOption('LarkClientId', inputs.LarkClientId); | ||||
|     } | ||||
|     if ( | ||||
|       originInputs['LarkClientSecret'] !== inputs.LarkClientSecret && | ||||
|       inputs.LarkClientSecret !== '' | ||||
|     ) { | ||||
|       await updateOption('LarkClientSecret', inputs.LarkClientSecret); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const submitTurnstile = async () => { | ||||
|     if (originInputs['TurnstileSiteKey'] !== inputs.TurnstileSiteKey) { | ||||
|       await updateOption('TurnstileSiteKey', inputs.TurnstileSiteKey); | ||||
| @@ -469,6 +485,44 @@ const SystemSetting = () => { | ||||
|             保存 GitHub OAuth 设置 | ||||
|           </Form.Button> | ||||
|           <Divider /> | ||||
|           <Header as='h3'> | ||||
|             配置飞书授权登录 | ||||
|             <Header.Subheader> | ||||
|               用以支持通过飞书进行登录注册, | ||||
|               <a href='https://open.feishu.cn/app' target='_blank'> | ||||
|                 点击此处 | ||||
|               </a> | ||||
|               管理你的飞书应用 | ||||
|             </Header.Subheader> | ||||
|           </Header> | ||||
|           <Message> | ||||
|             主页链接填 <code>{inputs.ServerAddress}</code> | ||||
|             ,重定向 URL 填{' '} | ||||
|             <code>{`${inputs.ServerAddress}/oauth/lark`}</code> | ||||
|           </Message> | ||||
|           <Form.Group widths={3}> | ||||
|             <Form.Input | ||||
|               label='App ID' | ||||
|               name='LarkClientId' | ||||
|               onChange={handleInputChange} | ||||
|               autoComplete='new-password' | ||||
|               value={inputs.LarkClientId} | ||||
|               placeholder='输入 App ID' | ||||
|             /> | ||||
|             <Form.Input | ||||
|               label='App Secret' | ||||
|               name='LarkClientSecret' | ||||
|               onChange={handleInputChange} | ||||
|               type='password' | ||||
|               autoComplete='new-password' | ||||
|               value={inputs.LarkClientSecret} | ||||
|               placeholder='敏感信息不会发送到前端显示' | ||||
|             /> | ||||
|           </Form.Group> | ||||
|           <Form.Button onClick={submitLarkOAuth}> | ||||
|             保存飞书 OAuth 设置 | ||||
|           </Form.Button> | ||||
|           <Divider /> | ||||
|           <Header as='h3'> | ||||
|             配置 WeChat Server | ||||
|             <Header.Subheader> | ||||
|   | ||||
| @@ -17,4 +17,13 @@ export async function onGitHubOAuthClicked(github_client_id) { | ||||
|   window.open( | ||||
|     `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email` | ||||
|   ); | ||||
| } | ||||
|  | ||||
| export async function onLarkOAuthClicked(lark_client_id) { | ||||
|   const state = await getOAuthState(); | ||||
|   if (!state) return; | ||||
|   let redirect_uri = `${window.location.origin}/oauth/lark`; | ||||
|   window.open( | ||||
|     `https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}` | ||||
|   ); | ||||
| } | ||||
| @@ -17,6 +17,7 @@ export const CHANNEL_OPTIONS = [ | ||||
|   { key: 29, text: 'Groq', value: 29, color: 'orange' }, | ||||
|   { key: 30, text: 'Ollama', value: 30, color: 'black' }, | ||||
|   { key: 31, text: '零一万物', value: 31, color: 'green' }, | ||||
|   { key: 31, text: '阶跃星辰', value: 32, color: 'blue' }, | ||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||
|   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, | ||||
|   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, | ||||
|   | ||||
							
								
								
									
										1
									
								
								web/default/src/images/lark.svg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								web/default/src/images/lark.svg
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| After Width: | Height: | Size: 5.4 KiB | 
| @@ -1,19 +1,22 @@ | ||||
| import React, { useEffect, useState } from 'react'; | ||||
| import { Button, Form, Header, Message, Segment } from 'semantic-ui-react'; | ||||
| import { useParams, useNavigate } from 'react-router-dom'; | ||||
| import { API, showError, showSuccess, timestamp2string } from '../../helpers'; | ||||
| import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render'; | ||||
| import { useNavigate, useParams } from 'react-router-dom'; | ||||
| import { API, copy, showError, showSuccess, timestamp2string } from '../../helpers'; | ||||
| import { renderQuotaWithPrompt } from '../../helpers/render'; | ||||
|  | ||||
| const EditToken = () => { | ||||
|   const params = useParams(); | ||||
|   const tokenId = params.id; | ||||
|   const isEdit = tokenId !== undefined; | ||||
|   const [loading, setLoading] = useState(isEdit); | ||||
|   const [modelOptions, setModelOptions] = useState([]); | ||||
|   const originInputs = { | ||||
|     name: '', | ||||
|     remain_quota: isEdit ? 0 : 500000, | ||||
|     expired_time: -1, | ||||
|     unlimited_quota: false | ||||
|     unlimited_quota: false, | ||||
|     models: [], | ||||
|     subnet: "", | ||||
|   }; | ||||
|   const [inputs, setInputs] = useState(originInputs); | ||||
|   const { name, remain_quota, expired_time, unlimited_quota } = inputs; | ||||
| @@ -22,8 +25,8 @@ const EditToken = () => { | ||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
|   }; | ||||
|   const handleCancel = () => { | ||||
|     navigate("/token"); | ||||
|   } | ||||
|     navigate('/token'); | ||||
|   }; | ||||
|   const setExpiredTime = (month, day, hour, minute) => { | ||||
|     let now = new Date(); | ||||
|     let timestamp = now.getTime() / 1000; | ||||
| @@ -50,6 +53,11 @@ const EditToken = () => { | ||||
|       if (data.expired_time !== -1) { | ||||
|         data.expired_time = timestamp2string(data.expired_time); | ||||
|       } | ||||
|       if (data.models === '') { | ||||
|         data.models = []; | ||||
|       } else { | ||||
|         data.models = data.models.split(','); | ||||
|       } | ||||
|       setInputs(data); | ||||
|     } else { | ||||
|       showError(message); | ||||
| @@ -60,8 +68,26 @@ const EditToken = () => { | ||||
|     if (isEdit) { | ||||
|       loadToken().then(); | ||||
|     } | ||||
|     loadAvailableModels().then(); | ||||
|   }, []); | ||||
|  | ||||
|   const loadAvailableModels = async () => { | ||||
|     let res = await API.get(`/api/user/available_models`); | ||||
|     const { success, message, data } = res.data; | ||||
|     if (success) { | ||||
|       let options = data.map((model) => { | ||||
|         return { | ||||
|           key: model, | ||||
|           text: model, | ||||
|           value: model | ||||
|         }; | ||||
|       }); | ||||
|       setModelOptions(options); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const submit = async () => { | ||||
|     if (!isEdit && inputs.name === '') return; | ||||
|     let localInputs = inputs; | ||||
| @@ -74,6 +100,7 @@ const EditToken = () => { | ||||
|       } | ||||
|       localInputs.expired_time = Math.ceil(time / 1000); | ||||
|     } | ||||
|     localInputs.models = localInputs.models.join(','); | ||||
|     let res; | ||||
|     if (isEdit) { | ||||
|       res = await API.put(`/api/token/`, { ...localInputs, id: parseInt(tokenId) }); | ||||
| @@ -109,6 +136,34 @@ const EditToken = () => { | ||||
|               required={!isEdit} | ||||
|             /> | ||||
|           </Form.Field> | ||||
|           <Form.Field> | ||||
|             <Form.Dropdown | ||||
|               label='模型范围' | ||||
|               placeholder={'请选择允许使用的模型,留空则不进行限制'} | ||||
|               name='models' | ||||
|               fluid | ||||
|               multiple | ||||
|               search | ||||
|               onLabelClick={(e, { value }) => { | ||||
|                 copy(value).then(); | ||||
|               }} | ||||
|               selection | ||||
|               onChange={handleInputChange} | ||||
|               value={inputs.models} | ||||
|               autoComplete='new-password' | ||||
|               options={modelOptions} | ||||
|             /> | ||||
|           </Form.Field> | ||||
|           <Form.Field> | ||||
|             <Form.Input | ||||
|               label='IP 限制' | ||||
|               name='subnet' | ||||
|               placeholder={'请输入允许访问的网段,例如:192.168.0.0/24'} | ||||
|               onChange={handleInputChange} | ||||
|               value={inputs.subnet} | ||||
|               autoComplete='new-password' | ||||
|             /> | ||||
|           </Form.Field> | ||||
|           <Form.Field> | ||||
|             <Form.Input | ||||
|               label='过期时间' | ||||
|   | ||||
| @@ -8,6 +8,7 @@ const TopUp = () => { | ||||
|   const [topUpLink, setTopUpLink] = useState(''); | ||||
|   const [userQuota, setUserQuota] = useState(0); | ||||
|   const [isSubmitting, setIsSubmitting] = useState(false); | ||||
|   const [user, setUser] = useState({}); | ||||
|  | ||||
|   const topUp = async () => { | ||||
|     if (redemptionCode === '') { | ||||
| @@ -41,7 +42,14 @@ const TopUp = () => { | ||||
|       showError('超级管理员未设置充值链接!'); | ||||
|       return; | ||||
|     } | ||||
|     window.open(topUpLink, '_blank'); | ||||
|     let url = new URL(topUpLink); | ||||
|     let username = user.username; | ||||
|     let user_id = user.id; | ||||
|     // add  username and user_id to the topup link | ||||
|     url.searchParams.append('username', username); | ||||
|     url.searchParams.append('user_id', user_id); | ||||
|     url.searchParams.append('transaction_id', crypto.randomUUID()); | ||||
|     window.open(url.toString(), '_blank'); | ||||
|   }; | ||||
|  | ||||
|   const getUserQuota = async ()=>{ | ||||
| @@ -49,6 +57,7 @@ const TopUp = () => { | ||||
|     const {success, message, data} = res.data; | ||||
|     if (success) { | ||||
|       setUserQuota(data.quota); | ||||
|       setUser(data); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
| @@ -80,7 +89,7 @@ const TopUp = () => { | ||||
|               }} | ||||
|             /> | ||||
|             <Button color='green' onClick={openTopUpLink}> | ||||
|               获取兑换码 | ||||
|               充值 | ||||
|             </Button> | ||||
|             <Button color='yellow' onClick={topUp} disabled={isSubmitting}> | ||||
|                 {isSubmitting ? '兑换中...' : '兑换'} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user