feat: able to fetch every request's cost

This commit is contained in:
Laisky.Cai 2024-04-23 00:58:25 +00:00
parent 7047d9605e
commit 84a6817314
14 changed files with 119 additions and 19 deletions

View File

@ -2,6 +2,7 @@ package ctxkey
const ( const (
Id = "id" Id = "id"
RequestId = "X-Oneapi-Request-Id"
Username = "username" Username = "username"
Role = "role" Role = "role"
Status = "status" Status = "status"

View File

@ -1,7 +1,3 @@
package logger package logger
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
var LogDir string var LogDir string

View File

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/Laisky/one-api/common/config" "github.com/Laisky/one-api/common/config"
"github.com/Laisky/one-api/common/ctxkey"
"github.com/Laisky/one-api/common/helper" "github.com/Laisky/one-api/common/helper"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -87,7 +88,7 @@ func logHelper(ctx context.Context, level string, msg string) {
if level == loggerINFO { if level == loggerINFO {
writer = gin.DefaultWriter writer = gin.DefaultWriter
} }
id := ctx.Value(RequestIdKey) id := ctx.Value(ctxkey.RequestId)
if id == nil { if id == nil {
id = helper.GenRequestID() id = helper.GenRequestID()
} }

View File

@ -59,7 +59,7 @@ func Relay(c *gin.Context) {
group := c.GetString(ctxkey.Group) group := c.GetString(ctxkey.Group)
originalModel := c.GetString(ctxkey.OriginalModel) originalModel := c.GetString(ctxkey.OriginalModel)
go processChannelRelayError(ctx, channelId, channelName, bizErr) go processChannelRelayError(ctx, channelId, channelName, bizErr)
requestId := c.GetString(logger.RequestIdKey) requestId := c.GetString(ctxkey.RequestId)
retryTimes := config.RetryTimes retryTimes := config.RetryTimes
if err := shouldRetry(c, bizErr.StatusCode); err != nil { if err := shouldRetry(c, bizErr.StatusCode); err != nil {
logger.Errorf(ctx, "relay error happen, won't retry since of %v", err.Error()) logger.Errorf(ctx, "relay error happen, won't retry since of %v", err.Error())

View File

@ -16,6 +16,28 @@ import (
"github.com/jinzhu/copier" "github.com/jinzhu/copier"
) )
func GetRequestCost(c *gin.Context) {
reqId := c.Param("request_id")
if reqId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "request_id 不能为空",
})
return
}
docu, err := model.GetCostByRequestId(reqId)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, docu)
}
func GetAllTokens(c *gin.Context) { func GetAllTokens(c *gin.Context) {
userId := c.GetInt(ctxkey.Id) userId := c.GetInt(ctxkey.Id)
p, _ := strconv.Atoi(c.Query("p")) p, _ := strconv.Atoi(c.Query("p"))

View File

@ -6,6 +6,7 @@ import (
"strconv" "strconv"
"strings" "strings"
gutils "github.com/Laisky/go-utils/v4"
"github.com/Laisky/one-api/common/ctxkey" "github.com/Laisky/one-api/common/ctxkey"
"github.com/Laisky/one-api/common/logger" "github.com/Laisky/one-api/common/logger"
"github.com/Laisky/one-api/model" "github.com/Laisky/one-api/model"
@ -75,6 +76,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set(ctxkey.ChannelRatio, minimalRatio) c.Set(ctxkey.ChannelRatio, minimalRatio)
c.Set(ctxkey.ChannelModel, channel) c.Set(ctxkey.ChannelModel, channel)
// generate an unique cost id for each request
if _, ok := c.Get(ctxkey.RequestId); !ok {
c.Set(ctxkey.RequestId, gutils.UUID7())
}
c.Set(ctxkey.Channel, channel.Type) c.Set(ctxkey.Channel, channel.Type)
c.Set(ctxkey.ChannelId, channel.Id) c.Set(ctxkey.ChannelId, channel.Id)
c.Set(ctxkey.ChannelName, channel.Name) c.Set(ctxkey.ChannelName, channel.Name)

View File

@ -2,7 +2,8 @@ package middleware
import ( import (
"fmt" "fmt"
"github.com/Laisky/one-api/common/logger"
"github.com/Laisky/one-api/common/ctxkey"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -10,7 +11,7 @@ func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string var requestID string
if param.Keys != nil { if param.Keys != nil {
requestID = param.Keys[logger.RequestIdKey].(string) requestID = param.Keys[ctxkey.RequestId].(string)
} }
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"), param.TimeStamp.Format("2006/01/02 - 15:04:05"),

View File

@ -2,18 +2,19 @@ package middleware
import ( import (
"context" "context"
"github.com/Laisky/one-api/common/ctxkey"
"github.com/Laisky/one-api/common/helper" "github.com/Laisky/one-api/common/helper"
"github.com/Laisky/one-api/common/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func RequestId() func(c *gin.Context) { func RequestId() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
id := helper.GenRequestID() id := helper.GenRequestID()
c.Set(logger.RequestIdKey, id) c.Set(ctxkey.RequestId, id)
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) ctx := context.WithValue(c.Request.Context(), ctxkey.RequestId, id)
c.Request = c.Request.WithContext(ctx) c.Request = c.Request.WithContext(ctx)
c.Header(logger.RequestIdKey, id) c.Header(ctxkey.RequestId, id)
c.Next() c.Next()
} }
} }

View File

@ -2,17 +2,19 @@ package middleware
import ( import (
"fmt" "fmt"
"strings"
"github.com/Laisky/one-api/common" "github.com/Laisky/one-api/common"
"github.com/Laisky/one-api/common/ctxkey"
"github.com/Laisky/one-api/common/helper" "github.com/Laisky/one-api/common/helper"
"github.com/Laisky/one-api/common/logger" "github.com/Laisky/one-api/common/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"strings"
) )
func abortWithMessage(c *gin.Context, statusCode int, message string) { func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{ c.JSON(statusCode, gin.H{
"error": gin.H{ "error": gin.H{
"message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)), "message": helper.MessageWithRequestId(message, c.GetString(ctxkey.RequestId)),
"type": "one_api_error", "type": "one_api_error",
}, },
}) })

47
model/cost.go Normal file
View File

@ -0,0 +1,47 @@
package model
import (
"github.com/Laisky/one-api/common/helper"
"github.com/pkg/errors"
)
type UserRequestCost struct {
Id int `json:"id"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
UserID int `json:"user_id"`
RequestID string `json:"request_id"`
Quota int64 `json:"quota"`
CostUSD float64 `json:"cost_usd" gorm:"-"`
}
// NewUserRequestCost create a new UserRequestCost
func NewUserRequestCost(userID int, quotaID string, quota int64) *UserRequestCost {
return &UserRequestCost{
CreatedTime: helper.GetTimestamp(),
UserID: userID,
RequestID: quotaID,
Quota: quota,
}
}
func (docu *UserRequestCost) Insert() error {
var err error
err = DB.Create(docu).Error
return errors.Wrap(err, "failed to insert UserRequestCost")
}
// GetCostByRequestId get cost by request id
func GetCostByRequestId(reqid string) (*UserRequestCost, error) {
if reqid == "" {
return nil, errors.New("request id is empty")
}
docu := &UserRequestCost{RequestID: reqid}
var err error = nil
if err = DB.First(docu, "request_id = ?", reqid).Error; err != nil {
return nil, errors.Wrap(err, "failed to get cost by request id")
}
docu.CostUSD = float64(docu.Quota) / 500000
return docu, nil
}

View File

@ -121,6 +121,10 @@ func InitDB(envName string) (db *gorm.DB, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = db.AutoMigrate(&UserRequestCost{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Token{}) err = db.AutoMigrate(&Token{})
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -160,12 +160,11 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
return preConsumedQuota, nil return preConsumedQuota, nil
} }
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) (quota int64) {
if usage == nil { if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected") logger.Error(ctx, "usage is nil, which is unexpected")
return return
} }
var quota int64
completionRatio := billingratio.GetCompletionRatio(textRequest.Model) completionRatio := billingratio.GetCompletionRatio(textRequest.Model)
promptTokens := usage.PromptTokens promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens completionTokens := usage.CompletionTokens
@ -193,6 +192,8 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
model.UpdateChannelUsedQuota(meta.ChannelId, quota) model.UpdateChannelUsedQuota(meta.ChannelId, quota)
return
} }
func getMappedModelName(modelName string, mapping map[string]string) (string, bool) { func getMappedModelName(modelName string, mapping map[string]string) (string, bool) {

View File

@ -8,7 +8,9 @@ import (
"strings" "strings"
"github.com/Laisky/errors/v2" "github.com/Laisky/errors/v2"
"github.com/Laisky/one-api/common/ctxkey"
"github.com/Laisky/one-api/common/logger" "github.com/Laisky/one-api/common/logger"
"github.com/Laisky/one-api/model"
"github.com/Laisky/one-api/relay" "github.com/Laisky/one-api/relay"
"github.com/Laisky/one-api/relay/adaptor/openai" "github.com/Laisky/one-api/relay/adaptor/openai"
"github.com/Laisky/one-api/relay/apitype" "github.com/Laisky/one-api/relay/apitype"
@ -16,11 +18,11 @@ import (
billingratio "github.com/Laisky/one-api/relay/billing/ratio" billingratio "github.com/Laisky/one-api/relay/billing/ratio"
"github.com/Laisky/one-api/relay/channeltype" "github.com/Laisky/one-api/relay/channeltype"
"github.com/Laisky/one-api/relay/meta" "github.com/Laisky/one-api/relay/meta"
"github.com/Laisky/one-api/relay/model" relaymodel "github.com/Laisky/one-api/relay/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { func RelayTextHelper(c *gin.Context) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context() ctx := c.Request.Context()
meta := meta.GetByContext(c) meta := meta.GetByContext(c)
// get & validate textRequest // get & validate textRequest
@ -108,7 +110,19 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return respErr return respErr
} }
// post-consume quota // post-consume quota
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) go func() {
quota := postConsumeQuota(c, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
docu := model.NewUserRequestCost(
c.GetInt(ctxkey.Id),
c.GetString(ctxkey.RequestId),
quota,
)
if err = docu.Insert(); err != nil {
logger.Errorf(c, "insert user request cost failed: %+v", err)
}
}()
return nil return nil
} }

View File

@ -94,6 +94,10 @@ func SetApiRouter(router *gin.Engine) {
tokenRoute.PUT("/", controller.UpdateToken) tokenRoute.PUT("/", controller.UpdateToken)
tokenRoute.DELETE("/:id", controller.DeleteToken) tokenRoute.DELETE("/:id", controller.DeleteToken)
} }
costRoute := apiRouter.Group("/cost")
{
costRoute.GET("/request/:request_id", controller.GetRequestCost)
}
redemptionRoute := apiRouter.Group("/redemption") redemptionRoute := apiRouter.Group("/redemption")
redemptionRoute.Use(middleware.AdminAuth()) redemptionRoute.Use(middleware.AdminAuth())
{ {