Compare commits

..

13 Commits

Author SHA1 Message Date
JustSong
d93cb8f645 feat: able to configure ratio for different models (close #26) 2023-04-28 19:16:37 +08:00
JustSong
b08cd7e104 refactor: use tiktoken-go to calculate token number 2023-04-28 18:36:17 +08:00
JustSong
aea6c859e7 fix: relay bug fix 2023-04-28 18:16:59 +08:00
JustSong
480e789cd8 feat: support configuring ratio when estimating token number in stream mode 2023-04-28 17:25:05 +08:00
JustSong
23ec541ba6 refactor: improve relay's implementation 2023-04-28 17:11:57 +08:00
JustSong
053bb85a1c feat: now use token as the unit of quota (close #33) 2023-04-28 16:58:55 +08:00
JustSong
601fa5cea8 refactor: use quota instead of times 2023-04-28 14:57:20 +08:00
JustSong
7a5057f02d fix: check user's role when manage user (#30) 2023-04-28 09:47:03 +08:00
JustSong
c76027a210 style: add bottom margin for unlimited times button 2023-04-27 17:18:07 +08:00
JustSong
f97c2b4c22 feat: able to set top up link now 2023-04-27 16:32:21 +08:00
JustSong
54b1e4adef fix: check user status when validating token (#23) 2023-04-27 15:05:33 +08:00
JustSong
9272884381 fix: root user cannot demote itself now (close #30) 2023-04-27 14:45:12 +08:00
JustSong
195e94a75d fix: fix MySQL syntax error (#54) 2023-04-27 11:10:10 +08:00
17 changed files with 411 additions and 91 deletions

View File

@@ -52,14 +52,18 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
+ [x] 自定义渠道 + [x] 自定义渠道
2. 支持通过负载均衡的方式访问多个渠道。 2. 支持通过负载均衡的方式访问多个渠道。
3. 支持单个访问渠道设置多个 API Key利用起来你的多个 API Key。 3. 支持单个访问渠道设置多个 API Key利用起来你的多个 API Key。
4. 支持设置令牌的过期时间和使用次数 4. 支持 HTTP SSE可以通过流式传输实现打字机效果
5. 支持 HTTP SSE 5. 支持设置令牌的过期时间和使用次数
6. 多种用户登录注册方式: 6. 支持批量生成和导出兑换码,可使用兑换码为令牌进行充值。
+ 邮箱登录注册以及通过邮箱进行密码重置 7. 支持为新用户设置初始配额
+ [GitHub 开放授权](https://github.com/settings/applications/new) 8. 支持发布公告,在线修改关于页面,设置充值链接,自定义页脚
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server) 9. 支持通过系统访问令牌访问管理 API
7. 支持用户管理。 10. 多种用户登录注册方式:
8. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式 + 邮箱登录注册以及通过邮箱进行密码重置
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
11. 支持用户管理。
12. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
## 部署 ## 部署
### 基于 Docker 进行部署 ### 基于 Docker 进行部署

View File

@@ -11,6 +11,7 @@ var Version = "v0.0.0" // this hard coding will be replaced automatic
var SystemName = "One API" var SystemName = "One API"
var ServerAddress = "http://localhost:3000" var ServerAddress = "http://localhost:3000"
var Footer = "" var Footer = ""
var TopUpLink = ""
var UsingSQLite = false var UsingSQLite = false
@@ -48,6 +49,11 @@ var TurnstileSecretKey = ""
var QuotaForNewUser = 100 var QuotaForNewUser = 100
// https://platform.openai.com/docs/models/model-endpoint-compatibility
var RatioGPT3dot5 float64 = 2
var RatioGPT4 float64 = 30
var RatioGPT4_32k float64 = 60
const ( const (
RoleGuestUser = 0 RoleGuestUser = 0
RoleCommonUser = 1 RoleCommonUser = 1

View File

@@ -26,6 +26,7 @@ func GetStatus(c *gin.Context) {
"server_address": common.ServerAddress, "server_address": common.ServerAddress,
"turnstile_check": common.TurnstileCheckEnabled, "turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey, "turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
}, },
}) })
return return

View File

@@ -2,8 +2,11 @@ package controller
import ( import (
"bufio" "bufio"
"bytes"
"encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
@@ -11,16 +14,46 @@ import (
"strings" "strings"
) )
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
//Stream bool `json:"stream"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TextResponse struct {
Usage `json:"usage"`
}
type StreamResponse struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
var tokenEncoder, _ = tiktoken.GetEncoding("cl100k_base")
func countToken(text string) int {
token := tokenEncoder.Encode(text, nil, nil)
return len(token)
}
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
channelType := c.GetInt("channel") err := relayHelper(c)
tokenId := c.GetInt("token_id")
isUnlimitedTimes := c.GetBool("unlimited_times")
baseURL := common.ChannelBaseURLs[channelType]
if channelType == common.ChannelTypeCustom {
baseURL = c.GetString("base_url")
}
requestURL := c.Request.URL.String()
req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, requestURL), c.Request.Body)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"error": gin.H{ "error": gin.H{
@@ -28,42 +61,90 @@ func Relay(c *gin.Context) {
"type": "one_api_error", "type": "one_api_error",
}, },
}) })
return
} }
//req.Header = c.Request.Header.Clone() }
// Fix HTTP Decompression failed
// https://github.com/stoplightio/prism/issues/1064#issuecomment-824682360 func relayHelper(c *gin.Context) error {
//req.Header.Del("Accept-Encoding") channelType := c.GetInt("channel")
tokenId := c.GetInt("token_id")
consumeQuota := c.GetBool("consume_quota")
baseURL := common.ChannelBaseURLs[channelType]
if channelType == common.ChannelTypeCustom {
baseURL = c.GetString("base_url")
}
var textRequest TextRequest
if consumeQuota {
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
return err
}
err = c.Request.Body.Close()
if err != nil {
return err
}
err = json.Unmarshal(requestBody, &textRequest)
if err != nil {
return err
}
// Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
}
requestURL := c.Request.URL.String()
req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, requestURL), c.Request.Body)
if err != nil {
return err
}
err = c.Request.Body.Close()
if err != nil {
return err
}
req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
req.Header.Set("Connection", c.Request.Header.Get("Connection")) req.Header.Set("Connection", c.Request.Header.Get("Connection"))
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ return err
"error": gin.H{ }
"message": err.Error(), err = req.Body.Close()
"type": "one_api_error", if err != nil {
}, return err
})
return
} }
var textResponse TextResponse
isStream := resp.Header.Get("Content-Type") == "text/event-stream"
var streamResponseText string
defer func() { defer func() {
err := req.Body.Close() if consumeQuota {
if err != nil { quota := 0
common.SysError("Error closing request body: " + err.Error()) if isStream {
} var text string
if !isUnlimitedTimes && requestURL == "/v1/chat/completions" { for _, message := range textRequest.Messages {
err := model.DecreaseTokenRemainTimesById(tokenId) text += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
}
text += fmt.Sprintf("%s: %s\n", "assistant", streamResponseText)
quota = countToken(text) + 3
} else {
quota = textResponse.Usage.TotalTokens
}
ratio := common.RatioGPT3dot5
if strings.HasPrefix(textRequest.Model, "gpt-4-32k") {
ratio = common.RatioGPT4_32k
} else if strings.HasPrefix(textRequest.Model, "gpt-4") {
ratio = common.RatioGPT4
} else {
ratio = common.RatioGPT3dot5
}
quota = int(float64(quota) * ratio)
err := model.ConsumeTokenQuota(tokenId, quota)
if err != nil { if err != nil {
common.SysError("Error decreasing token remain times: " + err.Error()) common.SysError("Error consuming token remain quota: " + err.Error())
} }
} }
}() }()
isStream := resp.Header.Get("Content-Type") == "text/event-stream"
if isStream { if isStream {
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -87,6 +168,18 @@ func Relay(c *gin.Context) {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
dataChan <- data dataChan <- data
data = data[6:]
if data != "[DONE]" {
var streamResponse StreamResponse
err = json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("Error unmarshalling stream response: " + err.Error())
return
}
for _, choice := range streamResponse.Choices {
streamResponseText += choice.Delta.Content
}
}
} }
stopChan <- true stopChan <- true
}() }()
@@ -103,20 +196,48 @@ func Relay(c *gin.Context) {
return false return false
} }
}) })
return err = resp.Body.Close()
if err != nil {
return err
}
return nil
} else { } else {
for k, v := range resp.Header { for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0]) c.Writer.Header().Set(k, v[0])
} }
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
err = resp.Body.Close()
if err != nil {
return err
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return err
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
_, err = io.Copy(c.Writer, resp.Body) _, err = io.Copy(c.Writer, resp.Body)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ return err
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
} }
err = resp.Body.Close()
if err != nil {
return err
}
return nil
} }
} }
func RelayNotImplemented(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": "Not Implemented",
"type": "one_api_error",
},
})
}

