mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 07:43:41 +08:00 
			
		
		
		
	Compare commits
	
		
			15 Commits
		
	
	
		
			v0.6.5-alp
			...
			v0.6.5-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					0bb7db0b44 | ||
| 
						 | 
					4d61b9937b | ||
| 
						 | 
					68605800af | ||
| 
						 | 
					c49778c254 | ||
| 
						 | 
					f02c7138ea | ||
| 
						 | 
					ca3228855a | ||
| 
						 | 
					f8cc63f00b | ||
| 
						 | 
					0a37aa4cbd | ||
| 
						 | 
					054b00b725 | ||
| 
						 | 
					76569bb0b6 | ||
| 
						 | 
					1994256bac | ||
| 
						 | 
					1f80b0a39f | ||
| 
						 | 
					f73f2e51df | ||
| 
						 | 
					6f036bd0c9 | ||
| 
						 | 
					fb90747c23 | 
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -8,3 +8,4 @@ build
 | 
			
		||||
logs
 | 
			
		||||
data
 | 
			
		||||
/web/node_modules
 | 
			
		||||
cmd.md
 | 
			
		||||
@@ -85,7 +85,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
3. 支持通过**负载均衡**的方式访问多个渠道。
 | 
			
		||||
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
 | 
			
		||||
5. 支持**多机部署**,[详见此处](#多机部署)。
 | 
			
		||||
6. 支持**令牌管理**,设置令牌的过期时间和额度。
 | 
			
		||||
6. 支持**令牌管理**,设置令牌的过期时间、额度、允许的 IP 范围以及允许的模型访问。
 | 
			
		||||
7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
 | 
			
		||||
8. 支持**渠道管理**,批量创建渠道。
 | 
			
		||||
9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。
 | 
			
		||||
@@ -101,10 +101,11 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
19. 支持丰富的**自定义**设置,
 | 
			
		||||
    1. 支持自定义系统名称,logo 以及页脚。
 | 
			
		||||
    2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
 | 
			
		||||
20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。
 | 
			
		||||
20. 支持通过系统访问令牌调用管理 API,进而**在无需二开的情况下扩展和自定义** One API 的功能,详情请参考此处 [API 文档](./docs/API.md)。。
 | 
			
		||||
21. 支持 Cloudflare Turnstile 用户校验。
 | 
			
		||||
22. 支持用户管理,支持**多种用户登录注册方式**:
 | 
			
		||||
    + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
 | 
			
		||||
    + 支持使用飞书进行授权登录。
 | 
			
		||||
    + [GitHub 开放授权](https://github.com/settings/applications/new)。
 | 
			
		||||
    + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
 | 
			
		||||
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
 | 
			
		||||
 
 | 
			
		||||
@@ -66,6 +66,9 @@ var SMTPToken = ""
 | 
			
		||||
var GitHubClientId = ""
 | 
			
		||||
var GitHubClientSecret = ""
 | 
			
		||||
 | 
			
		||||
var LarkClientId = ""
 | 
			
		||||
var LarkClientSecret = ""
 | 
			
		||||
 | 
			
		||||
var WeChatServerAddress = ""
 | 
			
		||||
var WeChatServerToken = ""
 | 
			
		||||
var WeChatAccountQRCodeImageURL = ""
 | 
			
		||||
 
 | 
			
		||||
@@ -72,14 +72,22 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-opus-20240229":   15.0 / 1000 * USD,
 | 
			
		||||
	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
 | 
			
		||||
	"ERNIE-Bot":       0.8572,     // ¥0.012 / 1k tokens
 | 
			
		||||
	"ERNIE-Bot-turbo": 0.5715,     // ¥0.008 / 1k tokens
 | 
			
		||||
	"ERNIE-Bot-4":     0.12 * RMB, // ¥0.12 / 1k tokens
 | 
			
		||||
	"ERNIE-Bot-8K":    0.024 * RMB,
 | 
			
		||||
	"Embedding-V1":    0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"bge-large-zh":    0.002 * RMB,
 | 
			
		||||
	"bge-large-en":    0.002 * RMB,
 | 
			
		||||
	"bge-large-8k":    0.002 * RMB,
 | 
			
		||||
	"ERNIE-4.0-8K":       0.120 * RMB,
 | 
			
		||||
	"ERNIE-Bot-8K-0922":  0.024 * RMB,
 | 
			
		||||
	"ERNIE-3.5-8K":       0.012 * RMB,
 | 
			
		||||
	"ERNIE-Lite-8K-0922": 0.008 * RMB,
 | 
			
		||||
	"ERNIE-Speed-8K":     0.004 * RMB,
 | 
			
		||||
	"ERNIE-3.5-4K-0205":  0.012 * RMB,
 | 
			
		||||
	"ERNIE-3.5-8K-0205":  0.024 * RMB,
 | 
			
		||||
	"ERNIE-3.5-8K-1222":  0.012 * RMB,
 | 
			
		||||
	"ERNIE-Lite-8K":      0.003 * RMB,
 | 
			
		||||
	"ERNIE-Speed-128K":   0.004 * RMB,
 | 
			
		||||
	"ERNIE-Tiny-8K":      0.001 * RMB,
 | 
			
		||||
	"BLOOMZ-7B":          0.004 * RMB,
 | 
			
		||||
	"Embedding-V1":       0.002 * RMB,
 | 
			
		||||
	"bge-large-zh":       0.002 * RMB,
 | 
			
		||||
	"bge-large-en":       0.002 * RMB,
 | 
			
		||||
	"tao-8k":             0.002 * RMB,
 | 
			
		||||
	// https://ai.google.dev/pricing
 | 
			
		||||
	"PaLM-2":                    1,
 | 
			
		||||
	"gemini-pro":                1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
 | 
			
		||||
@@ -91,6 +99,7 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"glm-4":                     0.1 * RMB,
 | 
			
		||||
	"glm-4v":                    0.1 * RMB,
 | 
			
		||||
	"glm-3-turbo":               0.005 * RMB,
 | 
			
		||||
	"embedding-2":               0.0005 * RMB,
 | 
			
		||||
	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens
 | 
			
		||||
	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens
 | 
			
		||||
	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										22
									
								
								common/network/ip.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								common/network/ip.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
			
		||||
package network
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func IsValidSubnet(subnet string) error {
 | 
			
		||||
	_, _, err := net.ParseCIDR(subnet)
 | 
			
		||||
	return fmt.Errorf("failed to parse subnet: %w", err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IsIpInSubnet(ctx context.Context, ip string, subnet string) bool {
 | 
			
		||||
	_, ipNet, err := net.ParseCIDR(subnet)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Errorf(ctx, "failed to parse subnet: %s", err.Error())
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return ipNet.Contains(net.ParseIP(ip))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										19
									
								
								common/network/ip_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								common/network/ip_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
package network
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	. "github.com/smartystreets/goconvey/convey"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestIsIpInSubnet(t *testing.T) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	ip1 := "192.168.0.5"
 | 
			
		||||
	ip2 := "125.216.250.89"
 | 
			
		||||
	subnet := "192.168.0.0/24"
 | 
			
		||||
	Convey("TestIsIpInSubnet", t, func() {
 | 
			
		||||
		So(IsIpInSubnet(ctx, ip1, subnet), ShouldBeTrue)
 | 
			
		||||
		So(IsIpInSubnet(ctx, ip2, subnet), ShouldBeFalse)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
package controller
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
@@ -11,6 +11,7 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/controller"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@@ -159,7 +160,7 @@ func GitHubOAuth(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	setupLogin(&user, c)
 | 
			
		||||
	controller.SetupLogin(&user, c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GitHubBind(c *gin.Context) {
 | 
			
		||||
							
								
								
									
										201
									
								
								controller/auth/lark.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								controller/auth/lark.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,201 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-contrib/sessions"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/controller"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type LarkOAuthResponse struct {
 | 
			
		||||
	AccessToken string `json:"access_token"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type LarkUser struct {
 | 
			
		||||
	Name   string `json:"name"`
 | 
			
		||||
	OpenID string `json:"open_id"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getLarkUserInfoByCode(code string) (*LarkUser, error) {
 | 
			
		||||
	if code == "" {
 | 
			
		||||
		return nil, errors.New("无效的参数")
 | 
			
		||||
	}
 | 
			
		||||
	values := map[string]string{
 | 
			
		||||
		"client_id":     config.LarkClientId,
 | 
			
		||||
		"client_secret": config.LarkClientSecret,
 | 
			
		||||
		"code":          code,
 | 
			
		||||
		"grant_type":    "authorization_code",
 | 
			
		||||
		"redirect_uri":  fmt.Sprintf("%s/oauth/lark", config.ServerAddress),
 | 
			
		||||
	}
 | 
			
		||||
	jsonData, err := json.Marshal(values)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Content-Type", "application/json")
 | 
			
		||||
	req.Header.Set("Accept", "application/json")
 | 
			
		||||
	client := http.Client{
 | 
			
		||||
		Timeout: 5 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
	res, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysLog(err.Error())
 | 
			
		||||
		return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
 | 
			
		||||
	}
 | 
			
		||||
	defer res.Body.Close()
 | 
			
		||||
	var oAuthResponse LarkOAuthResponse
 | 
			
		||||
	err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
 | 
			
		||||
	res2, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysLog(err.Error())
 | 
			
		||||
		return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
 | 
			
		||||
	}
 | 
			
		||||
	var larkUser LarkUser
 | 
			
		||||
	err = json.NewDecoder(res2.Body).Decode(&larkUser)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &larkUser, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func LarkOAuth(c *gin.Context) {
 | 
			
		||||
	session := sessions.Default(c)
 | 
			
		||||
	state := c.Query("state")
 | 
			
		||||
	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
 | 
			
		||||
		c.JSON(http.StatusForbidden, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "state is empty or not same",
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	username := session.Get("username")
 | 
			
		||||
	if username != nil {
 | 
			
		||||
		LarkBind(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	code := c.Query("code")
 | 
			
		||||
	larkUser, err := getLarkUserInfoByCode(code)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	user := model.User{
 | 
			
		||||
		LarkId: larkUser.OpenID,
 | 
			
		||||
	}
 | 
			
		||||
	if model.IsLarkIdAlreadyTaken(user.LarkId) {
 | 
			
		||||
		err := user.FillUserByLarkId()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": err.Error(),
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		if config.RegisterEnabled {
 | 
			
		||||
			user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1)
 | 
			
		||||
			if larkUser.Name != "" {
 | 
			
		||||
				user.DisplayName = larkUser.Name
 | 
			
		||||
			} else {
 | 
			
		||||
				user.DisplayName = "Lark User"
 | 
			
		||||
			}
 | 
			
		||||
			user.Role = common.RoleCommonUser
 | 
			
		||||
			user.Status = common.UserStatusEnabled
 | 
			
		||||
 | 
			
		||||
			if err := user.Insert(0); err != nil {
 | 
			
		||||
				c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
					"success": false,
 | 
			
		||||
					"message": err.Error(),
 | 
			
		||||
				})
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": "管理员关闭了新用户注册",
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Status != common.UserStatusEnabled {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"message": "用户已被封禁",
 | 
			
		||||
			"success": false,
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	controller.SetupLogin(&user, c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func LarkBind(c *gin.Context) {
 | 
			
		||||
	code := c.Query("code")
 | 
			
		||||
	larkUser, err := getLarkUserInfoByCode(code)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	user := model.User{
 | 
			
		||||
		LarkId: larkUser.OpenID,
 | 
			
		||||
	}
 | 
			
		||||
	if model.IsLarkIdAlreadyTaken(user.LarkId) {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "该飞书账户已被绑定",
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	session := sessions.Default(c)
 | 
			
		||||
	id := session.Get("id")
 | 
			
		||||
	// id := c.GetInt("id")  // critical bug!
 | 
			
		||||
	user.Id = id.(int)
 | 
			
		||||
	err = user.FillUserById()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	user.LarkId = larkUser.OpenID
 | 
			
		||||
	err = user.Update(false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "bind",
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
package controller
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
@@ -7,6 +7,7 @@ import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/controller"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@@ -109,7 +110,7 @@ func WeChatAuth(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	setupLogin(&user, c)
 | 
			
		||||
	controller.SetupLogin(&user, c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func WeChatBind(c *gin.Context) {
 | 
			
		||||
@@ -23,6 +23,7 @@ func GetStatus(c *gin.Context) {
 | 
			
		||||
			"email_verification":  config.EmailVerificationEnabled,
 | 
			
		||||
			"github_oauth":        config.GitHubOAuthEnabled,
 | 
			
		||||
			"github_client_id":    config.GitHubClientId,
 | 
			
		||||
			"lark_client_id":      config.LarkClientId,
 | 
			
		||||
			"system_name":         config.SystemName,
 | 
			
		||||
			"logo":                config.Logo,
 | 
			
		||||
			"footer_html":         config.Footer,
 | 
			
		||||
 
 | 
			
		||||
@@ -135,7 +135,7 @@ func ListModels(c *gin.Context) {
 | 
			
		||||
	for _, availableModel := range availableModels {
 | 
			
		||||
		modelSet[availableModel] = true
 | 
			
		||||
	}
 | 
			
		||||
	var availableOpenAIModels []OpenAIModels
 | 
			
		||||
	availableOpenAIModels := make([]OpenAIModels, 0)
 | 
			
		||||
	for _, model := range openAIModels {
 | 
			
		||||
		if _, ok := modelSet[model.Id]; ok {
 | 
			
		||||
			modelSet[model.Id] = false
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,12 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/network"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@@ -104,6 +106,19 @@ func GetTokenStatus(c *gin.Context) {
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func validateToken(c *gin.Context, token model.Token) error {
 | 
			
		||||
	if len(token.Name) > 30 {
 | 
			
		||||
		return fmt.Errorf("令牌名称过长")
 | 
			
		||||
	}
 | 
			
		||||
	if token.Subnet != nil && *token.Subnet != "" {
 | 
			
		||||
		err := network.IsValidSubnet(*token.Subnet)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("无效的网段:%s", err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func AddToken(c *gin.Context) {
 | 
			
		||||
	token := model.Token{}
 | 
			
		||||
	err := c.ShouldBindJSON(&token)
 | 
			
		||||
@@ -114,13 +129,15 @@ func AddToken(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if len(token.Name) > 30 {
 | 
			
		||||
	err = validateToken(c, token)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "令牌名称过长",
 | 
			
		||||
			"message": fmt.Sprintf("参数错误:%s", err.Error()),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cleanToken := model.Token{
 | 
			
		||||
		UserId:         c.GetInt("id"),
 | 
			
		||||
		Name:           token.Name,
 | 
			
		||||
@@ -131,6 +148,7 @@ func AddToken(c *gin.Context) {
 | 
			
		||||
		RemainQuota:    token.RemainQuota,
 | 
			
		||||
		UnlimitedQuota: token.UnlimitedQuota,
 | 
			
		||||
		Models:         token.Models,
 | 
			
		||||
		Subnet:         token.Subnet,
 | 
			
		||||
	}
 | 
			
		||||
	err = cleanToken.Insert()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -178,10 +196,11 @@ func UpdateToken(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if len(token.Name) > 30 {
 | 
			
		||||
	err = validateToken(c, token)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "令牌名称过长",
 | 
			
		||||
			"message": fmt.Sprintf("参数错误:%s", err.Error()),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -218,6 +237,7 @@ func UpdateToken(c *gin.Context) {
 | 
			
		||||
		cleanToken.RemainQuota = token.RemainQuota
 | 
			
		||||
		cleanToken.UnlimitedQuota = token.UnlimitedQuota
 | 
			
		||||
		cleanToken.Models = token.Models
 | 
			
		||||
		cleanToken.Subnet = token.Subnet
 | 
			
		||||
	}
 | 
			
		||||
	err = cleanToken.Update()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -58,11 +58,11 @@ func Login(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	setupLogin(&user, c)
 | 
			
		||||
	SetupLogin(&user, c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// setup session & cookies and then return user info
 | 
			
		||||
func setupLogin(user *model.User, c *gin.Context) {
 | 
			
		||||
func SetupLogin(user *model.User, c *gin.Context) {
 | 
			
		||||
	session := sessions.Default(c)
 | 
			
		||||
	session.Set("id", user.Id)
 | 
			
		||||
	session.Set("username", user.Username)
 | 
			
		||||
@@ -180,27 +180,27 @@ func Register(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllUsers(c *gin.Context) {
 | 
			
		||||
    p, _ := strconv.Atoi(c.Query("p"))
 | 
			
		||||
    if p < 0 {
 | 
			
		||||
        p = 0
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    order := c.DefaultQuery("order", "")
 | 
			
		||||
    users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
 | 
			
		||||
	
 | 
			
		||||
    if err != nil {
 | 
			
		||||
        c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
            "success": false,
 | 
			
		||||
            "message": err.Error(),
 | 
			
		||||
        })
 | 
			
		||||
        return
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
        "success": true,
 | 
			
		||||
        "message": "",
 | 
			
		||||
        "data":    users,
 | 
			
		||||
    })
 | 
			
		||||
	p, _ := strconv.Atoi(c.Query("p"))
 | 
			
		||||
	if p < 0 {
 | 
			
		||||
		p = 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	order := c.DefaultQuery("order", "")
 | 
			
		||||
	users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    users,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SearchUsers(c *gin.Context) {
 | 
			
		||||
@@ -770,3 +770,38 @@ func TopUp(c *gin.Context) {
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type adminTopUpRequest struct {
 | 
			
		||||
	UserId int    `json:"user_id"`
 | 
			
		||||
	Quota  int    `json:"quota"`
 | 
			
		||||
	Remark string `json:"remark"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func AdminTopUp(c *gin.Context) {
 | 
			
		||||
	req := adminTopUpRequest{}
 | 
			
		||||
	err := c.ShouldBindJSON(&req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	err = model.IncreaseUserQuota(req.UserId, int64(req.Quota))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if req.Remark == "" {
 | 
			
		||||
		req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota)))
 | 
			
		||||
	}
 | 
			
		||||
	model.RecordTopupLog(req.UserId, req.Remark, req.Quota)
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										53
									
								
								docs/API.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								docs/API.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,53 @@
 | 
			
		||||
# 使用 API 操控 & 扩展 One API
 | 
			
		||||
> 欢迎提交 PR 在此放上你的拓展项目。
 | 
			
		||||
 | 
			
		||||
例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。
 | 
			
		||||
 | 
			
		||||
又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。
 | 
			
		||||
 | 
			
		||||
## 鉴权
 | 
			
		||||
One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下图获取:
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API:
 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
## 请求格式与响应格式
 | 
			
		||||
One API 使用 JSON 格式进行请求和响应。
 | 
			
		||||
 | 
			
		||||
对于响应体,一般格式如下:
 | 
			
		||||
```json
 | 
			
		||||
{
 | 
			
		||||
  "message": "请求信息",
 | 
			
		||||
  "success": true,
 | 
			
		||||
  "data": {}
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## API 列表
 | 
			
		||||
> 当前 API 列表不全,请自行通过浏览器抓取前端请求
 | 
			
		||||
 | 
			
		||||
如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。
 | 
			
		||||
 | 
			
		||||
### 获取当前登录用户信息
 | 
			
		||||
**GET** `/api/user/self`
 | 
			
		||||
 | 
			
		||||
### 为给定用户充值额度
 | 
			
		||||
**POST** `/api/topup`
 | 
			
		||||
```json
 | 
			
		||||
{
 | 
			
		||||
  "user_id": 1,
 | 
			
		||||
  "quota": 100000,
 | 
			
		||||
  "remark": "充值 100000 额度"
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## 其他
 | 
			
		||||
### 充值链接上的附加参数
 | 
			
		||||
One API 会在用户点击充值按钮的时候,将用户的信息和充值信息附加在链接上,例如:
 | 
			
		||||
`https://example.com?username=root&user_id=1&transaction_id=4b3eed80-55d5-443f-bd44-fb18c648c837`
 | 
			
		||||
 | 
			
		||||
你可以通过解析链接上的参数来获取用户信息和充值信息,然后调用 API 来为用户充值。
 | 
			
		||||
 | 
			
		||||
注意,不是所有主题都支持该功能,欢迎 PR 补齐。
 | 
			
		||||
							
								
								
									
										4
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.mod
									
									
									
									
									
								
							@@ -15,6 +15,7 @@ require (
 | 
			
		||||
	github.com/google/uuid v1.3.0
 | 
			
		||||
	github.com/gorilla/websocket v1.5.0
 | 
			
		||||
	github.com/pkoukk/tiktoken-go v0.1.5
 | 
			
		||||
	github.com/smartystreets/goconvey v1.8.1
 | 
			
		||||
	github.com/stretchr/testify v1.8.3
 | 
			
		||||
	golang.org/x/crypto v0.17.0
 | 
			
		||||
	golang.org/x/image v0.14.0
 | 
			
		||||
@@ -37,6 +38,7 @@ require (
 | 
			
		||||
	github.com/go-playground/universal-translator v0.18.1 // indirect
 | 
			
		||||
	github.com/go-sql-driver/mysql v1.6.0 // indirect
 | 
			
		||||
	github.com/goccy/go-json v0.10.2 // indirect
 | 
			
		||||
	github.com/gopherjs/gopherjs v1.17.2 // indirect
 | 
			
		||||
	github.com/gorilla/context v1.1.1 // indirect
 | 
			
		||||
	github.com/gorilla/securecookie v1.1.1 // indirect
 | 
			
		||||
	github.com/gorilla/sessions v1.2.1 // indirect
 | 
			
		||||
@@ -47,6 +49,7 @@ require (
 | 
			
		||||
	github.com/jinzhu/inflection v1.0.0 // indirect
 | 
			
		||||
	github.com/jinzhu/now v1.1.5 // indirect
 | 
			
		||||
	github.com/json-iterator/go v1.1.12 // indirect
 | 
			
		||||
	github.com/jtolds/gls v4.20.0+incompatible // indirect
 | 
			
		||||
	github.com/klauspost/cpuid/v2 v2.2.4 // indirect
 | 
			
		||||
	github.com/leodido/go-urn v1.2.4 // indirect
 | 
			
		||||
	github.com/mattn/go-isatty v0.0.19 // indirect
 | 
			
		||||
@@ -55,6 +58,7 @@ require (
 | 
			
		||||
	github.com/modern-go/reflect2 v1.0.2 // indirect
 | 
			
		||||
	github.com/pelletier/go-toml/v2 v2.0.8 // indirect
 | 
			
		||||
	github.com/pmezard/go-difflib v1.0.0 // indirect
 | 
			
		||||
	github.com/smarty/assertions v1.15.0 // indirect
 | 
			
		||||
	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
 | 
			
		||||
	github.com/ugorji/go/codec v1.2.11 // indirect
 | 
			
		||||
	golang.org/x/arch v0.3.0 // indirect
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										12
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								go.sum
									
									
									
									
									
								
							@@ -56,11 +56,13 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL
 | 
			
		||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
 | 
			
		||||
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
 | 
			
		||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
 | 
			
		||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
 | 
			
		||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 | 
			
		||||
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
 | 
			
		||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
 | 
			
		||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
 | 
			
		||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 | 
			
		||||
github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
 | 
			
		||||
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
 | 
			
		||||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
 | 
			
		||||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
 | 
			
		||||
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
 | 
			
		||||
@@ -85,6 +87,8 @@ github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/
 | 
			
		||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
 | 
			
		||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
 | 
			
		||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
 | 
			
		||||
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
 | 
			
		||||
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
 | 
			
		||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
 | 
			
		||||
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
 | 
			
		||||
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
 | 
			
		||||
@@ -127,6 +131,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
 | 
			
		||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
 | 
			
		||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
 | 
			
		||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
 | 
			
		||||
github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
 | 
			
		||||
github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
 | 
			
		||||
github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY=
 | 
			
		||||
github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60=
 | 
			
		||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 | 
			
		||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
 | 
			
		||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
 | 
			
		||||
@@ -177,8 +185,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 | 
			
		||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
 | 
			
		||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
 | 
			
		||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 | 
			
		||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
 | 
			
		||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 | 
			
		||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
 | 
			
		||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
 | 
			
		||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
 | 
			
		||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/blacklist"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/network"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -89,6 +90,7 @@ func RootAuth() func(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		ctx := c.Request.Context()
 | 
			
		||||
		key := c.Request.Header.Get("Authorization")
 | 
			
		||||
		key = strings.TrimPrefix(key, "Bearer ")
 | 
			
		||||
		key = strings.TrimPrefix(key, "sk-")
 | 
			
		||||
@@ -99,6 +101,12 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
			abortWithMessage(c, http.StatusUnauthorized, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if token.Subnet != nil && *token.Subnet != "" {
 | 
			
		||||
			if !network.IsIpInSubnet(ctx, c.ClientIP(), *token.Subnet) {
 | 
			
		||||
				abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP()))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		userEnabled, err := model.CacheIsUserEnabled(token.UserId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			abortWithMessage(c, http.StatusInternalServerError, err.Error())
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										15
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								model/log.go
									
									
									
									
									
								
							@@ -51,6 +51,21 @@ func RecordLog(userId int, logType int, content string) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RecordTopupLog(userId int, content string, quota int) {
 | 
			
		||||
	log := &Log{
 | 
			
		||||
		UserId:    userId,
 | 
			
		||||
		Username:  GetUsernameById(userId),
 | 
			
		||||
		CreatedAt: helper.GetTimestamp(),
 | 
			
		||||
		Type:      LogTypeTopup,
 | 
			
		||||
		Content:   content,
 | 
			
		||||
		Quota:     quota,
 | 
			
		||||
	}
 | 
			
		||||
	err := LOG_DB.Create(log).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError("failed to record log: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
 | 
			
		||||
	logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
 | 
			
		||||
	if !config.LogConsumeEnabled {
 | 
			
		||||
 
 | 
			
		||||
@@ -172,6 +172,10 @@ func updateOptionMap(key string, value string) (err error) {
 | 
			
		||||
		config.GitHubClientId = value
 | 
			
		||||
	case "GitHubClientSecret":
 | 
			
		||||
		config.GitHubClientSecret = value
 | 
			
		||||
	case "LarkClientId":
 | 
			
		||||
		config.LarkClientId = value
 | 
			
		||||
	case "LarkClientSecret":
 | 
			
		||||
		config.LarkClientSecret = value
 | 
			
		||||
	case "Footer":
 | 
			
		||||
		config.Footer = value
 | 
			
		||||
	case "SystemName":
 | 
			
		||||
 
 | 
			
		||||
@@ -23,7 +23,8 @@ type Token struct {
 | 
			
		||||
	RemainQuota    int64   `json:"remain_quota" gorm:"bigint;default:0"`
 | 
			
		||||
	UnlimitedQuota bool    `json:"unlimited_quota" gorm:"default:false"`
 | 
			
		||||
	UsedQuota      int64   `json:"used_quota" gorm:"bigint;default:0"` // used quota
 | 
			
		||||
	Models         *string `json:"models" gorm:"default:''"`
 | 
			
		||||
	Models         *string `json:"models" gorm:"default:''"`           // allowed models
 | 
			
		||||
	Subnet         *string `json:"subnet" gorm:"default:''"`           // allowed subnet
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) {
 | 
			
		||||
@@ -122,7 +123,7 @@ func (token *Token) Insert() error {
 | 
			
		||||
// Update Make sure your token's fields is completed, because this will update non-zero values
 | 
			
		||||
func (token *Token) Update() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models").Updates(token).Error
 | 
			
		||||
	err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -24,6 +24,7 @@ type User struct {
 | 
			
		||||
	Email            string `json:"email" gorm:"index" validate:"max=50"`
 | 
			
		||||
	GitHubId         string `json:"github_id" gorm:"column:github_id;index"`
 | 
			
		||||
	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"`
 | 
			
		||||
	LarkId           string `json:"lark_id" gorm:"column:lark_id;index"`
 | 
			
		||||
	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database!
 | 
			
		||||
	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
 | 
			
		||||
	Quota            int64  `json:"quota" gorm:"bigint;default:0"`
 | 
			
		||||
@@ -41,21 +42,21 @@ func GetMaxUserId() int {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) {
 | 
			
		||||
    query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted)
 | 
			
		||||
    
 | 
			
		||||
    switch order {
 | 
			
		||||
    case "quota":
 | 
			
		||||
        query = query.Order("quota desc")
 | 
			
		||||
    case "used_quota":
 | 
			
		||||
        query = query.Order("used_quota desc")
 | 
			
		||||
    case "request_count":
 | 
			
		||||
        query = query.Order("request_count desc")
 | 
			
		||||
    default:
 | 
			
		||||
        query = query.Order("id desc")
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    err = query.Find(&users).Error
 | 
			
		||||
    return users, err
 | 
			
		||||
	query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted)
 | 
			
		||||
 | 
			
		||||
	switch order {
 | 
			
		||||
	case "quota":
 | 
			
		||||
		query = query.Order("quota desc")
 | 
			
		||||
	case "used_quota":
 | 
			
		||||
		query = query.Order("used_quota desc")
 | 
			
		||||
	case "request_count":
 | 
			
		||||
		query = query.Order("request_count desc")
 | 
			
		||||
	default:
 | 
			
		||||
		query = query.Order("id desc")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = query.Find(&users).Error
 | 
			
		||||
	return users, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SearchUsers(keyword string) (users []*User, err error) {
 | 
			
		||||
@@ -206,6 +207,14 @@ func (user *User) FillUserByGitHubId() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (user *User) FillUserByLarkId() error {
 | 
			
		||||
	if user.LarkId == "" {
 | 
			
		||||
		return errors.New("lark id 为空!")
 | 
			
		||||
	}
 | 
			
		||||
	DB.Where(User{LarkId: user.LarkId}).First(user)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (user *User) FillUserByWeChatId() error {
 | 
			
		||||
	if user.WeChatId == "" {
 | 
			
		||||
		return errors.New("WeChat id 为空!")
 | 
			
		||||
@@ -234,6 +243,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
 | 
			
		||||
	return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IsLarkIdAlreadyTaken(githubId string) bool {
 | 
			
		||||
	return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IsUsernameAlreadyTaken(username string) bool {
 | 
			
		||||
	return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -38,16 +38,26 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
 | 
			
		||||
		suffix += "completions_pro"
 | 
			
		||||
	case "ERNIE-Bot-4":
 | 
			
		||||
		suffix += "completions_pro"
 | 
			
		||||
	case "ERNIE-3.5-8K":
 | 
			
		||||
		suffix += "completions"
 | 
			
		||||
	case "ERNIE-Bot-8K":
 | 
			
		||||
		suffix += "ernie_bot_8k"
 | 
			
		||||
	case "ERNIE-Bot":
 | 
			
		||||
		suffix += "completions"
 | 
			
		||||
	case "ERNIE-Speed":
 | 
			
		||||
		suffix += "ernie_speed"
 | 
			
		||||
	case "ERNIE-Bot-turbo":
 | 
			
		||||
		suffix += "eb-instant"
 | 
			
		||||
	case "ERNIE-Speed":
 | 
			
		||||
		suffix += "ernie_speed"
 | 
			
		||||
	case "ERNIE-Bot-8K":
 | 
			
		||||
		suffix += "ernie_bot_8k"
 | 
			
		||||
	case "ERNIE-4.0-8K":
 | 
			
		||||
		suffix += "completions_pro"
 | 
			
		||||
	case "ERNIE-3.5-8K":
 | 
			
		||||
		suffix += "completions"
 | 
			
		||||
	case "ERNIE-Speed-8K":
 | 
			
		||||
		suffix += "ernie_speed"
 | 
			
		||||
	case "ERNIE-Speed-128K":
 | 
			
		||||
		suffix += "ernie-speed-128k"
 | 
			
		||||
	case "ERNIE-Lite-8K":
 | 
			
		||||
		suffix += "ernie-lite-8k"
 | 
			
		||||
	case "ERNIE-Tiny-8K":
 | 
			
		||||
		suffix += "ernie-tiny-8k"
 | 
			
		||||
	case "BLOOMZ-7B":
 | 
			
		||||
		suffix += "bloomz_7b1"
 | 
			
		||||
	case "Embedding-V1":
 | 
			
		||||
@@ -59,7 +69,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
 | 
			
		||||
	case "tao-8k":
 | 
			
		||||
		suffix += "tao_8k"
 | 
			
		||||
	default:
 | 
			
		||||
		suffix += meta.ActualModelName
 | 
			
		||||
		suffix += strings.ToLower(meta.ActualModelName)
 | 
			
		||||
	}
 | 
			
		||||
	fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix)
 | 
			
		||||
	var accessToken string
 | 
			
		||||
 
 | 
			
		||||
@@ -1,11 +1,18 @@
 | 
			
		||||
package baidu
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"ERNIE-Bot-4",
 | 
			
		||||
	"ERNIE-Bot-8K",
 | 
			
		||||
	"ERNIE-Bot",
 | 
			
		||||
	"ERNIE-Speed",
 | 
			
		||||
	"ERNIE-Bot-turbo",
 | 
			
		||||
	"ERNIE-4.0-8K",
 | 
			
		||||
	"ERNIE-Bot-8K-0922",
 | 
			
		||||
	"ERNIE-3.5-8K",
 | 
			
		||||
	"ERNIE-Lite-8K-0922",
 | 
			
		||||
	"ERNIE-Speed-8K",
 | 
			
		||||
	"ERNIE-3.5-4K-0205",
 | 
			
		||||
	"ERNIE-3.5-8K-0205",
 | 
			
		||||
	"ERNIE-3.5-8K-1222",
 | 
			
		||||
	"ERNIE-Lite-8K",
 | 
			
		||||
	"ERNIE-Speed-128K",
 | 
			
		||||
	"ERNIE-Tiny-8K",
 | 
			
		||||
	"BLOOMZ-7B",
 | 
			
		||||
	"Embedding-V1",
 | 
			
		||||
	"bge-large-zh",
 | 
			
		||||
	"bge-large-en",
 | 
			
		||||
 
 | 
			
		||||
@@ -28,7 +28,7 @@ type ChatRequest struct {
 | 
			
		||||
		} `json:"message"`
 | 
			
		||||
		Functions struct {
 | 
			
		||||
			Text []model.Function `json:"text,omitempty"`
 | 
			
		||||
		} `json:"functions"`
 | 
			
		||||
		} `json:"functions,omitempty"`
 | 
			
		||||
	} `json:"payload"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/util"
 | 
			
		||||
	"io"
 | 
			
		||||
@@ -35,6 +36,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
 | 
			
		||||
	if a.APIVersion == "v4" {
 | 
			
		||||
		return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
 | 
			
		||||
	}
 | 
			
		||||
	if meta.Mode == constant.RelayModeEmbeddings {
 | 
			
		||||
		return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
 | 
			
		||||
	}
 | 
			
		||||
	method := "invoke"
 | 
			
		||||
	if meta.IsStream {
 | 
			
		||||
		method = "sse-invoke"
 | 
			
		||||
@@ -53,18 +57,24 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
	// TopP (0.0, 1.0)
 | 
			
		||||
	request.TopP = math.Min(0.99, request.TopP)
 | 
			
		||||
	request.TopP = math.Max(0.01, request.TopP)
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case constant.RelayModeEmbeddings:
 | 
			
		||||
		baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
 | 
			
		||||
		return baiduEmbeddingRequest, nil
 | 
			
		||||
	default:
 | 
			
		||||
		// TopP (0.0, 1.0)
 | 
			
		||||
		request.TopP = math.Min(0.99, request.TopP)
 | 
			
		||||
		request.TopP = math.Max(0.01, request.TopP)
 | 
			
		||||
 | 
			
		||||
	// Temperature (0.0, 1.0)
 | 
			
		||||
	request.Temperature = math.Min(0.99, request.Temperature)
 | 
			
		||||
	request.Temperature = math.Max(0.01, request.Temperature)
 | 
			
		||||
	a.SetVersionByModeName(request.Model)
 | 
			
		||||
	if a.APIVersion == "v4" {
 | 
			
		||||
		return request, nil
 | 
			
		||||
		// Temperature (0.0, 1.0)
 | 
			
		||||
		request.Temperature = math.Min(0.99, request.Temperature)
 | 
			
		||||
		request.Temperature = math.Max(0.01, request.Temperature)
 | 
			
		||||
		a.SetVersionByModeName(request.Model)
 | 
			
		||||
		if a.APIVersion == "v4" {
 | 
			
		||||
			return request, nil
 | 
			
		||||
		}
 | 
			
		||||
		return ConvertRequest(*request), nil
 | 
			
		||||
	}
 | 
			
		||||
	return ConvertRequest(*request), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
@@ -84,14 +94,26 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
 | 
			
		||||
	if a.APIVersion == "v4" {
 | 
			
		||||
		return a.DoResponseV4(c, resp, meta)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if meta.IsStream {
 | 
			
		||||
		err, usage = StreamHandler(c, resp)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = Handler(c, resp)
 | 
			
		||||
		if meta.Mode == constant.RelayModeEmbeddings {
 | 
			
		||||
			err, usage = EmbeddingsHandler(c, resp)
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage = Handler(c, resp)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
 | 
			
		||||
	return &EmbeddingRequest{
 | 
			
		||||
		Model: "embedding-2",
 | 
			
		||||
		Input: request.Input.(string),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetModelList() []string {
 | 
			
		||||
	return ModelList
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,5 +2,5 @@ package zhipu
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
 | 
			
		||||
	"glm-4", "glm-4v", "glm-3-turbo",
 | 
			
		||||
	"glm-4", "glm-4v", "glm-3-turbo", "embedding-2",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -254,3 +254,50 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	var zhipuResponse EmbeddingRespone
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &zhipuResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func embeddingResponseZhipu2OpenAI(response *EmbeddingRespone) *openai.EmbeddingResponse {
 | 
			
		||||
	openAIEmbeddingResponse := openai.EmbeddingResponse{
 | 
			
		||||
		Object: "list",
 | 
			
		||||
		Data:   make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
 | 
			
		||||
		Model:  response.Model,
 | 
			
		||||
		Usage: model.Usage{
 | 
			
		||||
			PromptTokens:     response.PromptTokens,
 | 
			
		||||
			CompletionTokens: response.CompletionTokens,
 | 
			
		||||
			TotalTokens:      response.Usage.TotalTokens,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, item := range response.Embeddings {
 | 
			
		||||
		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
 | 
			
		||||
			Object:    `embedding`,
 | 
			
		||||
			Index:     item.Index,
 | 
			
		||||
			Embedding: item.Embedding,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	return &openAIEmbeddingResponse
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -44,3 +44,21 @@ type tokenData struct {
 | 
			
		||||
	Token      string
 | 
			
		||||
	ExpiryTime time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddingRequest struct {
 | 
			
		||||
	Model string `json:"model"`
 | 
			
		||||
	Input string `json:"input"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddingRespone struct {
 | 
			
		||||
	Model       string          `json:"model"`
 | 
			
		||||
	Object      string          `json:"object"`
 | 
			
		||||
	Embeddings  []EmbeddingData `json:"data"`
 | 
			
		||||
	model.Usage `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddingData struct {
 | 
			
		||||
	Index     int       `json:"index"`
 | 
			
		||||
	Object    string    `json:"object"`
 | 
			
		||||
	Embedding []float64 `json:"embedding"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -46,6 +46,15 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
 | 
			
		||||
	} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	//if strings.Contains(err.Message, "quota") {
 | 
			
		||||
	//	return true
 | 
			
		||||
	//}
 | 
			
		||||
	if strings.Contains(err.Message, "credit") {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if strings.Contains(err.Message, "balance") {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package router
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/controller"
 | 
			
		||||
	"github.com/songquanpeng/one-api/controller/auth"
 | 
			
		||||
	"github.com/songquanpeng/one-api/middleware"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-contrib/gzip"
 | 
			
		||||
@@ -21,11 +22,13 @@ func SetApiRouter(router *gin.Engine) {
 | 
			
		||||
		apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
 | 
			
		||||
		apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
 | 
			
		||||
		apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
 | 
			
		||||
		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
 | 
			
		||||
		apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
 | 
			
		||||
		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
 | 
			
		||||
		apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
 | 
			
		||||
		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth)
 | 
			
		||||
		apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth)
 | 
			
		||||
		apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode)
 | 
			
		||||
		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth)
 | 
			
		||||
		apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), auth.WeChatBind)
 | 
			
		||||
		apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
 | 
			
		||||
		apiRouter.POST("/topup", middleware.AdminAuth(), controller.AdminTopUp)
 | 
			
		||||
 | 
			
		||||
		userRoute := apiRouter.Group("/user")
 | 
			
		||||
		{
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,9 @@
 | 
			
		||||
 | 
			
		||||
> 每个文件夹代表一个主题,欢迎提交你的主题
 | 
			
		||||
 | 
			
		||||
> [!WARNING]
 | 
			
		||||
> 不是每一个主题都及时同步了所有功能,由于精力有限,优先更新默认主题,其他主题欢迎 & 期待 PR
 | 
			
		||||
 | 
			
		||||
## 提交新的主题
 | 
			
		||||
 | 
			
		||||
> 欢迎在页面底部保留你和 One API 的版权信息以及指向链接
 | 
			
		||||
 
 | 
			
		||||
@@ -24,6 +24,7 @@ import EditRedemption from './pages/Redemption/EditRedemption';
 | 
			
		||||
import TopUp from './pages/TopUp';
 | 
			
		||||
import Log from './pages/Log';
 | 
			
		||||
import Chat from './pages/Chat';
 | 
			
		||||
import LarkOAuth from './components/LarkOAuth';
 | 
			
		||||
 | 
			
		||||
const Home = lazy(() => import('./pages/Home'));
 | 
			
		||||
const About = lazy(() => import('./pages/About'));
 | 
			
		||||
@@ -239,6 +240,14 @@ function App() {
 | 
			
		||||
          </Suspense>
 | 
			
		||||
        }
 | 
			
		||||
      />
 | 
			
		||||
      <Route
 | 
			
		||||
        path='/oauth/lark'
 | 
			
		||||
        element={
 | 
			
		||||
          <Suspense fallback={<Loading></Loading>}>
 | 
			
		||||
            <LarkOAuth />
 | 
			
		||||
          </Suspense>
 | 
			
		||||
        }
 | 
			
		||||
      />
 | 
			
		||||
      <Route
 | 
			
		||||
        path='/setting'
 | 
			
		||||
        element={
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										58
									
								
								web/default/src/components/LarkOAuth.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								web/default/src/components/LarkOAuth.js
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,58 @@
 | 
			
		||||
import React, { useContext, useEffect, useState } from 'react';
 | 
			
		||||
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
 | 
			
		||||
import { useNavigate, useSearchParams } from 'react-router-dom';
 | 
			
		||||
import { API, showError, showSuccess } from '../helpers';
 | 
			
		||||
import { UserContext } from '../context/User';
 | 
			
		||||
 | 
			
		||||
const LarkOAuth = () => {
 | 
			
		||||
  const [searchParams, setSearchParams] = useSearchParams();
 | 
			
		||||
 | 
			
		||||
  const [userState, userDispatch] = useContext(UserContext);
 | 
			
		||||
  const [prompt, setPrompt] = useState('处理中...');
 | 
			
		||||
  const [processing, setProcessing] = useState(true);
 | 
			
		||||
 | 
			
		||||
  let navigate = useNavigate();
 | 
			
		||||
 | 
			
		||||
  const sendCode = async (code, state, count) => {
 | 
			
		||||
    const res = await API.get(`/api/oauth/lark?code=${code}&state=${state}`);
 | 
			
		||||
    const { success, message, data } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      if (message === 'bind') {
 | 
			
		||||
        showSuccess('绑定成功!');
 | 
			
		||||
        navigate('/setting');
 | 
			
		||||
      } else {
 | 
			
		||||
        userDispatch({ type: 'login', payload: data });
 | 
			
		||||
        localStorage.setItem('user', JSON.stringify(data));
 | 
			
		||||
        showSuccess('登录成功!');
 | 
			
		||||
        navigate('/');
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
      showError(message);
 | 
			
		||||
      if (count === 0) {
 | 
			
		||||
        setPrompt(`操作失败,重定向至登录界面中...`);
 | 
			
		||||
        navigate('/setting'); // in case this is failed to bind lark
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      count++;
 | 
			
		||||
      setPrompt(`出现错误,第 ${count} 次重试中...`);
 | 
			
		||||
      await new Promise((resolve) => setTimeout(resolve, count * 2000));
 | 
			
		||||
      await sendCode(code, state, count);
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
    let code = searchParams.get('code');
 | 
			
		||||
    let state = searchParams.get('state');
 | 
			
		||||
    sendCode(code, state, 0).then();
 | 
			
		||||
  }, []);
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <Segment style={{ minHeight: '300px' }}>
 | 
			
		||||
      <Dimmer active inverted>
 | 
			
		||||
        <Loader size='large'>{prompt}</Loader>
 | 
			
		||||
      </Dimmer>
 | 
			
		||||
    </Segment>
 | 
			
		||||
  );
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
export default LarkOAuth;
 | 
			
		||||
@@ -3,7 +3,8 @@ import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } f
 | 
			
		||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
 | 
			
		||||
import { UserContext } from '../context/User';
 | 
			
		||||
import { API, getLogo, showError, showSuccess, showWarning } from '../helpers';
 | 
			
		||||
import { onGitHubOAuthClicked } from './utils';
 | 
			
		||||
import { onGitHubOAuthClicked, onLarkOAuthClicked } from './utils';
 | 
			
		||||
import larkIcon from '../images/lark.svg';
 | 
			
		||||
 | 
			
		||||
const LoginForm = () => {
 | 
			
		||||
  const [inputs, setInputs] = useState({
 | 
			
		||||
@@ -124,7 +125,7 @@ const LoginForm = () => {
 | 
			
		||||
            点击注册
 | 
			
		||||
          </Link>
 | 
			
		||||
        </Message>
 | 
			
		||||
        {status.github_oauth || status.wechat_login ? (
 | 
			
		||||
        {status.github_oauth || status.wechat_login || status.lark_client_id ? (
 | 
			
		||||
          <>
 | 
			
		||||
            <Divider horizontal>Or</Divider>
 | 
			
		||||
            {status.github_oauth ? (
 | 
			
		||||
@@ -137,6 +138,18 @@ const LoginForm = () => {
 | 
			
		||||
            ) : (
 | 
			
		||||
              <></>
 | 
			
		||||
            )}
 | 
			
		||||
            {status.lark_client_id ? (
 | 
			
		||||
              <Button
 | 
			
		||||
                // circular
 | 
			
		||||
                color=''
 | 
			
		||||
                onClick={() => onLarkOAuthClicked(status.lark_client_id)}
 | 
			
		||||
                style={{ padding: 0, width: 36, height: 36 }}
 | 
			
		||||
              >
 | 
			
		||||
                <img src={larkIcon} width={36} height={36} />
 | 
			
		||||
              </Button>
 | 
			
		||||
            ) : (
 | 
			
		||||
              <></>
 | 
			
		||||
            )}
 | 
			
		||||
            {status.wechat_login ? (
 | 
			
		||||
              <Button
 | 
			
		||||
                circular
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,7 @@ import { Link, useNavigate } from 'react-router-dom';
 | 
			
		||||
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
 | 
			
		||||
import Turnstile from 'react-turnstile';
 | 
			
		||||
import { UserContext } from '../context/User';
 | 
			
		||||
import { onGitHubOAuthClicked } from './utils';
 | 
			
		||||
import { onGitHubOAuthClicked, onLarkOAuthClicked } from './utils';
 | 
			
		||||
 | 
			
		||||
const PersonalSetting = () => {
 | 
			
		||||
  const [userState, userDispatch] = useContext(UserContext);
 | 
			
		||||
@@ -247,6 +247,11 @@ const PersonalSetting = () => {
 | 
			
		||||
          <Button onClick={()=>{onGitHubOAuthClicked(status.github_client_id)}}>绑定 GitHub 账号</Button>
 | 
			
		||||
        )
 | 
			
		||||
      }
 | 
			
		||||
      {
 | 
			
		||||
        status.lark_client_id && (
 | 
			
		||||
          <Button onClick={()=>{onLarkOAuthClicked(status.lark_client_id)}}>绑定飞书账号</Button>
 | 
			
		||||
        )
 | 
			
		||||
      }
 | 
			
		||||
      <Button
 | 
			
		||||
        onClick={() => {
 | 
			
		||||
          setShowEmailBindModal(true);
 | 
			
		||||
 
 | 
			
		||||
@@ -10,6 +10,8 @@ const SystemSetting = () => {
 | 
			
		||||
    GitHubOAuthEnabled: '',
 | 
			
		||||
    GitHubClientId: '',
 | 
			
		||||
    GitHubClientSecret: '',
 | 
			
		||||
    LarkClientId: '',
 | 
			
		||||
    LarkClientSecret: '',
 | 
			
		||||
    Notice: '',
 | 
			
		||||
    SMTPServer: '',
 | 
			
		||||
    SMTPPort: '',
 | 
			
		||||
@@ -109,6 +111,8 @@ const SystemSetting = () => {
 | 
			
		||||
      name === 'ServerAddress' ||
 | 
			
		||||
      name === 'GitHubClientId' ||
 | 
			
		||||
      name === 'GitHubClientSecret' ||
 | 
			
		||||
      name === 'LarkClientId' ||
 | 
			
		||||
      name === 'LarkClientSecret' ||
 | 
			
		||||
      name === 'WeChatServerAddress' ||
 | 
			
		||||
      name === 'WeChatServerToken' ||
 | 
			
		||||
      name === 'WeChatAccountQRCodeImageURL' ||
 | 
			
		||||
@@ -212,6 +216,18 @@ const SystemSetting = () => {
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
   const submitLarkOAuth = async () => {
 | 
			
		||||
    if (originInputs['LarkClientId'] !== inputs.LarkClientId) {
 | 
			
		||||
      await updateOption('LarkClientId', inputs.LarkClientId);
 | 
			
		||||
    }
 | 
			
		||||
    if (
 | 
			
		||||
      originInputs['LarkClientSecret'] !== inputs.LarkClientSecret &&
 | 
			
		||||
      inputs.LarkClientSecret !== ''
 | 
			
		||||
    ) {
 | 
			
		||||
      await updateOption('LarkClientSecret', inputs.LarkClientSecret);
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const submitTurnstile = async () => {
 | 
			
		||||
    if (originInputs['TurnstileSiteKey'] !== inputs.TurnstileSiteKey) {
 | 
			
		||||
      await updateOption('TurnstileSiteKey', inputs.TurnstileSiteKey);
 | 
			
		||||
@@ -469,6 +485,44 @@ const SystemSetting = () => {
 | 
			
		||||
            保存 GitHub OAuth 设置
 | 
			
		||||
          </Form.Button>
 | 
			
		||||
          <Divider />
 | 
			
		||||
          <Header as='h3'>
 | 
			
		||||
            配置飞书授权登录
 | 
			
		||||
            <Header.Subheader>
 | 
			
		||||
              用以支持通过飞书进行登录注册,
 | 
			
		||||
              <a href='https://open.feishu.cn/app' target='_blank'>
 | 
			
		||||
                点击此处
 | 
			
		||||
              </a>
 | 
			
		||||
              管理你的飞书应用
 | 
			
		||||
            </Header.Subheader>
 | 
			
		||||
          </Header>
 | 
			
		||||
          <Message>
 | 
			
		||||
            主页链接填 <code>{inputs.ServerAddress}</code>
 | 
			
		||||
            ,重定向 URL 填{' '}
 | 
			
		||||
            <code>{`${inputs.ServerAddress}/oauth/lark`}</code>
 | 
			
		||||
          </Message>
 | 
			
		||||
          <Form.Group widths={3}>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='App ID'
 | 
			
		||||
              name='LarkClientId'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
              value={inputs.LarkClientId}
 | 
			
		||||
              placeholder='输入 App ID'
 | 
			
		||||
            />
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='App Secret'
 | 
			
		||||
              name='LarkClientSecret'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              type='password'
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
              value={inputs.LarkClientSecret}
 | 
			
		||||
              placeholder='敏感信息不会发送到前端显示'
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Button onClick={submitLarkOAuth}>
 | 
			
		||||
            保存飞书 OAuth 设置
 | 
			
		||||
          </Form.Button>
 | 
			
		||||
          <Divider />
 | 
			
		||||
          <Header as='h3'>
 | 
			
		||||
            配置 WeChat Server
 | 
			
		||||
            <Header.Subheader>
 | 
			
		||||
 
 | 
			
		||||
@@ -17,4 +17,13 @@ export async function onGitHubOAuthClicked(github_client_id) {
 | 
			
		||||
  window.open(
 | 
			
		||||
    `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function onLarkOAuthClicked(lark_client_id) {
 | 
			
		||||
  const state = await getOAuthState();
 | 
			
		||||
  if (!state) return;
 | 
			
		||||
  let redirect_uri = `${window.location.origin}/oauth/lark`;
 | 
			
		||||
  window.open(
 | 
			
		||||
    `https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										1
									
								
								web/default/src/images/lark.svg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								web/default/src/images/lark.svg
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| 
		 After Width: | Height: | Size: 5.4 KiB  | 
@@ -15,7 +15,8 @@ const EditToken = () => {
 | 
			
		||||
    remain_quota: isEdit ? 0 : 500000,
 | 
			
		||||
    expired_time: -1,
 | 
			
		||||
    unlimited_quota: false,
 | 
			
		||||
    models: []
 | 
			
		||||
    models: [],
 | 
			
		||||
    subnet: "",
 | 
			
		||||
  };
 | 
			
		||||
  const [inputs, setInputs] = useState(originInputs);
 | 
			
		||||
  const { name, remain_quota, expired_time, unlimited_quota } = inputs;
 | 
			
		||||
@@ -153,6 +154,16 @@ const EditToken = () => {
 | 
			
		||||
              options={modelOptions}
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='IP 限制'
 | 
			
		||||
              name='subnet'
 | 
			
		||||
              placeholder={'请输入允许访问的网段,例如:192.168.0.0/24'}
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              value={inputs.subnet}
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='过期时间'
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@ const TopUp = () => {
 | 
			
		||||
  const [topUpLink, setTopUpLink] = useState('');
 | 
			
		||||
  const [userQuota, setUserQuota] = useState(0);
 | 
			
		||||
  const [isSubmitting, setIsSubmitting] = useState(false);
 | 
			
		||||
  const [user, setUser] = useState({});
 | 
			
		||||
 | 
			
		||||
  const topUp = async () => {
 | 
			
		||||
    if (redemptionCode === '') {
 | 
			
		||||
@@ -41,7 +42,14 @@ const TopUp = () => {
 | 
			
		||||
      showError('超级管理员未设置充值链接!');
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    window.open(topUpLink, '_blank');
 | 
			
		||||
    let url = new URL(topUpLink);
 | 
			
		||||
    let username = user.username;
 | 
			
		||||
    let user_id = user.id;
 | 
			
		||||
    // add  username and user_id to the topup link
 | 
			
		||||
    url.searchParams.append('username', username);
 | 
			
		||||
    url.searchParams.append('user_id', user_id);
 | 
			
		||||
    url.searchParams.append('transaction_id', crypto.randomUUID());
 | 
			
		||||
    window.open(url.toString(), '_blank');
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const getUserQuota = async ()=>{
 | 
			
		||||
@@ -49,6 +57,7 @@ const TopUp = () => {
 | 
			
		||||
    const {success, message, data} = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      setUserQuota(data.quota);
 | 
			
		||||
      setUser(data);
 | 
			
		||||
    } else {
 | 
			
		||||
      showError(message);
 | 
			
		||||
    }
 | 
			
		||||
@@ -80,7 +89,7 @@ const TopUp = () => {
 | 
			
		||||
              }}
 | 
			
		||||
            />
 | 
			
		||||
            <Button color='green' onClick={openTopUpLink}>
 | 
			
		||||
              获取兑换码
 | 
			
		||||
              充值
 | 
			
		||||
            </Button>
 | 
			
		||||
            <Button color='yellow' onClick={topUp} disabled={isSubmitting}>
 | 
			
		||||
                {isSubmitting ? '兑换中...' : '兑换'}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user