Compare commits

...

13 Commits

Author SHA1 Message Date
JustSong
12ef9679a7 fix: fix url not passing when using custom chat_link 2023-09-17 17:19:12 +08:00
JustSong
328aa68255 feat: able to delete logs now (close #486) 2023-09-17 17:09:56 +08:00
JustSong
4335f005a6 feat: create new log file when too many logs recorded 2023-09-17 16:35:30 +08:00
JustSong
fe26a1448d docs: update README 2023-09-17 15:41:01 +08:00
JustSong
42451d9d02 refactor: update logging related logic 2023-09-17 15:39:46 +08:00
JustSong
25c4c111ab fix: only enable cors for relay routers to avoid csrf attack 2023-09-17 11:44:38 +08:00
JustSong
0d50ad4b2b chore: update channel test prompt 2023-09-17 11:34:06 +08:00
JustSong
959bcdef88 chore: update error code 2023-09-17 11:30:20 +08:00
JustSong
39ae8075e4 fix: fix oauth2 state not checking 2023-09-15 00:24:20 +08:00
JustSong
b57a0eca16 docs: update readme 2023-09-13 23:22:53 +08:00
Junyan Qin
1b4cc78890 docs: add QChatGPT (#522)
* doc(README.md): 添加支持One API的项目QChatGPT

* Update README.md

---------

Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2023-09-13 23:18:53 +08:00
JustSong
420c375140 perf: only return quota when it's not zero 2023-09-13 22:05:10 +08:00
JustSong
01863d3e44 fix: fix quota not return when error occurred (close #518) 2023-09-13 21:50:45 +08:00
30 changed files with 338 additions and 136 deletions

3
.gitignore vendored
View File

@@ -4,4 +4,5 @@ upload
*.exe
*.db
build
*.db-journal
*.db-journal
logs

View File

@@ -71,10 +71,10 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [360 智脑](https://ai.360.cn)
2. 支持配置镜像以及众多第三方代理服务:
+ [x] [OpenAI-SB](https://openai-sb.com)
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412)
+ [x] [API2D](https://api2d.com/r/197971)
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412)
+ [x] 自定义渠道:例如各种未收录的第三方代理服务
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
@@ -211,6 +211,13 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。
#### QChatGPT - QQ机器人
项目主页https://github.com/RockChinQ/QChatGPT
根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。
可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。
### 部署到第三方平台
<details>
<summary><strong>部署到 Sealos </strong></summary>
@@ -318,7 +325,7 @@ graph LR
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
+ 例子:`--port 3000`
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,日志将不会被保存
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下
+ 例子:`--log-dir ./logs`
3. `--version`: 打印系统版本号并退出。
4. `--help`: 查看命令的使用帮助和参数说明。
@@ -364,4 +371,4 @@ https://openai.justsong.cn
同样适用于基于本项目的二开项目。
依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。
依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。

View File

@@ -97,6 +97,10 @@ var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQU
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
const (
RoleGuestUser = 0
RoleCommonUser = 1

View File

@@ -12,7 +12,7 @@ var (
Port = flag.Int("port", 3000, "the listening port")
PrintVersion = flag.Bool("version", false, "print version and exit")
PrintHelp = flag.Bool("help", false, "print help and exit")
LogDir = flag.String("log-dir", "", "specify the log directory")
LogDir = flag.String("log-dir", "./logs", "specify the log directory")
)
func printHelp() {

View File

@@ -1,29 +1,47 @@
package common
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"io"
"log"
"os"
"path/filepath"
"sync"
"time"
)
func SetupGinLog() {
const (
loggerINFO = "INFO"
loggerWarn = "WARN"
loggerError = "ERR"
)
const maxLogCount = 1000000
var logCount int
var setupLogLock sync.Mutex
var setupLogWorking bool
func SetupLogger() {
if *LogDir != "" {
commonLogPath := filepath.Join(*LogDir, "common.log")
errorLogPath := filepath.Join(*LogDir, "error.log")
commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
ok := setupLogLock.TryLock()
if !ok {
log.Println("setup log is already working")
return
}
defer func() {
setupLogLock.Unlock()
setupLogWorking = false
}()
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
}
errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
}
gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
}
}
@@ -37,6 +55,36 @@ func SysError(s string) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func LogInfo(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg)
}
func LogWarn(ctx context.Context, msg string) {
logHelper(ctx, loggerWarn, msg)
}
func LogError(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg)
}
func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter
if level == loggerINFO {
writer = gin.DefaultWriter
}
id := ctx.Value(RequestIdKey)
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
logCount++ // we don't need accurate count, so no lock here
if logCount > maxLogCount && !setupLogWorking {
logCount = 0
setupLogWorking = true
go func() {
SetupLogger()
}()
}
}
func FatalLog(v ...any) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)