View File

@@ -102,8 +102,8 @@ func AddToken(c *gin.Context) {
ExpiredTime: token.ExpiredTime, ExpiredTime: token.ExpiredTime,
} }
if isAdmin { if isAdmin {
cleanToken.RemainTimes = token.RemainTimes cleanToken.RemainQuota = token.RemainQuota
cleanToken.UnlimitedTimes = token.UnlimitedTimes cleanToken.UnlimitedQuota = token.UnlimitedQuota
} else { } else {
userId := c.GetInt("id") userId := c.GetInt("id")
quota, err := model.GetUserQuota(userId) quota, err := model.GetUserQuota(userId)
@@ -115,7 +115,7 @@ func AddToken(c *gin.Context) {
return return
} }
if quota > 0 { if quota > 0 {
cleanToken.RemainTimes = quota cleanToken.RemainQuota = quota
} }
} }
err = cleanToken.Insert() err = cleanToken.Insert()
@@ -128,7 +128,7 @@ func AddToken(c *gin.Context) {
} }
if !isAdmin { if !isAdmin {
// update user quota // update user quota
err = model.DecreaseUserQuota(c.GetInt("id"), cleanToken.RemainTimes) err = model.DecreaseUserQuota(c.GetInt("id"), cleanToken.RemainQuota)
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
@@ -184,7 +184,7 @@ func UpdateToken(c *gin.Context) {
}) })
return return
} }
if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainTimes <= 0 && !cleanToken.UnlimitedTimes { if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "令牌可用次数已用尽,无法启用,请先修改令牌剩余次数,或者设置为无限次数", "message": "令牌可用次数已用尽,无法启用,请先修改令牌剩余次数,或者设置为无限次数",
@@ -199,8 +199,8 @@ func UpdateToken(c *gin.Context) {
cleanToken.Name = token.Name cleanToken.Name = token.Name
cleanToken.ExpiredTime = token.ExpiredTime cleanToken.ExpiredTime = token.ExpiredTime
if isAdmin { if isAdmin {
cleanToken.RemainTimes = token.RemainTimes cleanToken.RemainQuota = token.RemainQuota
cleanToken.UnlimitedTimes = token.UnlimitedTimes cleanToken.UnlimitedQuota = token.UnlimitedQuota
} }
} }
err = cleanToken.Update() err = cleanToken.Update()

