feat: able to fetch every request's cost

This commit is contained in:
Laisky.Cai 2024-04-23 00:58:25 +00:00
parent 7047d9605e
commit 84a6817314
14 changed files with 119 additions and 19 deletions

View File

@ -2,6 +2,7 @@ package ctxkey
const (
Id = "id"
RequestId = "X-Oneapi-Request-Id"
Username = "username"
Role = "role"
Status = "status"

View File

@ -1,7 +1,3 @@
package logger
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
var LogDir string

View File

@ -11,6 +11,7 @@ import (
"time"
"github.com/Laisky/one-api/common/config"
"github.com/Laisky/one-api/common/ctxkey"
"github.com/Laisky/one-api/common/helper"
"github.com/gin-gonic/gin"
)
@ -87,7 +88,7 @@ func logHelper(ctx context.Context, level string, msg string) {
if level == loggerINFO {
writer = gin.DefaultWriter
}
id := ctx.Value(RequestIdKey)
id := ctx.Value(ctxkey.RequestId)
if id == nil {
id = helper.GenRequestID()
}

View File

@ -59,7 +59,7 @@ func Relay(c *gin.Context) {
group := c.GetString(ctxkey.Group)
originalModel := c.GetString(ctxkey.OriginalModel)
go processChannelRelayError(ctx, channelId, channelName, bizErr)
requestId := c.GetString(logger.RequestIdKey)
requestId := c.GetString(ctxkey.RequestId)
retryTimes := config.RetryTimes
if err := shouldRetry(c, bizErr.StatusCode); err != nil {
logger.Errorf(ctx, "relay error happen, won't retry since of %v", err.Error())

View File

@ -16,6 +16,28 @@ import (
"github.com/jinzhu/copier"
)
func GetRequestCost(c *gin.Context) {
reqId := c.Param("request_id")
if reqId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "request_id 不能为空",
})
return
}
docu, err := model.GetCostByRequestId(reqId)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, docu)
}
func GetAllTokens(c *gin.Context) {
userId := c.GetInt(ctxkey.Id)
p, _ := strconv.Atoi(c.Query("p"))

View File

@ -6,6 +6,7 @@ import (
"strconv"
"strings"
gutils "github.com/Laisky/go-utils/v4"
"github.com/Laisky/one-api/common/ctxkey"
"github.com/Laisky/one-api/common/logger"
"github.com/Laisky/one-api/model"
@ -75,6 +76,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set(ctxkey.ChannelRatio, minimalRatio)
c.Set(ctxkey.ChannelModel, channel)
// generate an unique cost id for each request
if _, ok := c.Get(ctxkey.RequestId); !ok {
c.Set(ctxkey.RequestId, gutils.UUID7())
}
c.Set(ctxkey.Channel, channel.Type)
c.Set(ctxkey.ChannelId, channel.Id)
c.Set(ctxkey.ChannelName, channel.Name)

View File

@ -2,7 +2,8 @@ package middleware
import (
"fmt"
"github.com/Laisky/one-api/common/logger"
"github.com/Laisky/one-api/common/ctxkey"
"github.com/gin-gonic/gin"
)
@ -10,7 +11,7 @@ func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string
if param.Keys != nil {
requestID = param.Keys[logger.RequestIdKey].(string)
requestID = param.Keys[ctxkey.RequestId].(string)
}
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),

View File

@ -2,18 +2,19 @@ package middleware
import (
"context"
"github.com/Laisky/one-api/common/ctxkey"
"github.com/Laisky/one-api/common/helper"
"github.com/Laisky/one-api/common/logger"
"github.com/gin-gonic/gin"
)
func RequestId() func(c *gin.Context) {
return func(c *gin.Context) {
id := helper.GenRequestID()
c.Set(logger.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
c.Set(ctxkey.RequestId, id)
ctx := context.WithValue(c.Request.Context(), ctxkey.RequestId, id)
c.Request = c.Request.WithContext(ctx)
c.Header(logger.RequestIdKey, id)
c.Header(ctxkey.RequestId, id)
c.Next()
}
}

View File

@ -2,17 +2,19 @@ package middleware
import (
"fmt"
"strings"
"github.com/Laisky/one-api/common"
"github.com/Laisky/one-api/common/ctxkey"
"github.com/Laisky/one-api/common/helper"
"github.com/Laisky/one-api/common/logger"
"github.com/gin-gonic/gin"
"strings"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)),
"message": helper.MessageWithRequestId(message, c.GetString(ctxkey.RequestId)),
"type": "one_api_error",
},
})