View File

@@ -171,6 +171,11 @@ func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func Max(a int, b int) int {
if a >= b {
return a
@@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int {
}
return num
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}

View File

@@ -29,7 +29,7 @@ func GetSubscription(c *gin.Context) {
if err != nil {
openAIError := OpenAIError{
Message: err.Error(),
Type: "one_api_error",
Type: "upstream_error",
}
c.JSON(200, gin.H{
"error": openAIError,

View File

@@ -79,6 +79,14 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
func GitHubOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
GitHubBind(c)
@@ -205,3 +213,22 @@ func GitHubBind(c *gin.Context) {
})
return
}
func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c)
state := common.GetRandomString(12)
session.Set("oauth_state", state)
err := session.Save()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": state,
})
}

View File

@@ -2,6 +2,7 @@ package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
@@ -20,17 +21,18 @@ func GetAllLogs(c *gin.Context) {
modelName := c.Query("model_name")
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
if err != nil {
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": logs,
})
return
}
func GetUserLogs(c *gin.Context) {
@@ -46,34 +48,36 @@ func GetUserLogs(c *gin.Context) {
modelName := c.Query("model_name")
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
if err != nil {
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": logs,
})
return
}
func SearchAllLogs(c *gin.Context) {
keyword := c.Query("keyword")
logs, err := model.SearchAllLogs(keyword)
if err != nil {
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": logs,
})
return
}
func SearchUserLogs(c *gin.Context) {
@@ -81,17 +85,18 @@ func SearchUserLogs(c *gin.Context) {
userId := c.GetInt("id")
logs, err := model.SearchUserLogs(userId, keyword)
if err != nil {
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": logs,
})
return
}
func GetLogsStat(c *gin.Context) {
@@ -103,7 +108,7 @@ func GetLogsStat(c *gin.Context) {
modelName := c.Query("model_name")
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
@@ -111,6 +116,7 @@ func GetLogsStat(c *gin.Context) {
//"token": tokenNum,
},
})
return
}
func GetLogsSelfStat(c *gin.Context) {
@@ -122,7 +128,7 @@ func GetLogsSelfStat(c *gin.Context) {
modelName := c.Query("model_name")
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
c.JSON(200, gin.H{
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
@@ -130,4 +136,30 @@ func GetLogsSelfStat(c *gin.Context) {
//"token": tokenNum,
},
})
return
}
func DeleteHistoryLogs(c *gin.Context) {
targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
if targetTimestamp == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "target timestamp is required",
})
return
}
count, err := model.DeleteOldLog(targetTimestamp)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": count,
})
return
}

View File

@@ -2,6 +2,7 @@ package controller
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -91,7 +92,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
}
var audioResponse AudioResponse
defer func() {
defer func(ctx context.Context) {
go func() {
quota := countTokenText(audioResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota
@@ -106,13 +107,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
model.RecordConsumeLog(ctx, userId, 0, 0, audioModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}()
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)

View File

@@ -2,6 +2,7 @@ package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -124,7 +125,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
}
var textResponse ImageResponse
defer func() {
defer func(ctx context.Context) {
if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
@@ -137,13 +138,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
model.RecordConsumeLog(ctx, userId, 0, 0, imageModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}()
}(c.Request.Context())
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)

View File

@@ -2,6 +2,7 @@ package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -210,6 +211,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
if consumeQuota && preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
@@ -347,6 +349,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
return relayErrorHandler(resp)
}
}
@@ -355,7 +366,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
tokenName := c.GetString("token_name")
channelId := c.GetInt("channel_id")
defer func() {
defer func(ctx context.Context) {
// c.Writer.Flush()
go func() {
if consumeQuota {
@@ -378,21 +389,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.RecordConsumeLog(ctx, userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}()
}()
}(c.Request.Context())
switch apiType {
case APITypeOpenAI:
if isStream {

View File

@@ -146,7 +146,7 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
StatusCode: resp.StatusCode,
OpenAIError: OpenAIError{
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
Type: "one_api_error",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},

View File

@@ -196,6 +196,7 @@ func Relay(c *gin.Context) {
err = relayTextHelper(c, relayMode)
}
if err != nil {
requestId := c.GetString(common.RequestIdKey)
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
@@ -207,12 +208,13 @@ func Relay(c *gin.Context) {
if err.StatusCode == http.StatusTooManyRequests {
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
}
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError,
})
}
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
channelId := c.GetInt("channel_id")

View File

@@ -21,7 +21,7 @@ var buildFS embed.FS
var indexPage []byte
func main() {
common.SetupGinLog()
common.SetupLogger()
common.SysLog("One API " + common.Version + " started")
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
@@ -85,11 +85,12 @@ func main() {
controller.InitTokenEncoders()
// Initialize HTTP server
server := gin.Default()
server := gin.New()
server.Use(gin.Recovery())
// This will cause SSE not to work!!!
//server.Use(gzip.Gzip(gzip.DefaultCompression))
server.Use(middleware.CORS())
server.Use(middleware.RequestId())
middleware.SetUpLogger(server)
// Initialize session store
store := cookie.NewStore([]byte(common.SessionSecret))
server.Use(sessions.Sessions("session", store))

View File

@@ -91,34 +91,16 @@ func TokenAuth() func(c *gin.Context) {
key = parts[0]
token, err := model.ValidateUserToken(key)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusUnauthorized, err.Error())
return
}
userEnabled, err := model.IsUserEnabled(token.UserId)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !userEnabled {
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"message": "用户已被封禁",
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
c.Set("id", token.UserId)
@@ -134,13 +116,7 @@ func TokenAuth() func(c *gin.Context) {
if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1])
} else {
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"message": "普通用户不支持指定渠道",
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
}
}

View File

@@ -25,34 +25,16 @@ func Distribute() func(c *gin.Context) {
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
return
}
channel, err = model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
return
}
if channel.Status != common.ChannelStatusEnabled {
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"message": "该渠道已被禁用",
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
return
}
} else {
@@ -63,13 +45,7 @@ func Distribute() func(c *gin.Context) {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的请求",
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
@@ -99,13 +75,7 @@ func Distribute() func(c *gin.Context) {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"message": message,
"type": "one_api_error",
},
})
c.Abort()
abortWithMessage(c, http.StatusServiceUnavailable, message)
return
}
}

