Compare commits

..

25 Commits

Author SHA1 Message Date
wzxjohn
bc7c9105f4 chore: update quota calc logic (close #599) (#627)
* fix: change quota calc code (close #599)

Use float64 during calc and do math.Ceil after calc. This will result in the quota being used slightly more than the official standard, but it will be guaranteed that it will not be less.

* chore: remove blank line

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-05 19:15:06 +08:00
wood chen
3fe76c8af7 fix: fix Cloudflare AI Gateway channel test support (#639)
* 当使用Cloudflare AI Gateway时,支持openai渠道测试

* refactor: change logic

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-05 19:08:25 +08:00
papersnake
c70c614018 feat: support chatglm_turbo (#648)
* feat: support chatglm_turbo

* fix: remove characterglm
2023-11-05 17:59:38 +08:00
Baksi
0d87de697c fix: fix typo (#651) 2023-11-02 22:24:22 +08:00
MaricoHan
aec343dc38 feat: support xunfei v3 (#637) 2023-10-29 22:03:01 +08:00
JustSong
89d458b9cf feat: able to set RELAY_TIMEOUT 2023-10-22 20:39:49 +08:00
JustSong
63fafba112 feat: support ERNIE-Bot-4 (close #608) 2023-10-22 18:48:35 +08:00
Bryan
a398f35968 fix: fix postgresql support (#606)
* fix postgresql support

fixes #517

* fix: fix pg support

* chore: delete useless code

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-10-22 18:38:29 +08:00
yiGmMk
57aa637c77 fix: set Accept header if not given (#615)
* fix: fastgpt调用通义千问问答失败

* refactor: Dockerfile

* Revert "refactor: Dockerfile"

This reverts commit a538c4f28e.

* chore: update implementation

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-10-22 17:56:20 +08:00
vc
3b483639a4 feat: add cloudflare ai gateway support for image & audio (#607)
* Update channel-test.go

* Update relay-audio.go

* Update relay-image.go

* chore: using a util function

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-10-22 17:50:52 +08:00
subnew
22980b4c44 docs: add description for TIKTOKEN_CACHE_DIR (#612)
* Update README.md

* Update README.md
2023-10-22 17:31:27 +08:00
Pluto
64cdb7eafb fix: docker compose healthcheck failed (#593) 2023-10-14 21:55:16 -05:00
JustSong
824444244b feat: able to delete all disabled channels 2023-10-14 17:25:48 +08:00
JustSong
fbe9985f57 chore: show prompt to let the user know 2023-10-14 16:32:01 +08:00
JustSong
a27a5bcc06 fix: fix array index not checked (close #588) 2023-10-14 16:11:15 +08:00
JustSong
e28d4b1741 feat: support cloudflare AI gateway now (close #565, #598) 2023-10-14 15:26:28 +08:00
JustSong
f073592d39 fix: fix request count not updated correctly when using batch update 2023-10-14 15:04:52 +08:00
阿鹏
fa41ca9805 fix: fix url not passing (#562)
解决令牌页面聊天按钮丢失url参数的问题
2023-10-14 12:45:00 +08:00
Mikey
e338de45b6 fix: 404 Component is missing (#592)
* fix: 404 Component is missing

* chore: update 404 page style

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-10-14 12:42:07 +08:00
dependabot[bot]
114587b46f chore(deps): bump golang.org/x/net from 0.10.0 to 0.17.0 (#591)
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.10.0 to 0.17.0.
- [Commits](https://github.com/golang/net/compare/v0.10.0...v0.17.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-14 12:34:46 +08:00
JustSong
b4b4acc288 feat: support Tencent's model (close #519) 2023-10-03 14:19:03 +08:00
JustSong
d663de3e3a chore: delete 360's 360GPT_S2_V9.4 2023-10-03 12:52:45 +08:00
JustSong
a85ecace2e chore: delete 360's 360GPT_S2_V9.4 2023-10-03 12:51:53 +08:00
JustSong
fbdea91ea1 docs: update readme 2023-10-02 13:15:35 +08:00
JustSong
8d34b7a77e feat: able to delete all manually disabled channels (close #539) 2023-10-02 13:06:27 +08:00
34 changed files with 680 additions and 146 deletions

View File

@@ -72,6 +72,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
+ [x] [360 智脑](https://ai.360.cn)
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
2. 支持配置镜像以及众多第三方代理服务:
+ [x] [OpenAI-SB](https://openai-sb.com)
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412)
@@ -94,12 +95,13 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
15. 支持模型映射,重定向用户的请求模型。
16. 支持失败自动重试。
17. 支持绘图接口。
18. 支持丰富的**自定义**设置,
18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。
19. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
19. 支持通过系统访问令牌访问管理 API。
20. 支持 Cloudflare Turnstile 用户校验。
21. 支持用户管理,支持**多种用户登录注册方式**
20. 支持通过系统访问令牌访问管理 API。
21. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
@@ -350,6 +352,10 @@ graph LR
13. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
14. 编码器缓存设置:
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
@@ -389,6 +395,12 @@ https://openai.justsong.cn
+ 检查是否启用了 HTTPS浏览器会拦截 HTTPS 域名下的 HTTP 请求。
6. 报错:`当前分组负载已饱和,请稍后再试`
+ 上游通道 429 了。
7. 升级之后我的数据会丢失吗?
+ 如果使用 MySQL不会。
+ 如果使用 SQLite需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。
8. 升级之前数据库需要做变更吗?
+ 一般情况下不需要,系统将在初始化的时候自动调整。
+ 如果需要的话,我会在更新日志中说明,并给出脚本。
## 相关项目
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统

View File

@@ -21,12 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
var UsingSQLite = false
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var SQLitePath = "one-api.db"
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
@@ -98,6 +95,8 @@ var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
@@ -156,9 +155,10 @@ const (
)
const (
ChannelStatusUnknown = 0
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
ChannelStatusDisabled = 2 // also don't use 0
ChannelStatusUnknown = 0
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
ChannelStatusManuallyDisabled = 2 // also don't use 0
ChannelStatusAutoDisabled = 3
)
const (
@@ -185,30 +185,32 @@ const (
ChannelTypeOpenRouter = 20
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", //23
}

6
common/database.go Normal file
View File

@@ -0,0 +1,6 @@
package common
var UsingSQLite = false
var UsingPostgreSQL = false
var SQLitePath = "one-api.db"

View File

@@ -46,8 +46,10 @@ var ModelRatio = map[string]float64{
"claude-2": 5.51, // $11.02 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
@@ -59,7 +61,7 @@ var ModelRatio = map[string]float64{
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"360GPT_S2_V9.4": 0.8572, // ¥0.012 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
}
func ModelRatio2JSONString() string {

View File

@@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int {
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}

View File

@@ -10,6 +10,7 @@ import (
"one-api/common"
"one-api/model"
"strconv"
"strings"
"sync"
"time"
)
@@ -49,6 +50,8 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
}
requestURL += "/v1/chat/completions"
}
// for Cloudflare AI gateway: https://github.com/songquanpeng/one-api/pull/639
requestURL = strings.Replace(requestURL, "/v1/v1", "/v1", 1)
jsonData, err := json.Marshal(request)
if err != nil {
@@ -141,7 +144,7 @@ func disableChannel(channelId int, channelName string, reason string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
err := common.SendEmail(subject, common.RootUserEmail, content)

View File

@@ -127,6 +127,23 @@ func DeleteChannel(c *gin.Context) {
return
}
func DeleteDisabledChannel(c *gin.Context) {
rows, err := model.DeleteDisabledChannel()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": rows,
})
return
}
func UpdateChannel(c *gin.Context) {
channel := model.Channel{}
err := c.ShouldBindJSON(&channel)

View File

@@ -274,7 +274,7 @@ func init() {
Id: "claude-instant-1",
Object: "model",
Created: 1677649963,
OwnedBy: "anturopic",
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-instant-1",
Parent: nil,
@@ -283,7 +283,7 @@ func init() {
Id: "claude-2",
Object: "model",
Created: 1677649963,
OwnedBy: "anturopic",
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2",
Parent: nil,
@@ -306,6 +306,15 @@ func init() {
Root: "ERNIE-Bot-turbo",
Parent: nil,
},
{
Id: "ERNIE-Bot-4",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-4",
Parent: nil,
},
{
Id: "Embedding-V1",
Object: "model",
@@ -324,6 +333,15 @@ func init() {
Root: "PaLM-2",
Parent: nil,
},
{
Id: "chatglm_turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_turbo",
Parent: nil,
},
{
Id: "chatglm_pro",
Object: "model",
@@ -424,12 +442,12 @@ func init() {
Parent: nil,
},
{
Id: "360GPT_S2_V9.4",
Id: "hunyuan",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
OwnedBy: "tencent",
Permission: permission,
Root: "360GPT_S2_V9.4",
Root: "hunyuan",
Parent: nil,
},
}

View File

@@ -46,7 +46,7 @@ func UpdateOption(c *gin.Context) {
if option.Value == "true" && common.GitHubClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 GitHub OAuth请先填入 GitHub Client ID 以及 GitHub Client Secret",
"message": "无法启用 GitHub OAuth请先填入 GitHub Client Id 以及 GitHub Client Secret",
})
return
}

View File

@@ -6,12 +6,11 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -66,12 +65,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
requestBody := c.Request.Body
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)

View File

@@ -6,12 +6,11 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -61,16 +60,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(imageRequest)

287
controller/relay-tencent.go Normal file
View File

@@ -0,0 +1,287 @@
package controller
import (
"bufio"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"sort"
"strconv"
"strings"
)
// https://cloud.tencent.com/document/product/1729/97732
type TencentMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type TencentChatRequest struct {
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
SecretId string `json:"secret_id"` // 官网 SecretId
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
// 例如1529223702如果与当前时间相差过大会引起签名过期错误
Timestamp int64 `json:"timestamp"`
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
// 单位为秒Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
Expired int64 `json:"expired"`
QueryID string `json:"query_id"` //请求 Id用于问题排查
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
// 建议该参数和 top_p 只设置1个不要同时更改 top_p
Temperature float64 `json:"temperature"`
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
// 建议该参数和 temperature 只设置1个不要同时更改
TopP float64 `json:"top_p"`
// Stream 0同步1流式 默认协议SSE)
// 同步请求超时60s如果内容较长建议使用流式
Stream int `json:"stream"`
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
// 输入 content 总数最大支持 3000 token。
Messages []TencentMessage `json:"messages"`
}
type TencentError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type TencentUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TencentResponseChoices struct {
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type TencentChatResponse struct {
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage Usage `json:"usage,omitempty"` // token 数量
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}
func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
messages := make([]TencentMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, TencentMessage{
Role: "user",
Content: message.Content,
})
messages = append(messages, TencentMessage{
Role: "assistant",
Content: "Okay",
})
continue
}
messages = append(messages, TencentMessage{
Content: message.Content,
Role: message.Role,
})
}
stream := 0
if request.Stream {
stream = 1
}
return &TencentChatRequest{
Timestamp: common.GetTimestamp(),
Expired: common.GetTimestamp() + 24*60*60,
QueryID: common.GetUUID(),
Temperature: request.Temperature,
TopP: request.TopP,
Stream: stream,
Messages: messages,
}
}
func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: response.Usage,
}
if len(response.Choices) > 0 {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: response.Choices[0].Messages.Content,
},
FinishReason: response.Choices[0].FinishReason,
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
response := ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "tencent-hunyuan",
}
if len(TencentResponse.Choices) > 0 {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
if TencentResponse.Choices[0].FinishReason == "stop" {
choice.FinishReason = &stopFinishReason
}
response.Choices = append(response.Choices, choice)
}
return &response
}
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var TencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &TencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.Content
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var TencentResponse TencentChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &TencentResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if TencentResponse.Error.Code != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
parts := strings.Split(config, "|")
if len(parts) != 3 {
err = errors.New("invalid tencent config")
return
}
appId, err = strconv.ParseInt(parts[0], 10, 64)
secretId = parts[1]
secretKey = parts[2]
return
}
func getTencentSign(req TencentChatRequest, secretKey string) string {
params := make([]string, 0)
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
params = append(params, "secret_id="+req.SecretId)
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
params = append(params, "query_id="+req.QueryID)
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
params = append(params, "stream="+strconv.Itoa(req.Stream))
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
var messageStr string
for _, msg := range req.Messages {
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
}
messageStr = strings.TrimSuffix(messageStr, ",")
params = append(params, "messages=["+messageStr+"]")
sort.Sort(sort.StringSlice(params))
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
mac := hmac.New(sha1.New, []byte(secretKey))
signURL := url
mac.Write([]byte(signURL))
sign := mac.Sum([]byte(nil))
return base64.StdEncoding.EncodeToString(sign)
}

View File

@@ -6,13 +6,15 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"math"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"time"
"github.com/gin-gonic/gin"
)
const (
@@ -24,13 +26,21 @@ const (
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
)
var httpClient *http.Client
var impatientHTTPClient *http.Client
func init() {
httpClient = &http.Client{}
if common.RelayTimeout == 0 {
httpClient = &http.Client{}
} else {
httpClient = &http.Client{
Timeout: time.Duration(common.RelayTimeout) * time.Second,
}
}
impatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
@@ -109,13 +119,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiType = APITypeXunfei
case common.ChannelTypeAIProxyLibrary:
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
@@ -148,6 +160,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
@@ -179,6 +193,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if relayMode == RelayModeEmbeddings {
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
}
case APITypeTencent:
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
case APITypeAIProxyLibrary:
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
}
@@ -285,6 +301,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeTencent:
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
if err != nil {
return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
}
tencentRequest := requestOpenAI2Tencent(textRequest)
tencentRequest.AppId = appId
tencentRequest.SecretId = secretId
jsonStr, err := json.Marshal(tencentRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
sign := getTencentSign(*tencentRequest, secretKey)
c.Request.Header.Set("Authorization", sign)
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAIProxyLibrary:
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
@@ -332,11 +365,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if textRequest.Stream {
req.Header.Set("X-DashScope-SSE", "enable")
}
case APITypeTencent:
req.Header.Set("Authorization", apiKey)
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
resp, err = httpClient.Do(req)
if err != nil {
@@ -377,9 +415,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
completionRatio := common.GetCompletionRatio(textRequest.Model)
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
@@ -584,6 +620,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
return nil
}
case APITypeTencent:
if isStream {
err, responseText := tencentStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := tencentHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
}

View File

@@ -176,3 +176,13 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
return
}
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if channelType == common.ChannelTypeOpenAI {
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
}
}
return fullRequestURL
}

View File

@@ -220,6 +220,9 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
for !stop {
select {
case xunfeiResponse = <-dataChan:
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
continue
}
content += xunfeiResponse.Payload.Choices.Text[0].Content
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
@@ -295,8 +298,8 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string,
common.SysLog("api_version not found, use default: " + apiVersion)
}
domain := "general"
if apiVersion == "v2.1" {
domain = "generalv2"
if apiVersion != "v1.1" {
domain += strings.Split(apiVersion, ".")[0]
}
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
return domain, authUrl

View File

@@ -23,7 +23,7 @@ services:
depends_on:
- redis
healthcheck:
test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ]
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
interval: 30s
timeout: 10s
retries: 3

10
go.mod
View File

@@ -15,8 +15,9 @@ require (
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5
golang.org/x/crypto v0.9.0
golang.org/x/crypto v0.14.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3
gorm.io/gorm v1.25.0
)
@@ -52,10 +53,9 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gorm.io/driver/postgres v1.5.2 // indirect
)

17
go.sum
View File

@@ -150,11 +150,11 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -162,14 +162,14 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -198,7 +198,6 @@ gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBp
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=

View File

@@ -25,12 +25,12 @@ func Distribute() func(c *gin.Context) {
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return
}
channel, err = model.GetChannelById(id, true)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return
}
if channel.Status != common.ChannelStatusEnabled {

View File

@@ -15,10 +15,17 @@ type Ability struct {
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{}
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
var err error = nil
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error
} else {
err = channelQuery.Order("RAND()").First(&ability).Error

View File

@@ -21,14 +21,18 @@ var (
)
func CacheGetTokenByKey(key string) (*Token, error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
var token Token
if !common.RedisEnabled {
err := DB.Where("`key` = ?", key).First(&token).Error
err := DB.Where(keyCol+" = ?", key).First(&token).Error
return &token, err
}
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
err := DB.Where("`key` = ?", key).First(&token).Error
err := DB.Where(keyCol+" = ?", key).First(&token).Error
if err != nil {
return nil, err
}

View File

@@ -38,7 +38,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
}
func SearchChannels(keyword string) (channels []*Channel, err error) {
err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
return channels, err
}
@@ -53,17 +57,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
return &channel, err
}
func GetRandomChannel() (*Channel, error) {
channel := Channel{}
var err error = nil
if common.UsingSQLite {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
} else {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
}
return &channel, err
}
func BatchInsertChannels(channels []Channel) error {
var err error
err = DB.Create(&channels).Error
@@ -176,3 +169,13 @@ func updateChannelUsedQuota(id int, quota int) {
common.SysError("failed to update channel used quota: " + err.Error())
}
}
func DeleteChannelByStatus(status int64) (int64, error) {
result := DB.Where("status = ?", status).Delete(&Channel{})
return result.RowsAffected, result.Error
}
func DeleteDisabledChannel() (int64, error) {
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
return result.RowsAffected, result.Error
}

View File

@@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) {
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage

View File

@@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) {
}
redemption := &Redemption{}
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
if err != nil {
return errors.New("无效的兑换码")
}

View File

@@ -266,7 +266,12 @@ func GetUserEmail(id int) (email string, err error) {
}
func GetUserGroup(id int) (group string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
}
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
return group, err
}
@@ -309,7 +314,8 @@ func GetRootUserEmail() (email string) {
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota)
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
return
}
updateUserUsedQuotaAndRequestCount(id, quota, 1)
@@ -327,6 +333,24 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
}
}
func updateUserUsedQuota(id int, quota int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
},
).Error
if err != nil {
common.SysError("failed to update user used quota: " + err.Error())
}
}
func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil {
common.SysError("failed to update user request count: " + err.Error())
}
}
func GetUsernameById(id int) (username string) {
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
return username

View File

@@ -6,13 +6,13 @@ import (
"time"
)
const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock
const (
BatchUpdateTypeUserQuota = iota
BatchUpdateTypeTokenQuota
BatchUpdateTypeUsedQuotaAndRequestCount
BatchUpdateTypeUsedQuota
BatchUpdateTypeChannelUsedQuota
BatchUpdateTypeRequestCount
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
)
var batchUpdateStores []map[int]int
@@ -51,7 +51,7 @@ func batchUpdate() {
store := batchUpdateStores[i]
batchUpdateStores[i] = make(map[int]int)
batchUpdateLocks[i].Unlock()
// TODO: maybe we can combine updates with same key?
for key, value := range store {
switch i {
case BatchUpdateTypeUserQuota:
@@ -64,8 +64,10 @@ func batchUpdate() {
if err != nil {
common.SysError("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuotaAndRequestCount:
updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect
case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value)
case BatchUpdateTypeRequestCount:
updateUserRequestCount(key, value)
case BatchUpdateTypeChannelUsedQuota:
updateChannelUsedQuota(key, value)
}

View File

@@ -74,6 +74,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
channelRoute.POST("/", controller.AddChannel)
channelRoute.PUT("/", controller.UpdateChannel)
channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
channelRoute.DELETE("/:id", controller.DeleteChannel)
}
tokenRoute := apiRouter.Group("/token")

View File

@@ -283,7 +283,9 @@ function App() {
</Suspense>
}
/>
<Route path='*' element={NotFound} />
<Route path='*' element={
<NotFound />
} />
</Routes>
);
}

View File

@@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react';
import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react';
import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom';
import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers';
import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
import { renderGroup, renderNumber } from '../helpers/render';
@@ -55,6 +55,7 @@ const ChannelsTable = () => {
const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false);
const [updatingBalance, setUpdatingBalance] = useState(false);
const [showPrompt, setShowPrompt] = useState(shouldShowPrompt("channel-test"));
const loadChannels = async (startIdx) => {
const res = await API.get(`/api/channel/?p=${startIdx}`);
@@ -96,7 +97,7 @@ const ChannelsTable = () => {
});
}, []);
const manageChannel = async (id, action, idx, priority) => {
const manageChannel = async (id, action, idx, value) => {
let data = { id };
let res;
switch (action) {
@@ -112,10 +113,20 @@ const ChannelsTable = () => {
res = await API.put('/api/channel/', data);
break;
case 'priority':
if (priority === '') {
if (value === '') {
return;
}
data.priority = parseInt(priority);
data.priority = parseInt(value);
res = await API.put('/api/channel/', data);
break;
case 'weight':
if (value === '') {
return;
}
data.weight = parseInt(value);
if (data.weight < 0) {
data.weight = 0;
}
res = await API.put('/api/channel/', data);
break;
}
@@ -142,9 +153,23 @@ const ChannelsTable = () => {
return <Label basic color='green'>已启用</Label>;
case 2:
return (
<Label basic color='red'>
已禁用
</Label>
<Popup
trigger={<Label basic color='red'>
已禁用
</Label>}
content='本渠道被手动禁用'
basic
/>
);
case 3:
return (
<Popup
trigger={<Label basic color='yellow'>
已禁用
</Label>}
content='本渠道被程序自动禁用'
basic
/>
);
default:
return (
@@ -202,7 +227,6 @@ const ChannelsTable = () => {
showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
} else {
showError(message);
showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。")
}
};
@@ -216,6 +240,17 @@ const ChannelsTable = () => {
}
};
const deleteAllDisabledChannels = async () => {
const res = await API.delete(`/api/channel/disabled`);
const { success, message, data } = res.data;
if (success) {
showSuccess(`已删除所有禁用渠道,共计 ${data}`);
await refresh();
} else {
showError(message);
}
};
const updateChannelBalance = async (id, name, idx) => {
const res = await API.get(`/api/channel/update_balance/${id}/`);
const { success, message, balance } = res.data;
@@ -282,7 +317,19 @@ const ChannelsTable = () => {
onChange={handleKeywordChange}
/>
</Form>
{
showPrompt && (
<Message onDismiss={() => {
setShowPrompt(false);
setPromptShown("channel-test");
}}>
当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo
模型进行非流式请求实现的因此测试报错并不一定代表通道不可用该功能后续会修复
另外OpenAI 渠道已经不再支持通过 key 获取余额因此余额显示为 0对于支持的渠道类型请点击余额进行刷新
</Message>
)
}
<Table basic compact size='small'>
<Table.Header>
<Table.Row>
@@ -343,10 +390,10 @@ const ChannelsTable = () => {
余额
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortChannel('priority');
}}
style={{ cursor: 'pointer' }}
onClick={() => {
sortChannel('priority');
}}
>
优先级
</Table.HeaderCell>
@@ -390,18 +437,18 @@ const ChannelsTable = () => {
</Table.Cell>
<Table.Cell>
<Popup
trigger={<Input type="number" defaultValue={channel.priority} onBlur={(event) => {
manageChannel(
channel.id,
'priority',
idx,
event.target.value,
);
}}>
<input style={{maxWidth:'60px'}} />
</Input>}
content='渠道选择优先级,越高越优先'
basic
trigger={<Input type='number' defaultValue={channel.priority} onBlur={(event) => {
manageChannel(
channel.id,
'priority',
idx,
event.target.value
);
}}>
<input style={{ maxWidth: '60px' }} />
</Input>}
content='渠道选择优先级,越高越优先'
basic
/>
</Table.Cell>
<Table.Cell>
@@ -481,6 +528,20 @@ const ChannelsTable = () => {
</Button>
<Button size='small' onClick={updateAllChannelsBalance}
loading={loading || updatingBalance}>更新所有已启用通道余额</Button>
<Popup
trigger={
<Button size='small' loading={loading}>
删除禁用渠道
</Button>
}
on='click'
flowing
hoverable
>
<Button size='small' loading={loading} negative onClick={deleteAllDisabledChannels}>
确认删除
</Button>
</Popup>
<Pagination
floated='right'
activePage={activePage}

View File

@@ -138,7 +138,7 @@ const TokensTable = () => {
let defaultUrl;
if (chatLink) {
defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`;
defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
} else {
defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
}

View File

@@ -8,6 +8,7 @@ export const CHANNEL_OPTIONS = [
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
{ key: 19, text: '360 智脑', value: 19, color: 'blue' },
{ key: 23, text: '腾讯混元', value: 23, color: 'teal' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' },

View File

@@ -186,4 +186,14 @@ export const verifyJSON = (str) => {
return false;
}
return true;
};
};
export function shouldShowPrompt(id) {
let prompt = localStorage.getItem(`prompt-${id}`);
return !prompt;
}
export function setPromptShown(id) {
localStorage.setItem(`prompt-${id}`, 'true');
}

View File

@@ -19,6 +19,8 @@ function type2secretPrompt(type) {
return '按照如下格式输入APPID|APISecret|APIKey';
case 22:
return '按照如下格式输入APIKey-AppId例如fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041';
case 23:
return '按照如下格式输入AppId|SecretId|SecretKey';
default:
return '请输入渠道对应的鉴权密钥';
}
@@ -64,19 +66,22 @@ const EditChannel = () => {
localModels = ['PaLM-2'];
break;
case 15:
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1'];
break;
case 17:
localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];
break;
case 16:
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break;
case 18:
localModels = ['SparkDesk'];
break;
case 19:
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4'];
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
break;
case 23:
localModels = ['hunyuan'];
break;
}
setInputs((inputs) => ({ ...inputs, models: localModels }));

View File

@@ -1,19 +1,12 @@
import React from 'react';
import { Segment, Header } from 'semantic-ui-react';
import { Message } from 'semantic-ui-react';
const NotFound = () => (
<>
<Header
block
as="h4"
content="404"
attached="top"
icon="info"
className="small-icon"
/>
<Segment attached="bottom">
未找到所请求的页面
</Segment>
<Message negative>
<Message.Header>页面不存在</Message.Header>
<p>请检查你的浏览器地址是否正确</p>
</Message>
</>
);