Compare commits

..

16 Commits

Author SHA1 Message Date
MaricoHan
aec343dc38 feat: support xunfei v3 (#637) 2023-10-29 22:03:01 +08:00
JustSong
89d458b9cf feat: able to set RELAY_TIMEOUT 2023-10-22 20:39:49 +08:00
JustSong
63fafba112 feat: support ERNIE-Bot-4 (close #608) 2023-10-22 18:48:35 +08:00
Bryan
a398f35968 fix: fix postgresql support (#606)
* fix postgresql support

fixes #517

* fix: fix pg support

* chore: delete useless code

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-10-22 18:38:29 +08:00
yiGmMk
57aa637c77 fix: set Accept header if not given (#615)
* fix: fastgpt调用通义千问问答失败

* refactor: Dockerfile

* Revert "refactor: Dockerfile"

This reverts commit a538c4f28e.

* chore: update implementation

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-10-22 17:56:20 +08:00
vc
3b483639a4 feat: add cloudflare ai gateway support for image & audio (#607)
* Update channel-test.go

* Update relay-audio.go

* Update relay-image.go

* chore: using a util function

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-10-22 17:50:52 +08:00
subnew
22980b4c44 docs: add description for TIKTOKEN_CACHE_DIR (#612)
* Update README.md

* Update README.md
2023-10-22 17:31:27 +08:00
Pluto
64cdb7eafb fix: docker compose healthcheck failed (#593) 2023-10-14 21:55:16 -05:00
JustSong
824444244b feat: able to delete all disabled channels 2023-10-14 17:25:48 +08:00
JustSong
fbe9985f57 chore: show prompt to let the user know 2023-10-14 16:32:01 +08:00
JustSong
a27a5bcc06 fix: fix array index not checked (close #588) 2023-10-14 16:11:15 +08:00
JustSong
e28d4b1741 feat: support cloudflare AI gateway now (close #565, #598) 2023-10-14 15:26:28 +08:00
JustSong
f073592d39 fix: fix request count not updated correctly when using batch update 2023-10-14 15:04:52 +08:00
阿鹏
fa41ca9805 fix: fix url not passing (#562)
解决令牌页面聊天按钮丢失url参数的问题
2023-10-14 12:45:00 +08:00
Mikey
e338de45b6 fix: 404 Component is missing (#592)
* fix: 404 Component is missing

* chore: update 404 page style

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-10-14 12:42:07 +08:00
dependabot[bot]
114587b46f chore(deps): bump golang.org/x/net from 0.10.0 to 0.17.0 (#591)
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.10.0 to 0.17.0.
- [Commits](https://github.com/golang/net/compare/v0.10.0...v0.17.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-14 12:34:46 +08:00
30 changed files with 197 additions and 92 deletions

View File

@@ -95,12 +95,13 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
15. 支持模型映射,重定向用户的请求模型。
16. 支持失败自动重试。
17. 支持绘图接口。
18. 支持丰富的**自定义**设置,
18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。
19. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
19. 支持通过系统访问令牌访问管理 API。
20. 支持 Cloudflare Turnstile 用户校验。
21. 支持用户管理,支持**多种用户登录注册方式**
20. 支持通过系统访问令牌访问管理 API。
21. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
@@ -351,6 +352,10 @@ graph LR
13. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
14. 编码器缓存设置:
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。

View File

@@ -21,12 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
var UsingSQLite = false
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var SQLitePath = "one-api.db"
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
@@ -98,6 +95,8 @@ var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
const (
RequestIdKey = "X-Oneapi-Request-Id"
)

6
common/database.go Normal file
View File

@@ -0,0 +1,6 @@
package common
var UsingSQLite = false
var UsingPostgreSQL = false
var SQLitePath = "one-api.db"

View File

@@ -46,6 +46,7 @@ var ModelRatio = map[string]float64{
"claude-2": 5.51, // $11.02 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens

View File

@@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int {
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}

View File

@@ -5,13 +5,14 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"sync"
"time"
"github.com/gin-gonic/gin"
)
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {

View File

@@ -127,8 +127,8 @@ func DeleteChannel(c *gin.Context) {
return
}
func DeleteManuallyDisabledChannel(c *gin.Context) {
rows, err := model.DeleteChannelByStatus(common.ChannelStatusManuallyDisabled)
func DeleteDisabledChannel(c *gin.Context) {
rows, err := model.DeleteDisabledChannel()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -306,6 +306,15 @@ func init() {
Root: "ERNIE-Bot-turbo",
Parent: nil,
},
{
Id: "ERNIE-Bot-4",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-4",
Parent: nil,
},
{
Id: "Embedding-V1",
Object: "model",

View File

@@ -6,12 +6,11 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -66,12 +65,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
requestBody := c.Request.Body
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)

View File

@@ -6,12 +6,11 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -61,16 +60,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(imageRequest)

View File

@@ -6,13 +6,14 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"time"
"github.com/gin-gonic/gin"
)
const (
@@ -31,7 +32,14 @@ var httpClient *http.Client
var impatientHTTPClient *http.Client
func init() {
httpClient = &http.Client{}
if common.RelayTimeout == 0 {
httpClient = &http.Client{}
} else {
httpClient = &http.Client{
Timeout: time.Duration(common.RelayTimeout) * time.Second,
}
}
impatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
@@ -118,7 +126,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
@@ -151,6 +159,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
@@ -361,6 +371,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
resp, err = httpClient.Do(req)
if err != nil {

View File

@@ -176,3 +176,13 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
return
}
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if channelType == common.ChannelTypeOpenAI {
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
}
}
return fullRequestURL
}

View File

@@ -220,6 +220,9 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
for !stop {
select {
case xunfeiResponse = <-dataChan:
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
continue
}
content += xunfeiResponse.Payload.Choices.Text[0].Content
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
@@ -295,8 +298,8 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string,
common.SysLog("api_version not found, use default: " + apiVersion)
}
domain := "general"
if apiVersion == "v2.1" {
domain = "generalv2"
if apiVersion != "v1.1" {
domain += strings.Split(apiVersion, ".")[0]
}
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
return domain, authUrl

View File

@@ -23,7 +23,7 @@ services:
depends_on:
- redis
healthcheck:
test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ]
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
interval: 30s
timeout: 10s
retries: 3

10
go.mod
View File

@@ -15,8 +15,9 @@ require (
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5
golang.org/x/crypto v0.9.0
golang.org/x/crypto v0.14.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3
gorm.io/gorm v1.25.0
)
@@ -52,10 +53,9 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gorm.io/driver/postgres v1.5.2 // indirect
)

17
go.sum
View File

@@ -150,11 +150,11 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -162,14 +162,14 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -198,7 +198,6 @@ gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBp
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=

View File

@@ -15,10 +15,17 @@ type Ability struct {
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{}
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
var err error = nil
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error
} else {
err = channelQuery.Order("RAND()").First(&ability).Error

View File

@@ -21,14 +21,18 @@ var (
)
func CacheGetTokenByKey(key string) (*Token, error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
var token Token
if !common.RedisEnabled {
err := DB.Where("`key` = ?", key).First(&token).Error
err := DB.Where(keyCol+" = ?", key).First(&token).Error
return &token, err
}
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
err := DB.Where("`key` = ?", key).First(&token).Error
err := DB.Where(keyCol+" = ?", key).First(&token).Error
if err != nil {
return nil, err
}

View File

@@ -38,7 +38,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
}
func SearchChannels(keyword string) (channels []*Channel, err error) {
err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
return channels, err
}
@@ -53,17 +57,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
return &channel, err
}
func GetRandomChannel() (*Channel, error) {
channel := Channel{}
var err error = nil
if common.UsingSQLite {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
} else {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
}
return &channel, err
}
func BatchInsertChannels(channels []Channel) error {
var err error
err = DB.Create(&channels).Error
@@ -181,3 +174,8 @@ func DeleteChannelByStatus(status int64) (int64, error) {
result := DB.Where("status = ?", status).Delete(&Channel{})
return result.RowsAffected, result.Error
}
func DeleteDisabledChannel() (int64, error) {
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
return result.RowsAffected, result.Error
}

View File

@@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) {
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage

View File

@@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) {
}
redemption := &Redemption{}
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
if err != nil {
return errors.New("无效的兑换码")
}

View File

@@ -266,7 +266,12 @@ func GetUserEmail(id int) (email string, err error) {
}
func GetUserGroup(id int) (group string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
}
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
return group, err
}
@@ -309,7 +314,8 @@ func GetRootUserEmail() (email string) {
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota)
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
return
}
updateUserUsedQuotaAndRequestCount(id, quota, 1)
@@ -327,6 +333,24 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
}
}
func updateUserUsedQuota(id int, quota int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
},
).Error
if err != nil {
common.SysError("failed to update user used quota: " + err.Error())
}
}
func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil {
common.SysError("failed to update user request count: " + err.Error())
}
}
func GetUsernameById(id int) (username string) {
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
return username

View File

@@ -6,13 +6,13 @@ import (
"time"
)
const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock
const (
BatchUpdateTypeUserQuota = iota
BatchUpdateTypeTokenQuota
BatchUpdateTypeUsedQuotaAndRequestCount
BatchUpdateTypeUsedQuota
BatchUpdateTypeChannelUsedQuota
BatchUpdateTypeRequestCount
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
)
var batchUpdateStores []map[int]int
@@ -51,7 +51,7 @@ func batchUpdate() {
store := batchUpdateStores[i]
batchUpdateStores[i] = make(map[int]int)
batchUpdateLocks[i].Unlock()
// TODO: maybe we can combine updates with same key?
for key, value := range store {
switch i {
case BatchUpdateTypeUserQuota:
@@ -64,8 +64,10 @@ func batchUpdate() {
if err != nil {
common.SysError("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuotaAndRequestCount:
updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect
case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value)
case BatchUpdateTypeRequestCount:
updateUserRequestCount(key, value)
case BatchUpdateTypeChannelUsedQuota:
updateChannelUsedQuota(key, value)
}

View File

@@ -74,7 +74,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
channelRoute.POST("/", controller.AddChannel)
channelRoute.PUT("/", controller.UpdateChannel)
channelRoute.DELETE("/manually_disabled", controller.DeleteManuallyDisabledChannel)
channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
channelRoute.DELETE("/:id", controller.DeleteChannel)
}
tokenRoute := apiRouter.Group("/token")

View File

@@ -283,7 +283,9 @@ function App() {
</Suspense>
}
/>
<Route path='*' element={NotFound} />
<Route path='*' element={
<NotFound />
} />
</Routes>
);
}

View File

@@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Input, Label, Pagination, Popup, Table } from 'semantic-ui-react';
import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom';
import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers';
import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
import { renderGroup, renderNumber } from '../helpers/render';
@@ -55,6 +55,7 @@ const ChannelsTable = () => {
const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false);
const [updatingBalance, setUpdatingBalance] = useState(false);
const [showPrompt, setShowPrompt] = useState(shouldShowPrompt("channel-test"));
const loadChannels = async (startIdx) => {
const res = await API.get(`/api/channel/?p=${startIdx}`);
@@ -226,7 +227,6 @@ const ChannelsTable = () => {
showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
} else {
showError(message);
showNotice('当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。');
}
};
@@ -240,11 +240,11 @@ const ChannelsTable = () => {
}
};
const deleteAllManuallyDisabledChannels = async () => {
const res = await API.delete(`/api/channel/manually_disabled`);
const deleteAllDisabledChannels = async () => {
const res = await API.delete(`/api/channel/disabled`);
const { success, message, data } = res.data;
if (success) {
showSuccess(`已删除所有手动禁用渠道,共计 ${data}`);
showSuccess(`已删除所有禁用渠道,共计 ${data}`);
await refresh();
} else {
showError(message);
@@ -317,7 +317,19 @@ const ChannelsTable = () => {
onChange={handleKeywordChange}
/>
</Form>
{
showPrompt && (
<Message onDismiss={() => {
setShowPrompt(false);
setPromptShown("channel-test");
}}>
当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo
模型进行非流式请求实现的因此测试报错并不一定代表通道不可用该功能后续会修复
另外OpenAI 渠道已经不再支持通过 key 获取余额因此余额显示为 0对于支持的渠道类型请点击余额进行刷新
</Message>
)
}
<Table basic compact size='small'>
<Table.Header>
<Table.Row>
@@ -519,14 +531,14 @@ const ChannelsTable = () => {
<Popup
trigger={
<Button size='small' loading={loading}>
删除所有手动禁用渠道
删除禁用渠道
</Button>
}
on='click'
flowing
hoverable
>
<Button size='small' loading={loading} negative onClick={deleteAllManuallyDisabledChannels}>
<Button size='small' loading={loading} negative onClick={deleteAllDisabledChannels}>
确认删除
</Button>
</Popup>

View File

@@ -138,7 +138,7 @@ const TokensTable = () => {
let defaultUrl;
if (chatLink) {
defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`;
defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
} else {
defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
}

View File

@@ -186,4 +186,14 @@ export const verifyJSON = (str) => {
return false;
}
return true;
};
};
export function shouldShowPrompt(id) {
let prompt = localStorage.getItem(`prompt-${id}`);
return !prompt;
}
export function setPromptShown(id) {
localStorage.setItem(`prompt-${id}`, 'true');
}

View File

@@ -66,7 +66,7 @@ const EditChannel = () => {
localModels = ['PaLM-2'];
break;
case 15:
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1'];
break;
case 17:
localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];

View File

@@ -1,19 +1,12 @@
import React from 'react';
import { Segment, Header } from 'semantic-ui-react';
import { Message } from 'semantic-ui-react';
const NotFound = () => (
<>
<Header
block
as="h4"
content="404"
attached="top"
icon="info"
className="small-icon"
/>
<Segment attached="bottom">
未找到所请求的页面
</Segment>
<Message negative>
<Message.Header>页面不存在</Message.Header>
<p>请检查你的浏览器地址是否正确</p>
</Message>
</>
);