25
middleware/logger.go Normal file
View File

@@ -0,0 +1,25 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
)
func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string
if param.Keys != nil {
requestID = param.Keys[common.RequestIdKey].(string)
}
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
requestID,
param.StatusCode,
param.Latency,
param.ClientIP,
param.Method,
param.Path,
)
}))
}

18
middleware/request-id.go Normal file
View File

@@ -0,0 +1,18 @@
package middleware
import (
"context"
"github.com/gin-gonic/gin"
"one-api/common"
)
func RequestId() func(c *gin.Context) {
return func(c *gin.Context) {
id := common.GetTimeString() + common.GetRandomString(8)
c.Set(common.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx)
c.Header(common.RequestIdKey, id)
c.Next()
}
}

17
middleware/utils.go Normal file
View File

@@ -0,0 +1,17 @@
package middleware
import (
"github.com/gin-gonic/gin"
"one-api/common"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
"type": "one_api_error",
},
})
c.Abort()
common.LogError(c.Request.Context(), message)
}

View File

@@ -1,6 +1,8 @@
package model
import (
"context"
"fmt"
"gorm.io/gorm"
"one-api/common"
)
@@ -44,7 +46,8 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
return
}
@@ -62,7 +65,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
}
err := DB.Create(log).Error
if err != nil {
common.SysError("failed to record log: " + err.Error())
common.LogError(ctx, "failed to record log: "+err.Error())
}
}
@@ -166,3 +169,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx.Where("type = ?", LogTypeConsume).Scan(&token)
return token
}
func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error
}

View File

@@ -21,6 +21,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
@@ -97,6 +98,7 @@ func SetApiRouter(router *gin.Engine) {
}
logRoute := apiRouter.Group("/log")
logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs)
logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)

View File

@@ -8,6 +8,7 @@ import (
)
func SetRelayRouter(router *gin.Engine) {
router.Use(middleware.CORS())
// https://platform.openai.com/docs/api-reference/introduction
modelsRouter := router.Group("/v1/models")
modelsRouter.Use(middleware.TokenAuth())

View File

@@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom';
import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
import { renderGroup, renderNumber } from '../helpers/render';
@@ -195,6 +195,7 @@ const ChannelsTable = () => {
showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
} else {
showError(message);
showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。")
}
};

View File