View File

@@ -539,9 +539,23 @@ func ManageUser(c *gin.Context) {
switch req.Action { switch req.Action {
case "disable": case "disable":
user.Status = common.UserStatusDisabled user.Status = common.UserStatusDisabled
if user.Role == common.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法禁用超级管理员用户",
})
return
}
case "enable": case "enable":
user.Status = common.UserStatusEnabled user.Status = common.UserStatusEnabled
case "delete": case "delete":
if user.Role == common.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法删除超级管理员用户",
})
return
}
if err := user.Delete(); err != nil { if err := user.Delete(); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -557,8 +571,29 @@ func ManageUser(c *gin.Context) {
}) })
return return
} }
if user.Role >= common.RoleAdminUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户已经是管理员",
})
return
}
user.Role = common.RoleAdminUser user.Role = common.RoleAdminUser
case "demote": case "demote":
if user.Role == common.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法降级超级管理员用户",
})
return
}
if user.Role == common.RoleCommonUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户已经是普通用户",
})
return
}
user.Role = common.RoleCommonUser user.Role = common.RoleCommonUser
} }

2
go.mod
View File

@@ -25,6 +25,7 @@ require (
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
@@ -44,6 +45,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.7 // indirect github.com/pelletier/go-toml/v2 v2.0.7 // indirect
github.com/pkoukk/tiktoken-go v0.1.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect

4
go.sum
View File

@@ -14,6 +14,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g= github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs= github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs=
@@ -119,6 +121,8 @@ github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZO
github.com/pelletier/go-toml/v2 v2.0.7 h1:muncTPStnKRos5dpVKULv2FVd4bMOhNePj9CjgDb8Us= github.com/pelletier/go-toml/v2 v2.0.7 h1:muncTPStnKRos5dpVKULv2FVd4bMOhNePj9CjgDb8Us=
github.com/pelletier/go-toml/v2 v2.0.7/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek= github.com/pelletier/go-toml/v2 v2.0.7/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo=
github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=

View File

@@ -98,9 +98,29 @@ func TokenAuth() func(c *gin.Context) {
c.Abort() c.Abort()
return return
} }
if !model.IsUserEnabled(token.UserId) {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": "用户已被封禁",
"type": "one_api_error",
},
})
c.Abort()
return
}
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("unlimited_times", token.UnlimitedTimes) requestURL := c.Request.URL.String()
consumeQuota := false
switch requestURL {
case "/v1/chat/completions":
consumeQuota = !token.UnlimitedQuota
case "/v1/completions":
consumeQuota = !token.UnlimitedQuota
case "/v1/edits":
consumeQuota = !token.UnlimitedQuota
}
c.Set("consume_quota", consumeQuota)
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1]) c.Set("channelId", parts[1])

