diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 5e22d282..574bed16 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -2,6 +2,7 @@ package ctxkey const ( Id = "id" + RequestId = "X-Oneapi-Request-Id" Username = "username" Role = "role" Status = "status" diff --git a/common/logger/constants.go b/common/logger/constants.go index 78d32062..49df31ec 100644 --- a/common/logger/constants.go +++ b/common/logger/constants.go @@ -1,7 +1,3 @@ package logger -const ( - RequestIdKey = "X-Oneapi-Request-Id" -) - var LogDir string diff --git a/common/logger/logger.go b/common/logger/logger.go index bd30246b..ecb108c3 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -11,6 +11,7 @@ import ( "time" "github.com/Laisky/one-api/common/config" + "github.com/Laisky/one-api/common/ctxkey" "github.com/Laisky/one-api/common/helper" "github.com/gin-gonic/gin" ) @@ -87,7 +88,7 @@ func logHelper(ctx context.Context, level string, msg string) { if level == loggerINFO { writer = gin.DefaultWriter } - id := ctx.Value(RequestIdKey) + id := ctx.Value(ctxkey.RequestId) if id == nil { id = helper.GenRequestID() } diff --git a/controller/relay.go b/controller/relay.go index 81906691..967a9be8 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -59,7 +59,7 @@ func Relay(c *gin.Context) { group := c.GetString(ctxkey.Group) originalModel := c.GetString(ctxkey.OriginalModel) go processChannelRelayError(ctx, channelId, channelName, bizErr) - requestId := c.GetString(logger.RequestIdKey) + requestId := c.GetString(ctxkey.RequestId) retryTimes := config.RetryTimes if err := shouldRetry(c, bizErr.StatusCode); err != nil { logger.Errorf(ctx, "relay error happen, won't retry since of %v", err.Error()) diff --git a/controller/token.go b/controller/token.go index 80620394..235be6be 100644 --- a/controller/token.go +++ b/controller/token.go @@ -16,6 +16,28 @@ import ( "github.com/jinzhu/copier" ) +func GetRequestCost(c *gin.Context) { + reqId := c.Param("request_id") + if reqId == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "request_id 不能为空", + }) + return + } + + docu, err := model.GetCostByRequestId(reqId) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, docu) +} + func GetAllTokens(c *gin.Context) { userId := c.GetInt(ctxkey.Id) p, _ := strconv.Atoi(c.Query("p")) diff --git a/middleware/distributor.go b/middleware/distributor.go index b3c99316..85dea881 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -6,6 +6,7 @@ import ( "strconv" "strings" + gutils "github.com/Laisky/go-utils/v4" "github.com/Laisky/one-api/common/ctxkey" "github.com/Laisky/one-api/common/logger" "github.com/Laisky/one-api/model" @@ -75,6 +76,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set(ctxkey.ChannelRatio, minimalRatio) c.Set(ctxkey.ChannelModel, channel) + // generate an unique cost id for each request + if _, ok := c.Get(ctxkey.RequestId); !ok { + c.Set(ctxkey.RequestId, gutils.UUID7()) + } + c.Set(ctxkey.Channel, channel.Type) c.Set(ctxkey.ChannelId, channel.Id) c.Set(ctxkey.ChannelName, channel.Name) diff --git a/middleware/logger.go b/middleware/logger.go index 25fd1b34..496b3ab1 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -2,7 +2,8 @@ package middleware import ( "fmt" - "github.com/Laisky/one-api/common/logger" + + "github.com/Laisky/one-api/common/ctxkey" "github.com/gin-gonic/gin" ) @@ -10,7 +11,7 @@ func SetUpLogger(server *gin.Engine) { server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { var requestID string if param.Keys != nil { - requestID = param.Keys[logger.RequestIdKey].(string) + requestID = param.Keys[ctxkey.RequestId].(string) } return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", param.TimeStamp.Format("2006/01/02 - 15:04:05"), diff --git a/middleware/request-id.go b/middleware/request-id.go index 3042f608..447f5e3f 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -2,18 +2,19 @@ package middleware import ( "context" + + "github.com/Laisky/one-api/common/ctxkey" "github.com/Laisky/one-api/common/helper" - "github.com/Laisky/one-api/common/logger" "github.com/gin-gonic/gin" ) func RequestId() func(c *gin.Context) { return func(c *gin.Context) { id := helper.GenRequestID() - c.Set(logger.RequestIdKey, id) - ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) + c.Set(ctxkey.RequestId, id) + ctx := context.WithValue(c.Request.Context(), ctxkey.RequestId, id) c.Request = c.Request.WithContext(ctx) - c.Header(logger.RequestIdKey, id) + c.Header(ctxkey.RequestId, id) c.Next() } } diff --git a/middleware/utils.go b/middleware/utils.go index 0bd2a10c..2d717345 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -2,17 +2,19 @@ package middleware import ( "fmt" + "strings" + "github.com/Laisky/one-api/common" + "github.com/Laisky/one-api/common/ctxkey" "github.com/Laisky/one-api/common/helper" "github.com/Laisky/one-api/common/logger" "github.com/gin-gonic/gin" - "strings" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { c.JSON(statusCode, gin.H{ "error": gin.H{ - "message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)), + "message": helper.MessageWithRequestId(message, c.GetString(ctxkey.RequestId)), "type": "one_api_error", }, }) diff --git a/model/cost.go b/model/cost.go new file mode 100644 index 00000000..35118917 --- /dev/null +++ b/model/cost.go @@ -0,0 +1,47 @@ +package model + +import ( + "github.com/Laisky/one-api/common/helper" + "github.com/pkg/errors" +) + +type UserRequestCost struct { + Id int `json:"id"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + UserID int `json:"user_id"` + RequestID string `json:"request_id"` + Quota int64 `json:"quota"` + CostUSD float64 `json:"cost_usd" gorm:"-"` +} + +// NewUserRequestCost create a new UserRequestCost +func NewUserRequestCost(userID int, quotaID string, quota int64) *UserRequestCost { + return &UserRequestCost{ + CreatedTime: helper.GetTimestamp(), + UserID: userID, + RequestID: quotaID, + Quota: quota, + } +} + +func (docu *UserRequestCost) Insert() error { + var err error + err = DB.Create(docu).Error + return errors.Wrap(err, "failed to insert UserRequestCost") +} + +// GetCostByRequestId get cost by request id +func GetCostByRequestId(reqid string) (*UserRequestCost, error) { + if reqid == "" { + return nil, errors.New("request id is empty") + } + + docu := &UserRequestCost{RequestID: reqid} + var err error = nil + if err = DB.First(docu, "request_id = ?", reqid).Error; err != nil { + return nil, errors.Wrap(err, "failed to get cost by request id") + } + + docu.CostUSD = float64(docu.Quota) / 500000 + return docu, nil +} diff --git a/model/main.go b/model/main.go index 34730ddd..c046c875 100644 --- a/model/main.go +++ b/model/main.go @@ -121,6 +121,10 @@ func InitDB(envName string) (db *gorm.DB, err error) { if err != nil { return nil, err } + err = db.AutoMigrate(&UserRequestCost{}) + if err != nil { + return nil, err + } err = db.AutoMigrate(&Token{}) if err != nil { return nil, err diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 62054049..66bcb8ac 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -160,12 +160,11 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR return preConsumedQuota, nil } -func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { +func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) (quota int64) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return } - var quota int64 completionRatio := billingratio.GetCompletionRatio(textRequest.Model) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens @@ -193,6 +192,8 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateChannelUsedQuota(meta.ChannelId, quota) + + return } func getMappedModelName(modelName string, mapping map[string]string) (string, bool) { diff --git a/relay/controller/text.go b/relay/controller/text.go index 6e700c25..e38b3aa7 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -8,7 +8,9 @@ import ( "strings" "github.com/Laisky/errors/v2" + "github.com/Laisky/one-api/common/ctxkey" "github.com/Laisky/one-api/common/logger" + "github.com/Laisky/one-api/model" "github.com/Laisky/one-api/relay" "github.com/Laisky/one-api/relay/adaptor/openai" "github.com/Laisky/one-api/relay/apitype" @@ -16,11 +18,11 @@ import ( billingratio "github.com/Laisky/one-api/relay/billing/ratio" "github.com/Laisky/one-api/relay/channeltype" "github.com/Laisky/one-api/relay/meta" - "github.com/Laisky/one-api/relay/model" + relaymodel "github.com/Laisky/one-api/relay/model" "github.com/gin-gonic/gin" ) -func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { +func RelayTextHelper(c *gin.Context) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() meta := meta.GetByContext(c) // get & validate textRequest @@ -108,7 +110,19 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) return respErr } + // post-consume quota - go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) + go func() { + quota := postConsumeQuota(c, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) + docu := model.NewUserRequestCost( + c.GetInt(ctxkey.Id), + c.GetString(ctxkey.RequestId), + quota, + ) + if err = docu.Insert(); err != nil { + logger.Errorf(c, "insert user request cost failed: %+v", err) + } + }() + return nil } diff --git a/router/api.go b/router/api.go index 7d84e301..4575c805 100644 --- a/router/api.go +++ b/router/api.go @@ -94,6 +94,10 @@ func SetApiRouter(router *gin.Engine) { tokenRoute.PUT("/", controller.UpdateToken) tokenRoute.DELETE("/:id", controller.DeleteToken) } + costRoute := apiRouter.Group("/cost") + { + costRoute.GET("/request/:request_id", controller.GetRequestCost) + } redemptionRoute := apiRouter.Group("/redemption") redemptionRoute.Use(middleware.AdminAuth()) {