mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-24 02:13:42 +08:00
Compare commits
22 Commits
v0.5.5-alp
...
v0.5.5
Author | SHA1 | Date | |
---|---|---|---|
|
12ef9679a7 | ||
|
328aa68255 | ||
|
4335f005a6 | ||
|
fe26a1448d | ||
|
42451d9d02 | ||
|
25c4c111ab | ||
|
0d50ad4b2b | ||
|
959bcdef88 | ||
|
39ae8075e4 | ||
|
b57a0eca16 | ||
|
1b4cc78890 | ||
|
420c375140 | ||
|
01863d3e44 | ||
|
d0a0e871e1 | ||
|
bd6fe1e93c | ||
|
c55bb67818 | ||
|
0f949c3782 | ||
|
a721a5b6f9 | ||
|
276163affd | ||
|
621eb91b46 | ||
|
7e575abb95 | ||
|
9db93316c4 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -5,3 +5,4 @@ upload
|
||||
*.db
|
||||
build
|
||||
*.db-journal
|
||||
logs
|
17
README.md
17
README.md
@@ -71,10 +71,10 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
||||
+ [x] [360 智脑](https://ai.360.cn)
|
||||
2. 支持配置镜像以及众多第三方代理服务:
|
||||
+ [x] [OpenAI-SB](https://openai-sb.com)
|
||||
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412)
|
||||
+ [x] [API2D](https://api2d.com/r/197971)
|
||||
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
|
||||
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
|
||||
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412)
|
||||
+ [x] 自定义渠道:例如各种未收录的第三方代理服务
|
||||
3. 支持通过**负载均衡**的方式访问多个渠道。
|
||||
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
|
||||
@@ -109,7 +109,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
||||
|
||||
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
|
||||
|
||||
如果启动失败,请添加 `--privileged=true`,具体参考 #482。
|
||||
如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
|
||||
|
||||
如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
|
||||
|
||||
@@ -211,6 +211,13 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
|
||||
|
||||
注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。
|
||||
|
||||
#### QChatGPT - QQ机器人
|
||||
项目主页:https://github.com/RockChinQ/QChatGPT
|
||||
|
||||
根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。
|
||||
|
||||
可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。
|
||||
|
||||
### 部署到第三方平台
|
||||
<details>
|
||||
<summary><strong>部署到 Sealos </strong></summary>
|
||||
@@ -308,13 +315,17 @@ graph LR
|
||||
+ 例子:`POLLING_INTERVAL=5`
|
||||
10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
||||
+ 例子:`BATCH_UPDATE_ENABLED=true`
|
||||
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
|
||||
11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
|
||||
+ 例子:`BATCH_UPDATE_INTERVAL=5`
|
||||
12. 请求频率限制:
|
||||
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
|
||||
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
|
||||
|
||||
### 命令行参数
|
||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
||||
+ 例子:`--port 3000`
|
||||
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,日志将不会被保存。
|
||||
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。
|
||||
+ 例子:`--log-dir ./logs`
|
||||
3. `--version`: 打印系统版本号并退出。
|
||||
4. `--help`: 查看命令的使用帮助和参数说明。
|
||||
|
@@ -97,6 +97,10 @@ var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQU
|
||||
var BatchUpdateEnabled = false
|
||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
|
||||
const (
|
||||
RequestIdKey = "X-Oneapi-Request-Id"
|
||||
)
|
||||
|
||||
const (
|
||||
RoleGuestUser = 0
|
||||
RoleCommonUser = 1
|
||||
@@ -114,10 +118,10 @@ var (
|
||||
// All duration's unit is seconds
|
||||
// Shouldn't larger then RateLimitKeyExpirationDuration
|
||||
var (
|
||||
GlobalApiRateLimitNum = 180
|
||||
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
||||
GlobalApiRateLimitDuration int64 = 3 * 60
|
||||
|
||||
GlobalWebRateLimitNum = 60
|
||||
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitDuration int64 = 3 * 60
|
||||
|
||||
UploadRateLimitNum = 10
|
||||
@@ -179,29 +183,31 @@ const (
|
||||
ChannelType360 = 19
|
||||
ChannelTypeOpenRouter = 20
|
||||
ChannelTypeAIProxyLibrary = 21
|
||||
ChannelTypeFastGPT = 22
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"https://api.closeai-proxy.xyz", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://ai.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"https://api.closeai-proxy.xyz", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://ai.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
}
|
||||
|
@@ -12,7 +12,7 @@ var (
|
||||
Port = flag.Int("port", 3000, "the listening port")
|
||||
PrintVersion = flag.Bool("version", false, "print version and exit")
|
||||
PrintHelp = flag.Bool("help", false, "print help and exit")
|
||||
LogDir = flag.String("log-dir", "", "specify the log directory")
|
||||
LogDir = flag.String("log-dir", "./logs", "specify the log directory")
|
||||
)
|
||||
|
||||
func printHelp() {
|
||||
|
@@ -1,29 +1,47 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func SetupGinLog() {
|
||||
const (
|
||||
loggerINFO = "INFO"
|
||||
loggerWarn = "WARN"
|
||||
loggerError = "ERR"
|
||||
)
|
||||
|
||||
const maxLogCount = 1000000
|
||||
|
||||
var logCount int
|
||||
var setupLogLock sync.Mutex
|
||||
var setupLogWorking bool
|
||||
|
||||
func SetupLogger() {
|
||||
if *LogDir != "" {
|
||||
commonLogPath := filepath.Join(*LogDir, "common.log")
|
||||
errorLogPath := filepath.Join(*LogDir, "error.log")
|
||||
commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
ok := setupLogLock.TryLock()
|
||||
if !ok {
|
||||
log.Println("setup log is already working")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
setupLogLock.Unlock()
|
||||
setupLogWorking = false
|
||||
}()
|
||||
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
|
||||
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
log.Fatal("failed to open log file")
|
||||
}
|
||||
errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
log.Fatal("failed to open log file")
|
||||
}
|
||||
gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
|
||||
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
|
||||
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
|
||||
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,6 +55,36 @@ func SysError(s string) {
|
||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||
}
|
||||
|
||||
func LogInfo(ctx context.Context, msg string) {
|
||||
logHelper(ctx, loggerINFO, msg)
|
||||
}
|
||||
|
||||
func LogWarn(ctx context.Context, msg string) {
|
||||
logHelper(ctx, loggerWarn, msg)
|
||||
}
|
||||
|
||||
func LogError(ctx context.Context, msg string) {
|
||||
logHelper(ctx, loggerError, msg)
|
||||
}
|
||||
|
||||
func logHelper(ctx context.Context, level string, msg string) {
|
||||
writer := gin.DefaultErrorWriter
|
||||
if level == loggerINFO {
|
||||
writer = gin.DefaultWriter
|
||||
}
|
||||
id := ctx.Value(RequestIdKey)
|
||||
now := time.Now()
|
||||
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
||||
logCount++ // we don't need accurate count, so no lock here
|
||||
if logCount > maxLogCount && !setupLogWorking {
|
||||
logCount = 0
|
||||
setupLogWorking = true
|
||||
go func() {
|
||||
SetupLogger()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func FatalLog(v ...any) {
|
||||
t := time.Now()
|
||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
||||
|
@@ -50,9 +50,10 @@ var ModelRatio = map[string]float64{
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||
"qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
|
||||
"qwen-plus-v1": 0.5715, // Same as above
|
||||
"SparkDesk": 0.8572, // TBD
|
||||
"qwen-v1": 0.8572, // ¥0.012 / 1k tokens
|
||||
"qwen-plus-v1": 1, // ¥0.014 / 1k tokens
|
||||
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
||||
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
|
||||
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
||||
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
||||
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||
|
@@ -171,6 +171,11 @@ func GetTimestamp() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
func GetTimeString() string {
|
||||
now := time.Now()
|
||||
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
|
||||
}
|
||||
|
||||
func Max(a int, b int) int {
|
||||
if a >= b {
|
||||
return a
|
||||
@@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int {
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
||||
func MessageWithRequestId(message string, id string) string {
|
||||
return fmt.Sprintf("%s (request id: %s)", message, id)
|
||||
}
|
||||
|
@@ -29,7 +29,7 @@ func GetSubscription(c *gin.Context) {
|
||||
if err != nil {
|
||||
openAIError := OpenAIError{
|
||||
Message: err.Error(),
|
||||
Type: "one_api_error",
|
||||
Type: "upstream_error",
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"error": openAIError,
|
||||
|
@@ -14,7 +14,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
|
||||
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
|
||||
switch channel.Type {
|
||||
case common.ChannelTypePaLM:
|
||||
fallthrough
|
||||
@@ -32,6 +32,11 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
|
||||
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
|
||||
case common.ChannelTypeAzure:
|
||||
request.Model = "gpt-35-turbo"
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
|
||||
}
|
||||
}()
|
||||
default:
|
||||
request.Model = "gpt-3.5-turbo"
|
||||
}
|
||||
|
@@ -79,6 +79,14 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
||||
|
||||
func GitHubOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
GitHubBind(c)
|
||||
@@ -205,3 +213,22 @@ func GitHubBind(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GenerateOAuthCode(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := common.GetRandomString(12)
|
||||
session.Set("oauth_state", state)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": state,
|
||||
})
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
@@ -20,17 +21,18 @@ func GetAllLogs(c *gin.Context) {
|
||||
modelName := c.Query("model_name")
|
||||
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetUserLogs(c *gin.Context) {
|
||||
@@ -46,34 +48,36 @@ func GetUserLogs(c *gin.Context) {
|
||||
modelName := c.Query("model_name")
|
||||
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SearchAllLogs(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
logs, err := model.SearchAllLogs(keyword)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SearchUserLogs(c *gin.Context) {
|
||||
@@ -81,17 +85,18 @@ func SearchUserLogs(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
logs, err := model.SearchUserLogs(userId, keyword)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetLogsStat(c *gin.Context) {
|
||||
@@ -103,7 +108,7 @@ func GetLogsStat(c *gin.Context) {
|
||||
modelName := c.Query("model_name")
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
@@ -111,6 +116,7 @@ func GetLogsStat(c *gin.Context) {
|
||||
//"token": tokenNum,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetLogsSelfStat(c *gin.Context) {
|
||||
@@ -122,7 +128,7 @@ func GetLogsSelfStat(c *gin.Context) {
|
||||
modelName := c.Query("model_name")
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
@@ -130,4 +136,30 @@ func GetLogsSelfStat(c *gin.Context) {
|
||||
//"token": tokenNum,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteHistoryLogs(c *gin.Context) {
|
||||
targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
|
||||
if targetTimestamp == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "target timestamp is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
count, err := model.DeleteOldLog(targetTimestamp)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": count,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
@@ -360,6 +360,15 @@ func init() {
|
||||
Root: "qwen-plus-v1",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-embedding-v1",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "ali",
|
||||
Permission: permission,
|
||||
Root: "text-embedding-v1",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "SparkDesk",
|
||||
Object: "model",
|
||||
|
@@ -35,6 +35,29 @@ type AliChatRequest struct {
|
||||
Parameters AliParameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
} `json:"input"`
|
||||
Parameters *struct {
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliEmbedding struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type AliEmbeddingResponse struct {
|
||||
Output struct {
|
||||
Embeddings []AliEmbedding `json:"embeddings"`
|
||||
} `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
||||
|
||||
type AliError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
@@ -44,6 +67,7 @@ type AliError struct {
|
||||
type AliUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
@@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
|
||||
return &AliEmbeddingRequest{
|
||||
Model: "text-embedding-v1",
|
||||
Input: struct {
|
||||
Texts []string `json:"texts"`
|
||||
}{
|
||||
Texts: request.ParseInput(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var aliResponse AliEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
|
||||
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: Usage{TotalTokens: response.Usage.TotalTokens},
|
||||
}
|
||||
|
||||
for _, item := range response.Output.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: item.TextIndex,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
|
||||
choice := OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
|
@@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -91,7 +92,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
}
|
||||
var audioResponse AudioResponse
|
||||
|
||||
defer func() {
|
||||
defer func(ctx context.Context) {
|
||||
go func() {
|
||||
quota := countTokenText(audioResponse.Text, audioModel)
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
@@ -106,13 +107,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
|
||||
model.RecordConsumeLog(ctx, userId, 0, 0, audioModel, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}()
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
|
@@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
|
||||
}
|
||||
|
||||
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
||||
baiduEmbeddingRequest := BaiduEmbeddingRequest{
|
||||
Input: nil,
|
||||
return &BaiduEmbeddingRequest{
|
||||
Input: request.ParseInput(),
|
||||
}
|
||||
switch request.Input.(type) {
|
||||
case string:
|
||||
baiduEmbeddingRequest.Input = []string{request.Input.(string)}
|
||||
case []any:
|
||||
for _, item := range request.Input.([]any) {
|
||||
if str, ok := item.(string); ok {
|
||||
baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &baiduEmbeddingRequest
|
||||
}
|
||||
|
||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
|
||||
|
@@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -124,7 +125,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
}
|
||||
var textResponse ImageResponse
|
||||
|
||||
defer func() {
|
||||
defer func(ctx context.Context) {
|
||||
if consumeQuota {
|
||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||
if err != nil {
|
||||
@@ -137,13 +138,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
|
||||
model.RecordConsumeLog(ctx, userId, 0, 0, imageModel, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
|
||||
if consumeQuota {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
@@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -174,6 +175,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
||||
case APITypeAli:
|
||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||
if relayMode == RelayModeEmbeddings {
|
||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||
}
|
||||
case APITypeAIProxyLibrary:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
|
||||
}
|
||||
@@ -207,6 +211,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
||||
}
|
||||
if consumeQuota && preConsumedQuota > 0 {
|
||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||
@@ -262,8 +267,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeAli:
|
||||
aliRequest := requestOpenAI2Ali(textRequest)
|
||||
jsonStr, err := json.Marshal(aliRequest)
|
||||
var jsonStr []byte
|
||||
var err error
|
||||
switch relayMode {
|
||||
case RelayModeEmbeddings:
|
||||
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
|
||||
jsonStr, err = json.Marshal(aliEmbeddingRequest)
|
||||
default:
|
||||
aliRequest := requestOpenAI2Ali(textRequest)
|
||||
jsonStr, err = json.Marshal(aliRequest)
|
||||
}
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@@ -336,6 +349,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if preConsumedQuota != 0 {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
return relayErrorHandler(resp)
|
||||
}
|
||||
}
|
||||
@@ -344,7 +366,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
tokenName := c.GetString("token_name")
|
||||
channelId := c.GetInt("channel_id")
|
||||
|
||||
defer func() {
|
||||
defer func(ctx context.Context) {
|
||||
// c.Writer.Flush()
|
||||
go func() {
|
||||
if consumeQuota {
|
||||
@@ -367,22 +389,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
||||
model.RecordConsumeLog(ctx, userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
switch apiType {
|
||||
case APITypeOpenAI:
|
||||
if isStream {
|
||||
@@ -503,7 +524,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
err, usage := aliHandler(c, resp)
|
||||
var err *OpenAIErrorWithStatusCode
|
||||
var usage *Usage
|
||||
switch relayMode {
|
||||
case RelayModeEmbeddings:
|
||||
err, usage = aliEmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = aliHandler(c, resp)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -146,7 +146,7 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: OpenAIError{
|
||||
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
||||
Type: "one_api_error",
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
},
|
||||
|
@@ -44,6 +44,25 @@ type GeneralOpenAIRequest struct {
|
||||
Functions any `json:"functions,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
if r.Input == nil {
|
||||
return nil
|
||||
}
|
||||
var input []string
|
||||
switch r.Input.(type) {
|
||||
case string:
|
||||
input = []string{r.Input.(string)}
|
||||
case []any:
|
||||
input = make([]string, 0, len(r.Input.([]any)))
|
||||
for _, item := range r.Input.([]any) {
|
||||
if str, ok := item.(string); ok {
|
||||
input = append(input, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
@@ -177,6 +196,7 @@ func Relay(c *gin.Context) {
|
||||
err = relayTextHelper(c, relayMode)
|
||||
}
|
||||
if err != nil {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
retryTimesStr := c.Query("retry")
|
||||
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
||||
if retryTimesStr == "" {
|
||||
@@ -188,12 +208,13 @@ func Relay(c *gin.Context) {
|
||||
if err.StatusCode == http.StatusTooManyRequests {
|
||||
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
||||
c.JSON(err.StatusCode, gin.H{
|
||||
"error": err.OpenAIError,
|
||||
})
|
||||
}
|
||||
channelId := c.GetInt("channel_id")
|
||||
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
|
||||
channelId := c.GetInt("channel_id")
|
||||
|
@@ -523,5 +523,6 @@
|
||||
"按照如下格式输入:": "Enter in the following format:",
|
||||
"模型版本": "Model version",
|
||||
"请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
|
||||
"点击查看": "click to view"
|
||||
"点击查看": "click to view",
|
||||
"请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!"
|
||||
}
|
||||
|
9
main.go
9
main.go
@@ -21,7 +21,7 @@ var buildFS embed.FS
|
||||
var indexPage []byte
|
||||
|
||||
func main() {
|
||||
common.SetupGinLog()
|
||||
common.SetupLogger()
|
||||
common.SysLog("One API " + common.Version + " started")
|
||||
if os.Getenv("GIN_MODE") != "debug" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -85,11 +85,12 @@ func main() {
|
||||
controller.InitTokenEncoders()
|
||||
|
||||
// Initialize HTTP server
|
||||
server := gin.Default()
|
||||
server := gin.New()
|
||||
server.Use(gin.Recovery())
|
||||
// This will cause SSE not to work!!!
|
||||
//server.Use(gzip.Gzip(gzip.DefaultCompression))
|
||||
server.Use(middleware.CORS())
|
||||
|
||||
server.Use(middleware.RequestId())
|
||||
middleware.SetUpLogger(server)
|
||||
// Initialize session store
|
||||
store := cookie.NewStore([]byte(common.SessionSecret))
|
||||
server.Use(sessions.Sessions("session", store))
|
||||
|
@@ -91,23 +91,16 @@ func TokenAuth() func(c *gin.Context) {
|
||||
key = parts[0]
|
||||
token, err := model.ValidateUserToken(key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
if !model.CacheIsUserEnabled(token.UserId) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
userEnabled, err := model.IsUserEnabled(token.UserId)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if !userEnabled {
|
||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
return
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
@@ -123,13 +116,7 @@ func TokenAuth() func(c *gin.Context) {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("channelId", parts[1])
|
||||
} else {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "普通用户不支持指定渠道",
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@@ -25,34 +25,16 @@ func Distribute() func(c *gin.Context) {
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "无效的渠道 ID",
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
|
||||
return
|
||||
}
|
||||
channel, err = model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "无效的渠道 ID",
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "该渠道已被禁用",
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
@@ -63,13 +45,7 @@ func Distribute() func(c *gin.Context) {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
}
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "无效的请求",
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
@@ -99,13 +75,7 @@ func Distribute() func(c *gin.Context) {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
}
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": gin.H{
|
||||
"message": message,
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
abortWithMessage(c, http.StatusServiceUnavailable, message)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
25
middleware/logger.go
Normal file
25
middleware/logger.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
func SetUpLogger(server *gin.Engine) {
|
||||
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
var requestID string
|
||||
if param.Keys != nil {
|
||||
requestID = param.Keys[common.RequestIdKey].(string)
|
||||
}
|
||||
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
|
||||
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||
requestID,
|
||||
param.StatusCode,
|
||||
param.Latency,
|
||||
param.ClientIP,
|
||||
param.Method,
|
||||
param.Path,
|
||||
)
|
||||
}))
|
||||
}
|
18
middleware/request-id.go
Normal file
18
middleware/request-id.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
func RequestId() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
id := common.GetTimeString() + common.GetRandomString(8)
|
||||
c.Set(common.RequestIdKey, id)
|
||||
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Header(common.RequestIdKey, id)
|
||||
c.Next()
|
||||
}
|
||||
}
|
17
middleware/utils.go
Normal file
17
middleware/utils.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
c.JSON(statusCode, gin.H{
|
||||
"error": gin.H{
|
||||
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
common.LogError(c.Request.Context(), message)
|
||||
}
|
@@ -103,23 +103,28 @@ func CacheDecreaseUserQuota(id int, quota int) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func CacheIsUserEnabled(userId int) bool {
|
||||
func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
if !common.RedisEnabled {
|
||||
return IsUserEnabled(userId)
|
||||
}
|
||||
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
|
||||
if err != nil {
|
||||
status := common.UserStatusDisabled
|
||||
if IsUserEnabled(userId) {
|
||||
status = common.UserStatusEnabled
|
||||
}
|
||||
enabled = fmt.Sprintf("%d", status)
|
||||
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user enabled error: " + err.Error())
|
||||
}
|
||||
if err == nil {
|
||||
return enabled == "1", nil
|
||||
}
|
||||
return enabled == "1"
|
||||
|
||||
userEnabled, err := IsUserEnabled(userId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
enabled = "0"
|
||||
if userEnabled {
|
||||
enabled = "1"
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user enabled error: " + err.Error())
|
||||
}
|
||||
return userEnabled, err
|
||||
}
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
|
12
model/log.go
12
model/log.go
@@ -1,6 +1,8 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
)
|
||||
@@ -44,7 +46,8 @@ func RecordLog(userId int, logType int, content string) {
|
||||
}
|
||||
}
|
||||
|
||||
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
|
||||
func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
|
||||
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||
if !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
@@ -62,7 +65,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
|
||||
}
|
||||
err := DB.Create(log).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to record log: " + err.Error())
|
||||
common.LogError(ctx, "failed to record log: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,3 +169,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
tx.Where("type = ?", LogTypeConsume).Scan(&token)
|
||||
return token
|
||||
}
|
||||
|
||||
func DeleteOldLog(targetTimestamp int64) (int64, error) {
|
||||
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
@@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
||||
}
|
||||
token, err = CacheGetTokenByKey(key)
|
||||
if err == nil {
|
||||
if token.Status == common.TokenStatusExhausted {
|
||||
return nil, errors.New("该令牌额度已用尽")
|
||||
} else if token.Status == common.TokenStatusExpired {
|
||||
return nil, errors.New("该令牌已过期")
|
||||
}
|
||||
if token.Status != common.TokenStatusEnabled {
|
||||
return nil, errors.New("该令牌状态不可用")
|
||||
}
|
||||
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
|
||||
token.Status = common.TokenStatusExpired
|
||||
err := token.SelectUpdate()
|
||||
if err != nil {
|
||||
common.SysError("failed to update token status" + err.Error())
|
||||
if !common.RedisEnabled {
|
||||
token.Status = common.TokenStatusExpired
|
||||
err := token.SelectUpdate()
|
||||
if err != nil {
|
||||
common.SysError("failed to update token status" + err.Error())
|
||||
}
|
||||
}
|
||||
return nil, errors.New("该令牌已过期")
|
||||
}
|
||||
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
|
||||
token.Status = common.TokenStatusExhausted
|
||||
err := token.SelectUpdate()
|
||||
if err != nil {
|
||||
common.SysError("failed to update token status" + err.Error())
|
||||
if !common.RedisEnabled {
|
||||
// in this case, we can make sure the token is exhausted
|
||||
token.Status = common.TokenStatusExhausted
|
||||
err := token.SelectUpdate()
|
||||
if err != nil {
|
||||
common.SysError("failed to update token status" + err.Error())
|
||||
}
|
||||
}
|
||||
return nil, errors.New("该令牌额度已用尽")
|
||||
}
|
||||
go func() {
|
||||
token.AccessedTime = common.GetTimestamp()
|
||||
err := token.SelectUpdate()
|
||||
if err != nil {
|
||||
common.SysError("failed to update token" + err.Error())
|
||||
}
|
||||
}()
|
||||
return token, nil
|
||||
}
|
||||
return nil, errors.New("无效的令牌")
|
||||
@@ -141,8 +144,9 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
|
||||
func increaseTokenQuota(id int, quota int) (err error) {
|
||||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||
map[string]interface{}{
|
||||
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
||||
"used_quota": gorm.Expr("used_quota - ?", quota),
|
||||
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
||||
"used_quota": gorm.Expr("used_quota - ?", quota),
|
||||
"accessed_time": common.GetTimestamp(),
|
||||
},
|
||||
).Error
|
||||
return err
|
||||
@@ -162,8 +166,9 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
|
||||
func decreaseTokenQuota(id int, quota int) (err error) {
|
||||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||
map[string]interface{}{
|
||||
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||
"accessed_time": common.GetTimestamp(),
|
||||
},
|
||||
).Error
|
||||
return err
|
||||
|
@@ -226,17 +226,16 @@ func IsAdmin(userId int) bool {
|
||||
return user.Role >= common.RoleAdminUser
|
||||
}
|
||||
|
||||
func IsUserEnabled(userId int) bool {
|
||||
func IsUserEnabled(userId int) (bool, error) {
|
||||
if userId == 0 {
|
||||
return false
|
||||
return false, errors.New("user id is empty")
|
||||
}
|
||||
var user User
|
||||
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
|
||||
if err != nil {
|
||||
common.SysError("no such user " + err.Error())
|
||||
return false
|
||||
return false, err
|
||||
}
|
||||
return user.Status == common.UserStatusEnabled
|
||||
return user.Status == common.UserStatusEnabled, nil
|
||||
}
|
||||
|
||||
func ValidateAccessToken(token string) (user *User) {
|
||||
|
@@ -21,6 +21,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
|
||||
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
|
||||
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
|
||||
@@ -97,6 +98,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
}
|
||||
logRoute := apiRouter.Group("/log")
|
||||
logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
|
||||
logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs)
|
||||
logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
|
||||
logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
|
||||
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func SetRelayRouter(router *gin.Engine) {
|
||||
router.Use(middleware.CORS())
|
||||
// https://platform.openai.com/docs/api-reference/introduction
|
||||
modelsRouter := router.Group("/v1/models")
|
||||
modelsRouter.Use(middleware.TokenAuth())
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
|
||||
import { Link } from 'react-router-dom';
|
||||
import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
|
||||
import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers';
|
||||
|
||||
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
|
||||
import { renderGroup, renderNumber } from '../helpers/render';
|
||||
@@ -195,6 +195,7 @@ const ChannelsTable = () => {
|
||||
showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
|
||||
} else {
|
||||
showError(message);
|
||||
showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。")
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -13,8 +13,8 @@ const GitHubOAuth = () => {
|
||||
|
||||
let navigate = useNavigate();
|
||||
|
||||
const sendCode = async (code, count) => {
|
||||
const res = await API.get(`/api/oauth/github?code=${code}`);
|
||||
const sendCode = async (code, state, count) => {
|
||||
const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
if (message === 'bind') {
|
||||
@@ -36,13 +36,14 @@ const GitHubOAuth = () => {
|
||||
count++;
|
||||
setPrompt(`出现错误,第 ${count} 次重试中...`);
|
||||
await new Promise((resolve) => setTimeout(resolve, count * 2000));
|
||||
await sendCode(code, count);
|
||||
await sendCode(code, state, count);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
let code = searchParams.get('code');
|
||||
sendCode(code, 0).then();
|
||||
let state = searchParams.get('state');
|
||||
sendCode(code, state, 0).then();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
|
@@ -3,6 +3,7 @@ import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } f
|
||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { UserContext } from '../context/User';
|
||||
import { API, getLogo, showError, showSuccess } from '../helpers';
|
||||
import { getOAuthState, onGitHubOAuthClicked } from './utils';
|
||||
|
||||
const LoginForm = () => {
|
||||
const [inputs, setInputs] = useState({
|
||||
@@ -31,12 +32,6 @@ const LoginForm = () => {
|
||||
|
||||
const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
|
||||
|
||||
const onGitHubOAuthClicked = () => {
|
||||
window.open(
|
||||
`https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email`
|
||||
);
|
||||
};
|
||||
|
||||
const onWeChatLoginClicked = () => {
|
||||
setShowWeChatLoginModal(true);
|
||||
};
|
||||
@@ -131,7 +126,7 @@ const LoginForm = () => {
|
||||
circular
|
||||
color='black'
|
||||
icon='github'
|
||||
onClick={onGitHubOAuthClicked}
|
||||
onClick={()=>onGitHubOAuthClicked(status.github_client_id)}
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Divider, Form, Grid, Header } from 'semantic-ui-react';
|
||||
import { API, showError, verifyJSON } from '../helpers';
|
||||
import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers';
|
||||
|
||||
const OperationSetting = () => {
|
||||
let now = new Date();
|
||||
let [inputs, setInputs] = useState({
|
||||
QuotaForNewUser: 0,
|
||||
QuotaForInviter: 0,
|
||||
@@ -20,10 +21,11 @@ const OperationSetting = () => {
|
||||
DisplayInCurrencyEnabled: '',
|
||||
DisplayTokenStatEnabled: '',
|
||||
ApproximateTokenEnabled: '',
|
||||
RetryTimes: 0,
|
||||
RetryTimes: 0
|
||||
});
|
||||
const [originInputs, setOriginInputs] = useState({});
|
||||
let [loading, setLoading] = useState(false);
|
||||
let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
|
||||
|
||||
const getOptions = async () => {
|
||||
const res = await API.get('/api/option/');
|
||||
@@ -130,6 +132,17 @@ const OperationSetting = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const deleteHistoryLogs = async () => {
|
||||
console.log(inputs);
|
||||
const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
showSuccess(`${data} 条日志已清理!`);
|
||||
return;
|
||||
}
|
||||
showError('日志清理失败:' + message);
|
||||
};
|
||||
|
||||
return (
|
||||
<Grid columns={1}>
|
||||
<Grid.Column>
|
||||
@@ -179,12 +192,6 @@ const OperationSetting = () => {
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.LogConsumeEnabled === 'true'}
|
||||
label='启用额度消费日志记录'
|
||||
name='LogConsumeEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.DisplayInCurrencyEnabled === 'true'}
|
||||
label='以货币形式显示额度'
|
||||
@@ -208,6 +215,28 @@ const OperationSetting = () => {
|
||||
submitConfig('general').then();
|
||||
}}>保存通用设置</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>
|
||||
日志设置
|
||||
</Header>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.LogConsumeEnabled === 'true'}
|
||||
label='启用额度消费日志记录'
|
||||
name='LogConsumeEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths={4}>
|
||||
<Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
|
||||
name='history_timestamp'
|
||||
onChange={(e, { name, value }) => {
|
||||
setHistoryTimestamp(value);
|
||||
}} />
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
deleteHistoryLogs().then();
|
||||
}}>清理历史日志</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>
|
||||
监控设置
|
||||
</Header>
|
||||
|
@@ -4,6 +4,7 @@ import { Link, useNavigate } from 'react-router-dom';
|
||||
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
|
||||
import Turnstile from 'react-turnstile';
|
||||
import { UserContext } from '../context/User';
|
||||
import { onGitHubOAuthClicked } from './utils';
|
||||
|
||||
const PersonalSetting = () => {
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
@@ -130,12 +131,6 @@ const PersonalSetting = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const openGitHubOAuth = () => {
|
||||
window.open(
|
||||
`https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email`
|
||||
);
|
||||
};
|
||||
|
||||
const sendVerificationCode = async () => {
|
||||
setDisableButton(true);
|
||||
if (inputs.email === '') return;
|
||||
@@ -249,7 +244,7 @@ const PersonalSetting = () => {
|
||||
</Modal>
|
||||
{
|
||||
status.github_oauth && (
|
||||
<Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button>
|
||||
<Button onClick={()=>{onGitHubOAuthClicked(status.github_client_id)}}>绑定 GitHub 账号</Button>
|
||||
)
|
||||
}
|
||||
<Button
|
||||
|
@@ -96,7 +96,7 @@ const TokensTable = () => {
|
||||
let nextUrl;
|
||||
|
||||
if (nextLink) {
|
||||
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`;
|
||||
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||
} else {
|
||||
nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||
}
|
||||
|
20
web/src/components/utils.js
Normal file
20
web/src/components/utils.js
Normal file
@@ -0,0 +1,20 @@
|
||||
import { API, showError } from '../helpers';
|
||||
|
||||
export async function getOAuthState() {
|
||||
const res = await API.get('/api/oauth/state');
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
return data;
|
||||
} else {
|
||||
showError(message);
|
||||
return '';
|
||||
}
|
||||
}
|
||||
|
||||
export async function onGitHubOAuthClicked(github_client_id) {
|
||||
const state = await getOAuthState();
|
||||
if (!state) return;
|
||||
window.open(
|
||||
`https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`
|
||||
);
|
||||
}
|
@@ -9,6 +9,7 @@ export const CHANNEL_OPTIONS = [
|
||||
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
|
||||
{ key: 19, text: '360 智脑', value: 19, color: 'blue' },
|
||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||
{ key: 20, text: '代理:OpenRouter', value: 20, color: 'black' },
|
||||
{ key: 2, text: '代理:API2D', value: 2, color: 'blue' },
|
||||
|
@@ -10,6 +10,20 @@ const MODEL_MAPPING_EXAMPLE = {
|
||||
'gpt-4-32k-0314': 'gpt-4-32k'
|
||||
};
|
||||
|
||||
function type2secretPrompt(type) {
|
||||
// inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
|
||||
switch (type) {
|
||||
case 15:
|
||||
return '按照如下格式输入:APIKey|SecretKey';
|
||||
case 18:
|
||||
return '按照如下格式输入:APPID|APISecret|APIKey';
|
||||
case 22:
|
||||
return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041';
|
||||
default:
|
||||
return '请输入渠道对应的鉴权密钥';
|
||||
}
|
||||
}
|
||||
|
||||
const EditChannel = () => {
|
||||
const params = useParams();
|
||||
const navigate = useNavigate();
|
||||
@@ -53,7 +67,7 @@ const EditChannel = () => {
|
||||
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
|
||||
break;
|
||||
case 17:
|
||||
localModels = ['qwen-v1', 'qwen-plus-v1'];
|
||||
localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1'];
|
||||
break;
|
||||
case 16:
|
||||
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
|
||||
@@ -193,6 +207,24 @@ const EditChannel = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const addCustomModel = () => {
|
||||
if (customModel.trim() === '') return;
|
||||
if (inputs.models.includes(customModel)) return;
|
||||
let localModels = [...inputs.models];
|
||||
localModels.push(customModel);
|
||||
let localModelOptions = [];
|
||||
localModelOptions.push({
|
||||
key: customModel,
|
||||
text: customModel,
|
||||
value: customModel
|
||||
});
|
||||
setModelOptions(modelOptions => {
|
||||
return [...modelOptions, ...localModelOptions];
|
||||
});
|
||||
setCustomModel('');
|
||||
handleInputChange(null, { name: 'models', value: localModels });
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<Segment loading={loading}>
|
||||
@@ -336,29 +368,19 @@ const EditChannel = () => {
|
||||
}}>清除所有模型</Button>
|
||||
<Input
|
||||
action={
|
||||
<Button type={'button'} onClick={() => {
|
||||
if (customModel.trim() === '') return;
|
||||
if (inputs.models.includes(customModel)) return;
|
||||
let localModels = [...inputs.models];
|
||||
localModels.push(customModel);
|
||||
let localModelOptions = [];
|
||||
localModelOptions.push({
|
||||
key: customModel,
|
||||
text: customModel,
|
||||
value: customModel
|
||||
});
|
||||
setModelOptions(modelOptions => {
|
||||
return [...modelOptions, ...localModelOptions];
|
||||
});
|
||||
setCustomModel('');
|
||||
handleInputChange(null, { name: 'models', value: localModels });
|
||||
}}>填入</Button>
|
||||
<Button type={'button'} onClick={addCustomModel}>填入</Button>
|
||||
}
|
||||
placeholder='输入自定义模型名称'
|
||||
value={customModel}
|
||||
onChange={(e, { value }) => {
|
||||
setCustomModel(value);
|
||||
}}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
addCustomModel();
|
||||
e.preventDefault();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<Form.Field>
|
||||
@@ -389,7 +411,7 @@ const EditChannel = () => {
|
||||
label='密钥'
|
||||
name='key'
|
||||
required
|
||||
placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
|
||||
placeholder={type2secretPrompt(inputs.type)}
|
||||
onChange={handleInputChange}
|
||||
value={inputs.key}
|
||||
autoComplete='new-password'
|
||||
@@ -407,7 +429,7 @@ const EditChannel = () => {
|
||||
)
|
||||
}
|
||||
{
|
||||
inputs.type !== 3 && inputs.type !== 8 && (
|
||||
inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
|
||||
<Form.Field>
|
||||
<Form.Input
|
||||
label='代理'
|
||||
@@ -420,6 +442,20 @@ const EditChannel = () => {
|
||||
</Form.Field>
|
||||
)
|
||||
}
|
||||
{
|
||||
inputs.type === 22 && (
|
||||
<Form.Field>
|
||||
<Form.Input
|
||||
label='私有部署地址'
|
||||
name='base_url'
|
||||
placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'}
|
||||
onChange={handleInputChange}
|
||||
value={inputs.base_url}
|
||||
autoComplete='new-password'
|
||||
/>
|
||||
</Form.Field>
|
||||
)
|
||||
}
|
||||
<Button onClick={handleCancel}>取消</Button>
|
||||
<Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
|
||||
</Form>
|
||||
|
Reference in New Issue
Block a user