View File

@@ -47,6 +47,10 @@ func InitOptionMap() {
common.OptionMap["TurnstileSiteKey"] = "" common.OptionMap["TurnstileSiteKey"] = ""
common.OptionMap["TurnstileSecretKey"] = "" common.OptionMap["TurnstileSecretKey"] = ""
common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
common.OptionMap["RatioGPT3dot5"] = strconv.FormatFloat(common.RatioGPT3dot5, 'f', -1, 64)
common.OptionMap["RatioGPT4"] = strconv.FormatFloat(common.RatioGPT4, 'f', -1, 64)
common.OptionMap["RatioGPT4_32k"] = strconv.FormatFloat(common.RatioGPT4_32k, 'f', -1, 64)
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMapRWMutex.Unlock() common.OptionMapRWMutex.Unlock()
options, _ := AllOption() options, _ := AllOption()
for _, option := range options { for _, option := range options {
@@ -134,5 +138,13 @@ func updateOptionMap(key string, value string) {
common.TurnstileSecretKey = value common.TurnstileSecretKey = value
case "QuotaForNewUser": case "QuotaForNewUser":
common.QuotaForNewUser, _ = strconv.Atoi(value) common.QuotaForNewUser, _ = strconv.Atoi(value)
case "RatioGPT3dot5":
common.RatioGPT3dot5, _ = strconv.ParseFloat(value, 64)
case "RatioGPT4":
common.RatioGPT4, _ = strconv.ParseFloat(value, 64)
case "RatioGPT4_32k":
common.RatioGPT4_32k, _ = strconv.ParseFloat(value, 64)
case "TopUpLink":
common.TopUpLink = value
} }
} }

View File