47
model/cost.go Normal file
View File

@ -0,0 +1,47 @@
package model
import (
"github.com/Laisky/one-api/common/helper"
"github.com/pkg/errors"
)
type UserRequestCost struct {
Id int `json:"id"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
UserID int `json:"user_id"`
RequestID string `json:"request_id"`
Quota int64 `json:"quota"`
CostUSD float64 `json:"cost_usd" gorm:"-"`
}
// NewUserRequestCost create a new UserRequestCost
func NewUserRequestCost(userID int, quotaID string, quota int64) *UserRequestCost {
return &UserRequestCost{
CreatedTime: helper.GetTimestamp(),
UserID: userID,
RequestID: quotaID,
Quota: quota,
}
}
func (docu *UserRequestCost) Insert() error {
var err error
err = DB.Create(docu).Error
return errors.Wrap(err, "failed to insert UserRequestCost")
}
// GetCostByRequestId get cost by request id
func GetCostByRequestId(reqid string) (*UserRequestCost, error) {
if reqid == "" {
return nil, errors.New("request id is empty")
}
docu := &UserRequestCost{RequestID: reqid}
var err error = nil
if err = DB.First(docu, "request_id = ?", reqid).Error; err != nil {
return nil, errors.Wrap(err, "failed to get cost by request id")
}
docu.CostUSD = float64(docu.Quota) / 500000
return docu, nil
}

View File

@ -121,6 +121,10 @@ func InitDB(envName string) (db *gorm.DB, err error) {
if err != nil {
return nil, err
}
err = db.AutoMigrate(&UserRequestCost{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Token{})
if err != nil {
return nil, err

View File

@ -160,12 +160,11 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
return preConsumedQuota, nil
}
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) (quota int64) {
if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected")
return
}
var quota int64
completionRatio := billingratio.GetCompletionRatio(textRequest.Model)
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
@ -193,6 +192,8 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
return
}
func getMappedModelName(modelName string, mapping map[string]string) (string, bool) {

View File

@ -8,7 +8,9 @@ import (
"strings"
"github.com/Laisky/errors/v2"
"github.com/Laisky/one-api/common/ctxkey"
"github.com/Laisky/one-api/common/logger"
"github.com/Laisky/one-api/model"
"github.com/Laisky/one-api/relay"
"github.com/Laisky/one-api/relay/adaptor/openai"
"github.com/Laisky/one-api/relay/apitype"
@ -16,11 +18,11 @@ import (
billingratio "github.com/Laisky/one-api/relay/billing/ratio"
"github.com/Laisky/one-api/relay/channeltype"
"github.com/Laisky/one-api/relay/meta"
"github.com/Laisky/one-api/relay/model"
relaymodel "github.com/Laisky/one-api/relay/model"
"github.com/gin-gonic/gin"
)
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
func RelayTextHelper(c *gin.Context) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := meta.GetByContext(c)
// get & validate textRequest
@ -108,7 +110,19 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return respErr
}
// post-consume quota
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
go func() {
quota := postConsumeQuota(c, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
docu := model.NewUserRequestCost(
c.GetInt(ctxkey.Id),
c.GetString(ctxkey.RequestId),
quota,
)
if err = docu.Insert(); err != nil {
logger.Errorf(c, "insert user request cost failed: %+v", err)
}
}()
return nil
}

View File

@ -94,6 +94,10 @@ func SetApiRouter(router *gin.Engine) {
tokenRoute.PUT("/", controller.UpdateToken)
tokenRoute.DELETE("/:id", controller.DeleteToken)
}
costRoute := apiRouter.Group("/cost")
{
costRoute.GET("/request/:request_id", controller.GetRequestCost)
}
redemptionRoute := apiRouter.Group("/redemption")
redemptionRoute.Use(middleware.AdminAuth())
{