From d0402f90863a10036bfc0ec7901295ba14dbf0e5 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 31 Jan 2025 17:54:04 +0800 Subject: [PATCH] feat: record request_id --- common/helper/helper.go | 19 +++++++++++++++++-- common/logger/logger.go | 10 +++++----- controller/auth/github.go | 11 +++++++---- controller/auth/lark.go | 11 +++++++---- controller/auth/oidc.go | 11 +++++++---- controller/auth/wechat.go | 11 +++++++---- controller/user.go | 15 ++++++++++----- middleware/request-id.go | 4 ++-- model/log.go | 35 +++++++++++++++++++---------------- model/redemption.go | 9 ++++++--- model/user.go | 15 +++++++++------ 11 files changed, 96 insertions(+), 55 deletions(-) diff --git a/common/helper/helper.go b/common/helper/helper.go index df7b0a5f..65f4fd29 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -1,9 +1,8 @@ package helper import ( + "context" "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/random" "html/template" "log" "net" @@ -11,6 +10,10 @@ import ( "runtime" "strconv" "strings" + + "github.com/gin-gonic/gin" + + "github.com/songquanpeng/one-api/common/random" ) func OpenBrowser(url string) { @@ -106,6 +109,18 @@ func GenRequestID() string { return GetTimeString() + random.GetRandomNumberString(8) } +func SetRequestID(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, RequestIdKey, id) +} + +func GetRequestID(ctx context.Context) string { + rawRequestId := ctx.Value(RequestIdKey) + if rawRequestId == nil { + return "" + } + return rawRequestId.(string) +} + func GetResponseID(c *gin.Context) string { logID := c.GetString(RequestIdKey) return fmt.Sprintf("chatcmpl-%s", logID) diff --git a/common/logger/logger.go b/common/logger/logger.go index c5797217..1e3bc254 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -113,16 +113,16 @@ func logHelper(ctx context.Context, level loggerLevel, msg string) { if level == loggerINFO { writer = gin.DefaultWriter } - var logId string + var requestId string if ctx != nil { - rawLogId := ctx.Value(helper.RequestIdKey) - if rawLogId != nil { - logId = fmt.Sprintf(" | %s", rawLogId.(string)) + rawRequestId := helper.GetRequestID(ctx) + if rawRequestId != "" { + requestId = fmt.Sprintf(" | %s", rawRequestId) } } lineInfo, funcName := getLineInfo() now := time.Now() - _, _ = fmt.Fprintf(writer, "[%s] %v%s%s %s%s \n", level, now.Format("2006/01/02 - 15:04:05"), logId, lineInfo, funcName, msg) + _, _ = fmt.Fprintf(writer, "[%s] %v%s%s %s%s \n", level, now.Format("2006/01/02 - 15:04:05"), requestId, lineInfo, funcName, msg) SetupLogger() if level == loggerFatal { os.Exit(1) diff --git a/controller/auth/github.go b/controller/auth/github.go index 15542655..ecdd183c 100644 --- a/controller/auth/github.go +++ b/controller/auth/github.go @@ -5,16 +5,18 @@ import ( "encoding/json" "errors" "fmt" + "net/http" + "strconv" + "time" + "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" - "net/http" - "strconv" - "time" ) type GitHubOAuthResponse struct { @@ -81,6 +83,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { } func GitHubOAuth(c *gin.Context) { + ctx := c.Request.Context() session := sessions.Default(c) state := c.Query("state") if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { @@ -136,7 +139,7 @@ func GitHubOAuth(c *gin.Context) { user.Role = model.RoleCommonUser user.Status = model.UserStatusEnabled - if err := user.Insert(0); err != nil { + if err := user.Insert(ctx, 0); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/controller/auth/lark.go b/controller/auth/lark.go index 39088b3c..651d5874 100644 --- a/controller/auth/lark.go +++ b/controller/auth/lark.go @@ -5,15 +5,17 @@ import ( "encoding/json" "errors" "fmt" + "net/http" + "strconv" + "time" + "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" - "net/http" - "strconv" - "time" ) type LarkOAuthResponse struct { @@ -79,6 +81,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) { } func LarkOAuth(c *gin.Context) { + ctx := c.Request.Context() session := sessions.Default(c) state := c.Query("state") if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { @@ -125,7 +128,7 @@ func LarkOAuth(c *gin.Context) { user.Role = model.RoleCommonUser user.Status = model.UserStatusEnabled - if err := user.Insert(0); err != nil { + if err := user.Insert(ctx, 0); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/controller/auth/oidc.go b/controller/auth/oidc.go index 7b4ad4b9..1c4eedbe 100644 --- a/controller/auth/oidc.go +++ b/controller/auth/oidc.go @@ -5,15 +5,17 @@ import ( "encoding/json" "errors" "fmt" + "net/http" + "strconv" + "time" + "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" - "net/http" - "strconv" - "time" ) type OidcResponse struct { @@ -87,6 +89,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { } func OidcAuth(c *gin.Context) { + ctx := c.Request.Context() session := sessions.Default(c) state := c.Query("state") if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { @@ -142,7 +145,7 @@ func OidcAuth(c *gin.Context) { } else { user.DisplayName = "OIDC User" } - err := user.Insert(0) + err := user.Insert(ctx, 0) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/auth/wechat.go b/controller/auth/wechat.go index a561aec0..9c30b8f0 100644 --- a/controller/auth/wechat.go +++ b/controller/auth/wechat.go @@ -4,14 +4,16 @@ import ( "encoding/json" "errors" "fmt" + "net/http" + "strconv" + "time" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" - "net/http" - "strconv" - "time" ) type wechatLoginResponse struct { @@ -52,6 +54,7 @@ func getWeChatIdByCode(code string) (string, error) { } func WeChatAuth(c *gin.Context) { + ctx := c.Request.Context() if !config.WeChatAuthEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员未开启通过微信登录以及注册", @@ -87,7 +90,7 @@ func WeChatAuth(c *gin.Context) { user.Role = model.RoleCommonUser user.Status = model.UserStatusEnabled - if err := user.Insert(0); err != nil { + if err := user.Insert(ctx, 0); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/controller/user.go b/controller/user.go index e79881c2..43e529e5 100644 --- a/controller/user.go +++ b/controller/user.go @@ -109,6 +109,7 @@ func Logout(c *gin.Context) { } func Register(c *gin.Context) { + ctx := c.Request.Context() if !config.RegisterEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员关闭了新用户注册", @@ -166,7 +167,7 @@ func Register(c *gin.Context) { if config.EmailVerificationEnabled { cleanUser.Email = user.Email } - if err := cleanUser.Insert(inviterId); err != nil { + if err := cleanUser.Insert(ctx, inviterId); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -362,6 +363,7 @@ func GetSelf(c *gin.Context) { } func UpdateUser(c *gin.Context) { + ctx := c.Request.Context() var updatedUser model.User err := json.NewDecoder(c.Request.Body).Decode(&updatedUser) if err != nil || updatedUser.Id == 0 { @@ -416,7 +418,7 @@ func UpdateUser(c *gin.Context) { return } if originUser.Quota != updatedUser.Quota { - model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) + model.RecordLog(ctx, originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) } c.JSON(http.StatusOK, gin.H{ "success": true, @@ -535,6 +537,7 @@ func DeleteSelf(c *gin.Context) { } func CreateUser(c *gin.Context) { + ctx := c.Request.Context() var user model.User err := json.NewDecoder(c.Request.Body).Decode(&user) if err != nil || user.Username == "" || user.Password == "" { @@ -568,7 +571,7 @@ func CreateUser(c *gin.Context) { Password: user.Password, DisplayName: user.DisplayName, } - if err := cleanUser.Insert(0); err != nil { + if err := cleanUser.Insert(ctx, 0); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -747,6 +750,7 @@ type topUpRequest struct { } func TopUp(c *gin.Context) { + ctx := c.Request.Context() req := topUpRequest{} err := c.ShouldBindJSON(&req) if err != nil { @@ -757,7 +761,7 @@ func TopUp(c *gin.Context) { return } id := c.GetInt("id") - quota, err := model.Redeem(req.Key, id) + quota, err := model.Redeem(ctx, req.Key, id) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -780,6 +784,7 @@ type adminTopUpRequest struct { } func AdminTopUp(c *gin.Context) { + ctx := c.Request.Context() req := adminTopUpRequest{} err := c.ShouldBindJSON(&req) if err != nil { @@ -800,7 +805,7 @@ func AdminTopUp(c *gin.Context) { if req.Remark == "" { req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) } - model.RecordTopupLog(req.UserId, req.Remark, req.Quota) + model.RecordTopupLog(ctx, req.UserId, req.Remark, req.Quota) c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", diff --git a/middleware/request-id.go b/middleware/request-id.go index bef09e32..973a63f8 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -1,8 +1,8 @@ package middleware import ( - "context" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/helper" ) @@ -10,7 +10,7 @@ func RequestId() func(c *gin.Context) { return func(c *gin.Context) { id := helper.GenRequestID() c.Set(helper.RequestIdKey, id) - ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id) + ctx := helper.SetRequestID(c.Request.Context(), id) c.Request = c.Request.WithContext(ctx) c.Header(helper.RequestIdKey, id) c.Next() diff --git a/model/log.go b/model/log.go index 58fdd513..1fd7ee84 100644 --- a/model/log.go +++ b/model/log.go @@ -4,11 +4,12 @@ import ( "context" "fmt" + "gorm.io/gorm" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "gorm.io/gorm" ) type Log struct { @@ -24,6 +25,7 @@ type Log struct { PromptTokens int `json:"prompt_tokens" gorm:"default:0"` CompletionTokens int `json:"completion_tokens" gorm:"default:0"` ChannelId int `json:"channel" gorm:"index"` + RequestId string `json:"request_id"` } const ( @@ -34,7 +36,18 @@ const ( LogTypeSystem ) -func RecordLog(userId int, logType int, content string) { +func recordLogHelper(ctx context.Context, log *Log) { + requestId := helper.GetRequestID(ctx) + log.RequestId = requestId + err := LOG_DB.Create(log).Error + if err != nil { + logger.Error(ctx, "failed to record log: "+err.Error()) + return + } + logger.Infof(ctx, "record log: %+v", log) +} + +func RecordLog(ctx context.Context, userId int, logType int, content string) { if logType == LogTypeConsume && !config.LogConsumeEnabled { return } @@ -45,13 +58,10 @@ func RecordLog(userId int, logType int, content string) { Type: logType, Content: content, } - err := LOG_DB.Create(log).Error - if err != nil { - logger.SysError("failed to record log: " + err.Error()) - } + recordLogHelper(ctx, log) } -func RecordTopupLog(userId int, content string, quota int) { +func RecordTopupLog(ctx context.Context, userId int, content string, quota int) { log := &Log{ UserId: userId, Username: GetUsernameById(userId), @@ -60,14 +70,10 @@ func RecordTopupLog(userId int, content string, quota int) { Content: content, Quota: quota, } - err := LOG_DB.Create(log).Error - if err != nil { - logger.SysError("failed to record log: " + err.Error()) - } + recordLogHelper(ctx, log) } func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { - logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !config.LogConsumeEnabled { return } @@ -84,10 +90,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke Quota: int(quota), ChannelId: channelId, } - err := LOG_DB.Create(log).Error - if err != nil { - logger.Error(ctx, "failed to record log: "+err.Error()) - } + recordLogHelper(ctx, log) } func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { diff --git a/model/redemption.go b/model/redemption.go index 45871a71..957a33be 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -1,11 +1,14 @@ package model import ( + "context" "errors" "fmt" + + "gorm.io/gorm" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" - "gorm.io/gorm" ) const ( @@ -48,7 +51,7 @@ func GetRedemptionById(id int) (*Redemption, error) { return &redemption, err } -func Redeem(key string, userId int) (quota int64, err error) { +func Redeem(ctx context.Context, key string, userId int) (quota int64, err error) { if key == "" { return 0, errors.New("未提供兑换码") } @@ -82,7 +85,7 @@ func Redeem(key string, userId int) (quota int64, err error) { if err != nil { return 0, errors.New("兑换失败," + err.Error()) } - RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) + RecordLog(ctx, userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) return redemption.Quota, nil } diff --git a/model/user.go b/model/user.go index a964a0d7..a619c901 100644 --- a/model/user.go +++ b/model/user.go @@ -1,16 +1,19 @@ package model import ( + "context" "errors" "fmt" + "strings" + + "gorm.io/gorm" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/random" - "gorm.io/gorm" - "strings" ) const ( @@ -114,7 +117,7 @@ func DeleteUserById(id int) (err error) { return user.Delete() } -func (user *User) Insert(inviterId int) error { +func (user *User) Insert(ctx context.Context, inviterId int) error { var err error if user.Password != "" { user.Password, err = common.Password2Hash(user.Password) @@ -130,16 +133,16 @@ func (user *User) Insert(inviterId int) error { return result.Error } if config.QuotaForNewUser > 0 { - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) + RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) } if inviterId != 0 { if config.QuotaForInvitee > 0 { _ = IncreaseUserQuota(user.Id, config.QuotaForInvitee) - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) + RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) } if config.QuotaForInviter > 0 { _ = IncreaseUserQuota(inviterId, config.QuotaForInviter) - RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) + RecordLog(ctx, inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) } } // create default token