@@ -13,8 +13,8 @@ const GitHubOAuth = () => {
let navigate = useNavigate();
const sendCode = async (code, count) => {
const res = await API.get(`/api/oauth/github?code=${code}`);
const sendCode = async (code, state, count) => {
const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`);
const { success, message, data } = res.data;
if (success) {
if (message === 'bind') {
@@ -36,13 +36,14 @@ const GitHubOAuth = () => {
count++;
setPrompt(`出现错误,第 ${count} 次重试中...`);
await new Promise((resolve) => setTimeout(resolve, count * 2000));
await sendCode(code, count);
await sendCode(code, state, count);
}
};
useEffect(() => {
let code = searchParams.get('code');
sendCode(code, 0).then();
let state = searchParams.get('state');
sendCode(code, state, 0).then();
}, []);
return (

View File

@@ -3,6 +3,7 @@ import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } f
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
import { UserContext } from '../context/User';
import { API, getLogo, showError, showSuccess } from '../helpers';
import { getOAuthState, onGitHubOAuthClicked } from './utils';
const LoginForm = () => {
const [inputs, setInputs] = useState({
@@ -31,12 +32,6 @@ const LoginForm = () => {
const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
const onGitHubOAuthClicked = () => {
window.open(
`https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email`
);
};
const onWeChatLoginClicked = () => {
setShowWeChatLoginModal(true);
};
@@ -131,7 +126,7 @@ const LoginForm = () => {
circular
color='black'
icon='github'
onClick={onGitHubOAuthClicked}
onClick={()=>onGitHubOAuthClicked(status.github_client_id)}
/>
) : (
<></>

View File

@@ -1,8 +1,9 @@
import React, { useEffect, useState } from 'react';
import { Divider, Form, Grid, Header } from 'semantic-ui-react';
import { API, showError, verifyJSON } from '../helpers';
import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers';
const OperationSetting = () => {
let now = new Date();
let [inputs, setInputs] = useState({
QuotaForNewUser: 0,
QuotaForInviter: 0,
@@ -20,10 +21,11 @@ const OperationSetting = () => {
DisplayInCurrencyEnabled: '',
DisplayTokenStatEnabled: '',
ApproximateTokenEnabled: '',
RetryTimes: 0,
RetryTimes: 0
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
const getOptions = async () => {
const res = await API.get('/api/option/');
@@ -130,6 +132,17 @@ const OperationSetting = () => {
}
};
const deleteHistoryLogs = async () => {
console.log(inputs);
const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
const { success, message, data } = res.data;
if (success) {
showSuccess(`${data} 条日志已清理!`);
return;
}
showError('日志清理失败:' + message);
};
return (
<Grid columns={1}>
<Grid.Column>
@@ -179,12 +192,6 @@ const OperationSetting = () => {
/>
</Form.Group>
<Form.Group inline>
<Form.Checkbox
checked={inputs.LogConsumeEnabled === 'true'}
label='启用额度消费日志记录'
name='LogConsumeEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.DisplayInCurrencyEnabled === 'true'}
label='以货币形式显示额度'
@@ -208,6 +215,28 @@ const OperationSetting = () => {
submitConfig('general').then();
}}>保存通用设置</Form.Button>
<Divider />
<Header as='h3'>
日志设置
</Header>
<Form.Group inline>
<Form.Checkbox
checked={inputs.LogConsumeEnabled === 'true'}
label='启用额度消费日志记录'
name='LogConsumeEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Group widths={4}>
<Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
name='history_timestamp'
onChange={(e, { name, value }) => {
setHistoryTimestamp(value);
}} />
</Form.Group>
<Form.Button onClick={() => {
deleteHistoryLogs().then();
}}>清理历史日志</Form.Button>
<Divider />
<Header as='h3'>
监控设置
</Header>

View File

@@ -4,6 +4,7 @@ import { Link, useNavigate } from 'react-router-dom';
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
import Turnstile from 'react-turnstile';
import { UserContext } from '../context/User';
import { onGitHubOAuthClicked } from './utils';
const PersonalSetting = () => {
const [userState, userDispatch] = useContext(UserContext);
@@ -130,12 +131,6 @@ const PersonalSetting = () => {
}
};
const openGitHubOAuth = () => {
window.open(
`https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email`
);
};
const sendVerificationCode = async () => {
setDisableButton(true);
if (inputs.email === '') return;
@@ -249,7 +244,7 @@ const PersonalSetting = () => {
</Modal>
{
status.github_oauth && (
<Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button>
<Button onClick={()=>{onGitHubOAuthClicked(status.github_client_id)}}>绑定 GitHub 账号</Button>
)
}
<Button

View File

@@ -96,7 +96,7 @@ const TokensTable = () => {
let nextUrl;
if (nextLink) {
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`;
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
} else {
nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
}

View File

@@ -0,0 +1,20 @@
import { API, showError } from '../helpers';
export async function getOAuthState() {
const res = await API.get('/api/oauth/state');
const { success, message, data } = res.data;
if (success) {
return data;
} else {
showError(message);
return '';
}
}
export async function onGitHubOAuthClicked(github_client_id) {
const state = await getOAuthState();
if (!state) return;
window.open(
`https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`
);
}