@@ -48,14 +48,14 @@ func Redeem(key string, tokenId int) (quota int, err error) {
return 0, errors.New("未提供 token id") return 0, errors.New("未提供 token id")
} }
redemption := &Redemption{} redemption := &Redemption{}
err = DB.Where("key = ?", key).First(redemption).Error err = DB.Where("`key` = ?", key).First(redemption).Error
if err != nil { if err != nil {
return 0, errors.New("无效的兑换码") return 0, errors.New("无效的兑换码")
} }
if redemption.Status != common.RedemptionCodeStatusEnabled { if redemption.Status != common.RedemptionCodeStatusEnabled {
return 0, errors.New("该兑换码已被使用") return 0, errors.New("该兑换码已被使用")
} }
err = TopUpToken(tokenId, redemption.Quota) err = TopUpTokenQuota(tokenId, redemption.Quota)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@@ -17,8 +17,8 @@ type Token struct {
CreatedTime int64 `json:"created_time" gorm:"bigint"` CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"` AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainTimes int `json:"remain_times" gorm:"default:0"` RemainQuota int `json:"remain_quota" gorm:"default:0"`
UnlimitedTimes bool `json:"unlimited_times" gorm:"default:false"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
} }
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
@@ -39,7 +39,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
} }
key = strings.Replace(key, "Bearer ", "", 1) key = strings.Replace(key, "Bearer ", "", 1)
token = &Token{} token = &Token{}
err = DB.Where("key = ?", key).First(token).Error err = DB.Where("`key` = ?", key).First(token).Error
if err == nil { if err == nil {
if token.Status != common.TokenStatusEnabled { if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该 token 状态不可用") return nil, errors.New("该 token 状态不可用")
@@ -52,13 +52,13 @@ func ValidateUserToken(key string) (token *Token, err error) {
} }
return nil, errors.New("该 token 已过期") return nil, errors.New("该 token 已过期")
} }
if !token.UnlimitedTimes && token.RemainTimes <= 0 { if !token.UnlimitedQuota && token.RemainQuota <= 0 {
token.Status = common.TokenStatusExhausted token.Status = common.TokenStatusExhausted
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
common.SysError("更新 token 状态失败:" + err.Error()) common.SysError("更新 token 状态失败:" + err.Error())
} }
return nil, errors.New("该 token 可用次数已用尽") return nil, errors.New("该 token 额度已用尽")
} }
go func() { go func() {
token.AccessedTime = common.GetTimestamp() token.AccessedTime = common.GetTimestamp()
@@ -91,7 +91,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values // Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error { func (token *Token) Update() error {
var err error var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_times", "unlimited_times").Updates(token).Error err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
return err return err
} }
@@ -119,12 +119,12 @@ func DeleteTokenById(id int, userId int) (err error) {
return token.Delete() return token.Delete()
} }
func DecreaseTokenRemainTimesById(id int) (err error) { func ConsumeTokenQuota(id int, quota int) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_times", gorm.Expr("remain_times - ?", 1)).Error err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error
return err return err
} }
func TopUpToken(id int, times int) (err error) { func TopUpTokenQuota(id int, quota int) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_times", gorm.Expr("remain_times + ?", times)).Error err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota + ?", quota)).Error
return err return err
} }

View File

@@ -195,6 +195,19 @@ func IsAdmin(userId int) bool {
return user.Role >= common.RoleAdminUser return user.Role >= common.RoleAdminUser
} }
func IsUserEnabled(userId int) bool {
if userId == 0 {
return false
}
var user User
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
if err != nil {
common.SysError("No such user " + err.Error())
return false
}
return user.Status == common.UserStatusEnabled
}
func ValidateAccessToken(token string) (user *User) { func ValidateAccessToken(token string) (user *User) {
if token == "" { if token == "" {
return nil return nil

View File

@@ -7,12 +7,35 @@ import (
) )
func SetRelayRouter(router *gin.Engine) { func SetRelayRouter(router *gin.Engine) {
// https://platform.openai.com/docs/api-reference/introduction
relayV1Router := router.Group("/v1") relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
{ {
relayV1Router.Any("/*path", controller.Relay) relayV1Router.GET("/models", controller.Relay)
relayV1Router.GET("/models/:model", controller.Relay)
relayV1Router.POST("/completions", controller.RelayNotImplemented)
relayV1Router.POST("/chat/completions", controller.Relay)
relayV1Router.POST("/edits", controller.RelayNotImplemented)
relayV1Router.POST("/images/generations", controller.RelayNotImplemented)
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
relayV1Router.POST("/embeddings", controller.RelayNotImplemented)
relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
relayV1Router.GET("/files", controller.RelayNotImplemented)
relayV1Router.POST("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
relayV1Router.GET("/files/:id", controller.RelayNotImplemented)
relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented)
relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented)
relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented)
relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented)
relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
relayV1Router.POST("/moderations", controller.RelayNotImplemented)
} }
relayDashboardRouter := router.Group("/dashboard") relayDashboardRouter := router.Group("/dashboard") // TODO: return system's own token info
relayDashboardRouter.Use(middleware.TokenAuth(), middleware.Distribute()) relayDashboardRouter.Use(middleware.TokenAuth(), middleware.Distribute())
{ {
relayDashboardRouter.Any("/*path", controller.Relay) relayDashboardRouter.Any("/*path", controller.Relay)

View File

@@ -25,6 +25,10 @@ const SystemSetting = () => {
TurnstileSecretKey: '', TurnstileSecretKey: '',
RegisterEnabled: '', RegisterEnabled: '',
QuotaForNewUser: 0, QuotaForNewUser: 0,
RatioGPT3dot5: 2,
RatioGPT4: 30,
RatioGPT4_32k: 60,
TopUpLink: ''
}); });
let originInputs = {}; let originInputs = {};
let [loading, setLoading] = useState(false); let [loading, setLoading] = useState(false);
@@ -65,7 +69,7 @@ const SystemSetting = () => {
} }
const res = await API.put('/api/option', { const res = await API.put('/api/option', {
key, key,
value, value
}); });
const { success, message } = res.data; const { success, message } = res.data;
if (success) { if (success) {
@@ -88,7 +92,9 @@ const SystemSetting = () => {
name === 'WeChatAccountQRCodeImageURL' || name === 'WeChatAccountQRCodeImageURL' ||
name === 'TurnstileSiteKey' || name === 'TurnstileSiteKey' ||
name === 'TurnstileSecretKey' || name === 'TurnstileSecretKey' ||
name === 'QuotaForNewUser' name === 'QuotaForNewUser' ||
name.startsWith('Ratio') ||
name === 'TopUpLink'
) { ) {
setInputs((inputs) => ({ ...inputs, [name]: value })); setInputs((inputs) => ({ ...inputs, [name]: value }));
} else { } else {
@@ -101,6 +107,24 @@ const SystemSetting = () => {
await updateOption('ServerAddress', ServerAddress); await updateOption('ServerAddress', ServerAddress);
}; };
const submitOperationConfig = async () => {
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
}
if (originInputs['RatioGPT3dot5'] !== inputs.RatioGPT3dot5) {
await updateOption('RatioGPT3dot5', inputs.RatioGPT3dot5);
}
if (originInputs['RatioGPT4'] !== inputs.RatioGPT4) {
await updateOption('RatioGPT4', inputs.RatioGPT4);
}
if (originInputs['RatioGPT4_32k'] !== inputs.RatioGPT4_32k) {
await updateOption('RatioGPT4_32k', inputs.RatioGPT4_32k);
}
if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
await updateOption('TopUpLink', inputs.TopUpLink);
}
}
const submitSMTP = async () => { const submitSMTP = async () => {
if (originInputs['SMTPServer'] !== inputs.SMTPServer) { if (originInputs['SMTPServer'] !== inputs.SMTPServer) {
await updateOption('SMTPServer', inputs.SMTPServer); await updateOption('SMTPServer', inputs.SMTPServer);
@@ -244,10 +268,52 @@ const SystemSetting = () => {
min='0' min='0'
placeholder='例如100' placeholder='例如100'
/> />
<Form.Input
label='充值链接'
name='TopUpLink'
onChange={handleInputChange}
autoComplete='off'
value={inputs.TopUpLink}
type='link'
placeholder='例如发卡网站的购买链接'
/>
</Form.Group> </Form.Group>
<Form.Button onClick={()=>{ <Form.Group widths={3}>
updateOption('QuotaForNewUser', inputs.QuotaForNewUser).then(); <Form.Input
}}>保存运营设置</Form.Button> label='GPT-3.5 系列模型倍率'
name='RatioGPT3dot5'
onChange={handleInputChange}
autoComplete='off'
value={inputs.RatioGPT3dot5}
type='number'
step='0.01'
min='0'
placeholder='例如2'
/>
<Form.Input
label='GPT-4 系列模型倍率'
name='RatioGPT4'
onChange={handleInputChange}
autoComplete='off'
value={inputs.RatioGPT4}
type='number'
step='0.01'
min='0'
placeholder='例如30'
/>
<Form.Input
label='GPT-4 32k 系列模型倍率'
name='RatioGPT4_32k'
onChange={handleInputChange}
autoComplete='off'
value={inputs.RatioGPT4_32k}
type='number'
step='0.01'
min='0'
placeholder='例如60'
/>
</Form.Group>
<Form.Button onClick={submitOperationConfig}>保存运营设置</Form.Button>
<Divider /> <Divider />
<Header as='h3'> <Header as='h3'>
配置 SMTP 配置 SMTP

View File

@@ -37,6 +37,7 @@ const TokensTable = () => {
const [showTopUpModal, setShowTopUpModal] = useState(false); const [showTopUpModal, setShowTopUpModal] = useState(false);
const [targetTokenIdx, setTargetTokenIdx] = useState(0); const [targetTokenIdx, setTargetTokenIdx] = useState(0);
const [redemptionCode, setRedemptionCode] = useState(''); const [redemptionCode, setRedemptionCode] = useState('');
const [topUpLink, setTopUpLink] = useState('');
const loadTokens = async (startIdx) => { const loadTokens = async (startIdx) => {
const res = await API.get(`/api/token/?p=${startIdx}`); const res = await API.get(`/api/token/?p=${startIdx}`);
@@ -71,6 +72,13 @@ const TokensTable = () => {
.catch((reason) => { .catch((reason) => {
showError(reason); showError(reason);
}); });
let status = localStorage.getItem('status');
if (status) {
status = JSON.parse(status);
if (status.top_up_link) {
setTopUpLink(status.top_up_link);
}
}
}, []); }, []);
const manageToken = async (id, action, idx) => { const manageToken = async (id, action, idx) => {
@@ -156,7 +164,7 @@ const TokensTable = () => {
showSuccess('充值成功!'); showSuccess('充值成功!');
let newTokens = [...tokens]; let newTokens = [...tokens];
let realIdx = (activePage - 1) * ITEMS_PER_PAGE + targetTokenIdx; let realIdx = (activePage - 1) * ITEMS_PER_PAGE + targetTokenIdx;
newTokens[realIdx].remain_times += data; newTokens[realIdx].remain_quota += data;
setTokens(newTokens); setTokens(newTokens);
setRedemptionCode(''); setRedemptionCode('');
setShowTopUpModal(false); setShowTopUpModal(false);
@@ -209,10 +217,10 @@ const TokensTable = () => {
<Table.HeaderCell <Table.HeaderCell
style={{ cursor: 'pointer' }} style={{ cursor: 'pointer' }}
onClick={() => { onClick={() => {
sortToken('remain_times'); sortToken('remain_quota');
}} }}
> >
剩余次数 额度
</Table.HeaderCell> </Table.HeaderCell>
<Table.HeaderCell <Table.HeaderCell
style={{ cursor: 'pointer' }} style={{ cursor: 'pointer' }}
@@ -247,7 +255,7 @@ const TokensTable = () => {
<Table.Cell>{token.id}</Table.Cell> <Table.Cell>{token.id}</Table.Cell>
<Table.Cell>{token.name ? token.name : '无'}</Table.Cell> <Table.Cell>{token.name ? token.name : '无'}</Table.Cell>
<Table.Cell>{renderStatus(token.status)}</Table.Cell> <Table.Cell>{renderStatus(token.status)}</Table.Cell>
<Table.Cell>{token.unlimited_times ? '无限制' : token.remain_times}</Table.Cell> <Table.Cell>{token.unlimited_quota ? '无限制' : token.remain_quota}</Table.Cell>
<Table.Cell>{renderTimestamp(token.created_time)}</Table.Cell> <Table.Cell>{renderTimestamp(token.created_time)}</Table.Cell>
<Table.Cell>{token.expired_time === -1 ? '永不过期' : renderTimestamp(token.expired_time)}</Table.Cell> <Table.Cell>{token.expired_time === -1 ? '永不过期' : renderTimestamp(token.expired_time)}</Table.Cell>
<Table.Cell> <Table.Cell>
@@ -342,6 +350,11 @@ const TokensTable = () => {
<Modal.Content> <Modal.Content>
<Modal.Description> <Modal.Description>
{/*<Image src={status.wechat_qrcode} fluid />*/} {/*<Image src={status.wechat_qrcode} fluid />*/}
{
topUpLink && <p>
<a target='_blank' href={topUpLink}>点击此处获取兑换码</a>
</p>
}
<Form size='large'> <Form size='large'>
<Form.Input <Form.Input
fluid fluid

View File

@@ -10,13 +10,13 @@ const EditToken = () => {
const [loading, setLoading] = useState(isEdit); const [loading, setLoading] = useState(isEdit);
const originInputs = { const originInputs = {
name: '', name: '',
remain_times: 0, remain_quota: 0,
expired_time: -1, expired_time: -1,
unlimited_times: false unlimited_quota: false
}; };
const isAdminUser = isAdmin(); const isAdminUser = isAdmin();
const [inputs, setInputs] = useState(originInputs); const [inputs, setInputs] = useState(originInputs);
const { name, remain_times, expired_time, unlimited_times } = inputs; const { name, remain_quota, expired_time, unlimited_quota } = inputs;
const handleInputChange = (e, { name, value }) => { const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value })); setInputs((inputs) => ({ ...inputs, [name]: value }));
@@ -37,8 +37,8 @@ const EditToken = () => {
} }
}; };
const setUnlimitedTimes = () => { const setUnlimitedQuota = () => {
setInputs({ ...inputs, unlimited_times: !unlimited_times }); setInputs({ ...inputs, unlimited_quota: !unlimited_quota });
}; };
const loadToken = async () => { const loadToken = async () => {
@@ -63,7 +63,7 @@ const EditToken = () => {
const submit = async () => { const submit = async () => {
if (!isEdit && inputs.name === '') return; if (!isEdit && inputs.name === '') return;
let localInputs = inputs; let localInputs = inputs;
localInputs.remain_times = parseInt(localInputs.remain_times); localInputs.remain_quota = parseInt(localInputs.remain_quota);
if (localInputs.expired_time !== -1) { if (localInputs.expired_time !== -1) {
let time = Date.parse(localInputs.expired_time); let time = Date.parse(localInputs.expired_time);
if (isNaN(time)) { if (isNaN(time)) {
@@ -111,19 +111,19 @@ const EditToken = () => {
isAdminUser && <> isAdminUser && <>
<Form.Field> <Form.Field>
<Form.Input <Form.Input
label='剩余次数' label='额度'
name='remain_times' name='remain_quota'
placeholder={'请输入剩余次数'} placeholder={'请输入额度'}
onChange={handleInputChange} onChange={handleInputChange}
value={remain_times} value={remain_quota}
autoComplete='off' autoComplete='off'
type='number' type='number'
disabled={unlimited_times} disabled={unlimited_quota}
/> />
</Form.Field> </Form.Field>
<Button type={'button'} onClick={() => { <Button type={'button'} style={{marginBottom: '14px'}} onClick={() => {
setUnlimitedTimes(); setUnlimitedQuota();
}}>{unlimited_times ? '取消无限' : '设置为无限'}</Button> }}>{unlimited_quota ? '取消无限额度' : '设置为无限额度'}</Button>
</> </>
} }
<Form.Field> <Form.Field>