feat: record request_id

This commit is contained in:
JustSong 2025-01-31 17:54:04 +08:00
parent 1fead8e7f7
commit d0402f9086
11 changed files with 96 additions and 55 deletions

View File

@ -1,9 +1,8 @@
package helper package helper
import ( import (
"context"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/random"
"html/template" "html/template"
"log" "log"
"net" "net"
@ -11,6 +10,10 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/random"
) )
func OpenBrowser(url string) { func OpenBrowser(url string) {
@ -106,6 +109,18 @@ func GenRequestID() string {
return GetTimeString() + random.GetRandomNumberString(8) return GetTimeString() + random.GetRandomNumberString(8)
} }
func SetRequestID(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, RequestIdKey, id)
}
func GetRequestID(ctx context.Context) string {
rawRequestId := ctx.Value(RequestIdKey)
if rawRequestId == nil {
return ""
}
return rawRequestId.(string)
}
func GetResponseID(c *gin.Context) string { func GetResponseID(c *gin.Context) string {
logID := c.GetString(RequestIdKey) logID := c.GetString(RequestIdKey)
return fmt.Sprintf("chatcmpl-%s", logID) return fmt.Sprintf("chatcmpl-%s", logID)

View File

@ -113,16 +113,16 @@ func logHelper(ctx context.Context, level loggerLevel, msg string) {
if level == loggerINFO { if level == loggerINFO {
writer = gin.DefaultWriter writer = gin.DefaultWriter
} }
var logId string var requestId string
if ctx != nil { if ctx != nil {
rawLogId := ctx.Value(helper.RequestIdKey) rawRequestId := helper.GetRequestID(ctx)
if rawLogId != nil { if rawRequestId != "" {
logId = fmt.Sprintf(" | %s", rawLogId.(string)) requestId = fmt.Sprintf(" | %s", rawRequestId)
} }
} }
lineInfo, funcName := getLineInfo() lineInfo, funcName := getLineInfo()
now := time.Now() now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v%s%s %s%s \n", level, now.Format("2006/01/02 - 15:04:05"), logId, lineInfo, funcName, msg) _, _ = fmt.Fprintf(writer, "[%s] %v%s%s %s%s \n", level, now.Format("2006/01/02 - 15:04:05"), requestId, lineInfo, funcName, msg)
SetupLogger() SetupLogger()
if level == loggerFatal { if level == loggerFatal {
os.Exit(1) os.Exit(1)

View File

@ -5,16 +5,18 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
"strconv"
"time"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
) )
type GitHubOAuthResponse struct { type GitHubOAuthResponse struct {
@ -81,6 +83,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
} }
func GitHubOAuth(c *gin.Context) { func GitHubOAuth(c *gin.Context) {
ctx := c.Request.Context()
session := sessions.Default(c) session := sessions.Default(c)
state := c.Query("state") state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
@ -136,7 +139,7 @@ func GitHubOAuth(c *gin.Context) {
user.Role = model.RoleCommonUser user.Role = model.RoleCommonUser
user.Status = model.UserStatusEnabled user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(ctx, 0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

View File

@ -5,15 +5,17 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
"strconv"
"time"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
) )
type LarkOAuthResponse struct { type LarkOAuthResponse struct {
@ -79,6 +81,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
} }
func LarkOAuth(c *gin.Context) { func LarkOAuth(c *gin.Context) {
ctx := c.Request.Context()
session := sessions.Default(c) session := sessions.Default(c)
state := c.Query("state") state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
@ -125,7 +128,7 @@ func LarkOAuth(c *gin.Context) {
user.Role = model.RoleCommonUser user.Role = model.RoleCommonUser
user.Status = model.UserStatusEnabled user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(ctx, 0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

View File

@ -5,15 +5,17 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
"strconv"
"time"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
) )
type OidcResponse struct { type OidcResponse struct {
@ -87,6 +89,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
} }
func OidcAuth(c *gin.Context) { func OidcAuth(c *gin.Context) {
ctx := c.Request.Context()
session := sessions.Default(c) session := sessions.Default(c)
state := c.Query("state") state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
@ -142,7 +145,7 @@ func OidcAuth(c *gin.Context) {
} else { } else {
user.DisplayName = "OIDC User" user.DisplayName = "OIDC User"
} }
err := user.Insert(0) err := user.Insert(ctx, 0)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,

View File

@ -4,14 +4,16 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
) )
type wechatLoginResponse struct { type wechatLoginResponse struct {
@ -52,6 +54,7 @@ func getWeChatIdByCode(code string) (string, error) {
} }
func WeChatAuth(c *gin.Context) { func WeChatAuth(c *gin.Context) {
ctx := c.Request.Context()
if !config.WeChatAuthEnabled { if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册", "message": "管理员未开启通过微信登录以及注册",
@ -87,7 +90,7 @@ func WeChatAuth(c *gin.Context) {
user.Role = model.RoleCommonUser user.Role = model.RoleCommonUser
user.Status = model.UserStatusEnabled user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(ctx, 0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

View File

@ -109,6 +109,7 @@ func Logout(c *gin.Context) {
} }
func Register(c *gin.Context) { func Register(c *gin.Context) {
ctx := c.Request.Context()
if !config.RegisterEnabled { if !config.RegisterEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册", "message": "管理员关闭了新用户注册",
@ -166,7 +167,7 @@ func Register(c *gin.Context) {
if config.EmailVerificationEnabled { if config.EmailVerificationEnabled {
cleanUser.Email = user.Email cleanUser.Email = user.Email
} }
if err := cleanUser.Insert(inviterId); err != nil { if err := cleanUser.Insert(ctx, inviterId); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
@ -362,6 +363,7 @@ func GetSelf(c *gin.Context) {
} }
func UpdateUser(c *gin.Context) { func UpdateUser(c *gin.Context) {
ctx := c.Request.Context()
var updatedUser model.User var updatedUser model.User
err := json.NewDecoder(c.Request.Body).Decode(&updatedUser) err := json.NewDecoder(c.Request.Body).Decode(&updatedUser)
if err != nil || updatedUser.Id == 0 { if err != nil || updatedUser.Id == 0 {
@ -416,7 +418,7 @@ func UpdateUser(c *gin.Context) {
return return
} }
if originUser.Quota != updatedUser.Quota { if originUser.Quota != updatedUser.Quota {
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) model.RecordLog(ctx, originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
@ -535,6 +537,7 @@ func DeleteSelf(c *gin.Context) {
} }
func CreateUser(c *gin.Context) { func CreateUser(c *gin.Context) {
ctx := c.Request.Context()
var user model.User var user model.User
err := json.NewDecoder(c.Request.Body).Decode(&user) err := json.NewDecoder(c.Request.Body).Decode(&user)
if err != nil || user.Username == "" || user.Password == "" { if err != nil || user.Username == "" || user.Password == "" {
@ -568,7 +571,7 @@ func CreateUser(c *gin.Context) {
Password: user.Password, Password: user.Password,
DisplayName: user.DisplayName, DisplayName: user.DisplayName,
} }
if err := cleanUser.Insert(0); err != nil { if err := cleanUser.Insert(ctx, 0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
@ -747,6 +750,7 @@ type topUpRequest struct {
} }
func TopUp(c *gin.Context) { func TopUp(c *gin.Context) {
ctx := c.Request.Context()
req := topUpRequest{} req := topUpRequest{}
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
@ -757,7 +761,7 @@ func TopUp(c *gin.Context) {
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
quota, err := model.Redeem(req.Key, id) quota, err := model.Redeem(ctx, req.Key, id)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -780,6 +784,7 @@ type adminTopUpRequest struct {
} }
func AdminTopUp(c *gin.Context) { func AdminTopUp(c *gin.Context) {
ctx := c.Request.Context()
req := adminTopUpRequest{} req := adminTopUpRequest{}
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
@ -800,7 +805,7 @@ func AdminTopUp(c *gin.Context) {
if req.Remark == "" { if req.Remark == "" {
req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota)))
} }
model.RecordTopupLog(req.UserId, req.Remark, req.Quota) model.RecordTopupLog(ctx, req.UserId, req.Remark, req.Quota)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",

View File

@ -1,8 +1,8 @@
package middleware package middleware
import ( import (
"context"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
) )
@ -10,7 +10,7 @@ func RequestId() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
id := helper.GenRequestID() id := helper.GenRequestID()
c.Set(helper.RequestIdKey, id) c.Set(helper.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id) ctx := helper.SetRequestID(c.Request.Context(), id)
c.Request = c.Request.WithContext(ctx) c.Request = c.Request.WithContext(ctx)
c.Header(helper.RequestIdKey, id) c.Header(helper.RequestIdKey, id)
c.Next() c.Next()

View File

@ -4,11 +4,12 @@ import (
"context" "context"
"fmt" "fmt"
"gorm.io/gorm"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
) )
type Log struct { type Log struct {
@ -24,6 +25,7 @@ type Log struct {
PromptTokens int `json:"prompt_tokens" gorm:"default:0"` PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"` CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
ChannelId int `json:"channel" gorm:"index"` ChannelId int `json:"channel" gorm:"index"`
RequestId string `json:"request_id"`
} }
const ( const (
@ -34,7 +36,18 @@ const (
LogTypeSystem LogTypeSystem
) )
func RecordLog(userId int, logType int, content string) { func recordLogHelper(ctx context.Context, log *Log) {
requestId := helper.GetRequestID(ctx)
log.RequestId = requestId
err := LOG_DB.Create(log).Error
if err != nil {
logger.Error(ctx, "failed to record log: "+err.Error())
return
}
logger.Infof(ctx, "record log: %+v", log)
}
func RecordLog(ctx context.Context, userId int, logType int, content string) {
if logType == LogTypeConsume && !config.LogConsumeEnabled { if logType == LogTypeConsume && !config.LogConsumeEnabled {
return return
} }
@ -45,13 +58,10 @@ func RecordLog(userId int, logType int, content string) {
Type: logType, Type: logType,
Content: content, Content: content,
} }
err := LOG_DB.Create(log).Error recordLogHelper(ctx, log)
if err != nil {
logger.SysError("failed to record log: " + err.Error())
}
} }
func RecordTopupLog(userId int, content string, quota int) { func RecordTopupLog(ctx context.Context, userId int, content string, quota int) {
log := &Log{ log := &Log{
UserId: userId, UserId: userId,
Username: GetUsernameById(userId), Username: GetUsernameById(userId),
@ -60,14 +70,10 @@ func RecordTopupLog(userId int, content string, quota int) {
Content: content, Content: content,
Quota: quota, Quota: quota,
} }
err := LOG_DB.Create(log).Error recordLogHelper(ctx, log)
if err != nil {
logger.SysError("failed to record log: " + err.Error())
}
} }
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !config.LogConsumeEnabled { if !config.LogConsumeEnabled {
return return
} }
@ -84,10 +90,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
Quota: int(quota), Quota: int(quota),
ChannelId: channelId, ChannelId: channelId,
} }
err := LOG_DB.Create(log).Error recordLogHelper(ctx, log)
if err != nil {
logger.Error(ctx, "failed to record log: "+err.Error())
}
} }
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {

View File

@ -1,11 +1,14 @@
package model package model
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"gorm.io/gorm"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"gorm.io/gorm"
) )
const ( const (
@ -48,7 +51,7 @@ func GetRedemptionById(id int) (*Redemption, error) {
return &redemption, err return &redemption, err
} }
func Redeem(key string, userId int) (quota int64, err error) { func Redeem(ctx context.Context, key string, userId int) (quota int64, err error) {
if key == "" { if key == "" {
return 0, errors.New("未提供兑换码") return 0, errors.New("未提供兑换码")
} }
@ -82,7 +85,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
if err != nil { if err != nil {
return 0, errors.New("兑换失败," + err.Error()) return 0, errors.New("兑换失败," + err.Error())
} }
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) RecordLog(ctx, userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
return redemption.Quota, nil return redemption.Quota, nil
} }

View File

@ -1,16 +1,19 @@
package model package model
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strings"
"gorm.io/gorm"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/common/random"
"gorm.io/gorm"
"strings"
) )
const ( const (
@ -114,7 +117,7 @@ func DeleteUserById(id int) (err error) {
return user.Delete() return user.Delete()
} }
func (user *User) Insert(inviterId int) error { func (user *User) Insert(ctx context.Context, inviterId int) error {
var err error var err error
if user.Password != "" { if user.Password != "" {
user.Password, err = common.Password2Hash(user.Password) user.Password, err = common.Password2Hash(user.Password)
@ -130,16 +133,16 @@ func (user *User) Insert(inviterId int) error {
return result.Error return result.Error
} }
if config.QuotaForNewUser > 0 { if config.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
} }
if inviterId != 0 { if inviterId != 0 {
if config.QuotaForInvitee > 0 { if config.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee) _ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
} }
if config.QuotaForInviter > 0 { if config.QuotaForInviter > 0 {
_ = IncreaseUserQuota(inviterId, config.QuotaForInviter) _ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) RecordLog(ctx, inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
} }
} }
// create default token // create default token