mirror of
				https://github.com/linux-do/new-api.git
				synced 2025-11-04 05:13:41 +08:00 
			
		
		
		
	refactor: update logging related logic
This commit is contained in:
		
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -4,4 +4,5 @@ upload
 | 
			
		||||
*.exe
 | 
			
		||||
*.db
 | 
			
		||||
build
 | 
			
		||||
*.db-journal
 | 
			
		||||
*.db-journal
 | 
			
		||||
logs
 | 
			
		||||
@@ -97,6 +97,10 @@ var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQU
 | 
			
		||||
var BatchUpdateEnabled = false
 | 
			
		||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RequestIdKey = "X-Oneapi-Request-Id"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RoleGuestUser  = 0
 | 
			
		||||
	RoleCommonUser = 1
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,7 @@ var (
 | 
			
		||||
	Port         = flag.Int("port", 3000, "the listening port")
 | 
			
		||||
	PrintVersion = flag.Bool("version", false, "print version and exit")
 | 
			
		||||
	PrintHelp    = flag.Bool("help", false, "print help and exit")
 | 
			
		||||
	LogDir       = flag.String("log-dir", "", "specify the log directory")
 | 
			
		||||
	LogDir       = flag.String("log-dir", "./logs", "specify the log directory")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func printHelp() {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package common
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
@@ -10,20 +11,21 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	loggerINFO  = "INFO"
 | 
			
		||||
	loggerWarn  = "WARN"
 | 
			
		||||
	loggerError = "ERR"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetupGinLog() {
 | 
			
		||||
	if *LogDir != "" {
 | 
			
		||||
		commonLogPath := filepath.Join(*LogDir, "common.log")
 | 
			
		||||
		errorLogPath := filepath.Join(*LogDir, "error.log")
 | 
			
		||||
		commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 | 
			
		||||
		logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
 | 
			
		||||
		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Fatal("failed to open log file")
 | 
			
		||||
		}
 | 
			
		||||
		errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Fatal("failed to open log file")
 | 
			
		||||
		}
 | 
			
		||||
		gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
 | 
			
		||||
		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
 | 
			
		||||
		gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
 | 
			
		||||
		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -37,6 +39,28 @@ func SysError(s string) {
 | 
			
		||||
	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func LogInfo(ctx context.Context, msg string) {
 | 
			
		||||
	logHelper(ctx, loggerINFO, msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func LogWarn(ctx context.Context, msg string) {
 | 
			
		||||
	logHelper(ctx, loggerWarn, msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func LogError(ctx context.Context, msg string) {
 | 
			
		||||
	logHelper(ctx, loggerError, msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func logHelper(ctx context.Context, level string, msg string) {
 | 
			
		||||
	writer := gin.DefaultErrorWriter
 | 
			
		||||
	if level == loggerINFO {
 | 
			
		||||
		writer = gin.DefaultWriter
 | 
			
		||||
	}
 | 
			
		||||
	id := ctx.Value(RequestIdKey)
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func FatalLog(v ...any) {
 | 
			
		||||
	t := time.Now()
 | 
			
		||||
	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
 | 
			
		||||
 
 | 
			
		||||
@@ -171,6 +171,11 @@ func GetTimestamp() int64 {
 | 
			
		||||
	return time.Now().Unix()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetTimeString() string {
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Max(a int, b int) int {
 | 
			
		||||
	if a >= b {
 | 
			
		||||
		return a
 | 
			
		||||
@@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int {
 | 
			
		||||
	}
 | 
			
		||||
	return num
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func MessageWithRequestId(message string, id string) string {
 | 
			
		||||
	return fmt.Sprintf("%s (request id: %s)", message, id)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
@@ -91,7 +92,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
	}
 | 
			
		||||
	var audioResponse AudioResponse
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
	defer func(ctx context.Context) {
 | 
			
		||||
		go func() {
 | 
			
		||||
			quota := countTokenText(audioResponse.Text, audioModel)
 | 
			
		||||
			quotaDelta := quota - preConsumedQuota
 | 
			
		||||
@@ -106,13 +107,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
			if quota != 0 {
 | 
			
		||||
				tokenName := c.GetString("token_name")
 | 
			
		||||
				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
				model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
 | 
			
		||||
				model.RecordConsumeLog(ctx, userId, 0, 0, audioModel, tokenName, quota, logContent)
 | 
			
		||||
				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
				channelId := c.GetInt("channel_id")
 | 
			
		||||
				model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}()
 | 
			
		||||
	}(c.Request.Context())
 | 
			
		||||
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
@@ -124,7 +125,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
	}
 | 
			
		||||
	var textResponse ImageResponse
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
	defer func(ctx context.Context) {
 | 
			
		||||
		if consumeQuota {
 | 
			
		||||
			err := model.PostConsumeTokenQuota(tokenId, quota)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
@@ -137,13 +138,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
			if quota != 0 {
 | 
			
		||||
				tokenName := c.GetString("token_name")
 | 
			
		||||
				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
				model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
 | 
			
		||||
				model.RecordConsumeLog(ctx, userId, 0, 0, imageModel, tokenName, quota, logContent)
 | 
			
		||||
				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
				channelId := c.GetInt("channel_id")
 | 
			
		||||
				model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	}(c.Request.Context())
 | 
			
		||||
 | 
			
		||||
	if consumeQuota {
 | 
			
		||||
		responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
@@ -210,6 +211,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		// in this case, we do not pre-consume quota
 | 
			
		||||
		// because the user has enough quota
 | 
			
		||||
		preConsumedQuota = 0
 | 
			
		||||
		common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
 | 
			
		||||
	}
 | 
			
		||||
	if consumeQuota && preConsumedQuota > 0 {
 | 
			
		||||
		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
			
		||||
@@ -348,13 +350,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
 | 
			
		||||
		if resp.StatusCode != http.StatusOK {
 | 
			
		||||
			if preConsumedQuota != 0 {
 | 
			
		||||
				go func() {
 | 
			
		||||
				go func(ctx context.Context) {
 | 
			
		||||
					// return pre-consumed quota
 | 
			
		||||
					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						common.SysError("error return pre-consumed quota: " + err.Error())
 | 
			
		||||
						common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
 | 
			
		||||
					}
 | 
			
		||||
				}()
 | 
			
		||||
				}(c.Request.Context())
 | 
			
		||||
			}
 | 
			
		||||
			return relayErrorHandler(resp)
 | 
			
		||||
		}
 | 
			
		||||
@@ -364,7 +366,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	tokenName := c.GetString("token_name")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
	defer func(ctx context.Context) {
 | 
			
		||||
		// c.Writer.Flush()
 | 
			
		||||
		go func() {
 | 
			
		||||
			if consumeQuota {
 | 
			
		||||
@@ -387,21 +389,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
				quotaDelta := quota - preConsumedQuota
 | 
			
		||||
				err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("error consuming token remain quota: " + err.Error())
 | 
			
		||||
					common.LogError(ctx, "error consuming token remain quota: "+err.Error())
 | 
			
		||||
				}
 | 
			
		||||
				err = model.CacheUpdateUserQuota(userId)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("error update user quota cache: " + err.Error())
 | 
			
		||||
					common.LogError(ctx, "error update user quota cache: "+err.Error())
 | 
			
		||||
				}
 | 
			
		||||
				if quota != 0 {
 | 
			
		||||
					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
					model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
 | 
			
		||||
					model.RecordConsumeLog(ctx, userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
 | 
			
		||||
					model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
					model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}()
 | 
			
		||||
	}(c.Request.Context())
 | 
			
		||||
	switch apiType {
 | 
			
		||||
	case APITypeOpenAI:
 | 
			
		||||
		if isStream {
 | 
			
		||||
 
 | 
			
		||||
@@ -196,6 +196,7 @@ func Relay(c *gin.Context) {
 | 
			
		||||
		err = relayTextHelper(c, relayMode)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		requestId := c.GetString(common.RequestIdKey)
 | 
			
		||||
		retryTimesStr := c.Query("retry")
 | 
			
		||||
		retryTimes, _ := strconv.Atoi(retryTimesStr)
 | 
			
		||||
		if retryTimesStr == "" {
 | 
			
		||||
@@ -207,12 +208,13 @@ func Relay(c *gin.Context) {
 | 
			
		||||
			if err.StatusCode == http.StatusTooManyRequests {
 | 
			
		||||
				err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
 | 
			
		||||
			}
 | 
			
		||||
			err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
 | 
			
		||||
			c.JSON(err.StatusCode, gin.H{
 | 
			
		||||
				"error": err.OpenAIError,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
		channelId := c.GetInt("channel_id")
 | 
			
		||||
		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 | 
			
		||||
		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 | 
			
		||||
		// https://platform.openai.com/docs/guides/error-codes/api-errors
 | 
			
		||||
		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
 | 
			
		||||
			channelId := c.GetInt("channel_id")
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										7
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								main.go
									
									
									
									
									
								
							@@ -7,6 +7,7 @@ import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/controller"
 | 
			
		||||
	"one-api/middleware"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"one-api/router"
 | 
			
		||||
	"os"
 | 
			
		||||
@@ -84,10 +85,12 @@ func main() {
 | 
			
		||||
	controller.InitTokenEncoders()
 | 
			
		||||
 | 
			
		||||
	// Initialize HTTP server
 | 
			
		||||
	server := gin.Default()
 | 
			
		||||
	server := gin.New()
 | 
			
		||||
	server.Use(gin.Recovery())
 | 
			
		||||
	// This will cause SSE not to work!!!
 | 
			
		||||
	//server.Use(gzip.Gzip(gzip.DefaultCompression))
 | 
			
		||||
 | 
			
		||||
	server.Use(middleware.RequestId())
 | 
			
		||||
	middleware.SetUpLogger(server)
 | 
			
		||||
	// Initialize session store
 | 
			
		||||
	store := cookie.NewStore([]byte(common.SessionSecret))
 | 
			
		||||
	server.Use(sessions.Sessions("session", store))
 | 
			
		||||
 
 | 
			
		||||
@@ -91,34 +91,16 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
		key = parts[0]
 | 
			
		||||
		token, err := model.ValidateUserToken(key)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusUnauthorized, gin.H{
 | 
			
		||||
				"error": gin.H{
 | 
			
		||||
					"message": err.Error(),
 | 
			
		||||
					"type":    "one_api_error",
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			abortWithMessage(c, http.StatusUnauthorized, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		userEnabled, err := model.IsUserEnabled(token.UserId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, gin.H{
 | 
			
		||||
				"error": gin.H{
 | 
			
		||||
					"message": err.Error(),
 | 
			
		||||
					"type":    "one_api_error",
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			abortWithMessage(c, http.StatusInternalServerError, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if !userEnabled {
 | 
			
		||||
			c.JSON(http.StatusForbidden, gin.H{
 | 
			
		||||
				"error": gin.H{
 | 
			
		||||
					"message": "用户已被封禁",
 | 
			
		||||
					"type":    "one_api_error",
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		c.Set("id", token.UserId)
 | 
			
		||||
@@ -134,13 +116,7 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
			if model.IsAdmin(token.UserId) {
 | 
			
		||||
				c.Set("channelId", parts[1])
 | 
			
		||||
			} else {
 | 
			
		||||
				c.JSON(http.StatusForbidden, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "普通用户不支持指定渠道",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -25,34 +25,16 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
		if ok {
 | 
			
		||||
			id, err := strconv.Atoi(channelId.(string))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.JSON(http.StatusBadRequest, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "无效的渠道 ID",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			channel, err = model.GetChannelById(id, true)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.JSON(http.StatusBadRequest, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "无效的渠道 ID",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if channel.Status != common.ChannelStatusEnabled {
 | 
			
		||||
				c.JSON(http.StatusForbidden, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "该渠道已被禁用",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
@@ -63,13 +45,7 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
				err = common.UnmarshalBodyReusable(c, &modelRequest)
 | 
			
		||||
			}
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.JSON(http.StatusBadRequest, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "无效的请求",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusBadRequest, "无效的请求")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 | 
			
		||||
@@ -99,13 +75,7 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
 | 
			
		||||
					message = "数据库一致性已被破坏,请联系管理员"
 | 
			
		||||
				}
 | 
			
		||||
				c.JSON(http.StatusServiceUnavailable, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": message,
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusServiceUnavailable, message)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										25
									
								
								middleware/logger.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								middleware/logger.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
			
		||||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetUpLogger(server *gin.Engine) {
 | 
			
		||||
	server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
 | 
			
		||||
		var requestID string
 | 
			
		||||
		if param.Keys != nil {
 | 
			
		||||
			requestID = param.Keys[common.RequestIdKey].(string)
 | 
			
		||||
		}
 | 
			
		||||
		return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
 | 
			
		||||
			param.TimeStamp.Format("2006/01/02 - 15:04:05"),
 | 
			
		||||
			requestID,
 | 
			
		||||
			param.StatusCode,
 | 
			
		||||
			param.Latency,
 | 
			
		||||
			param.ClientIP,
 | 
			
		||||
			param.Method,
 | 
			
		||||
			param.Path,
 | 
			
		||||
		)
 | 
			
		||||
	}))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										18
									
								
								middleware/request-id.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								middleware/request-id.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,18 @@
 | 
			
		||||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func RequestId() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		id := common.GetTimeString() + common.GetRandomString(8)
 | 
			
		||||
		c.Set(common.RequestIdKey, id)
 | 
			
		||||
		ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
 | 
			
		||||
		c.Request = c.Request.WithContext(ctx)
 | 
			
		||||
		c.Header(common.RequestIdKey, id)
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										17
									
								
								middleware/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								middleware/utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
			
		||||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
 | 
			
		||||
	c.JSON(statusCode, gin.H{
 | 
			
		||||
		"error": gin.H{
 | 
			
		||||
			"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
 | 
			
		||||
			"type":    "one_api_error",
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
	c.Abort()
 | 
			
		||||
	common.LogError(c.Request.Context(), message)
 | 
			
		||||
}
 | 
			
		||||
@@ -1,6 +1,8 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
@@ -44,7 +46,8 @@ func RecordLog(userId int, logType int, content string) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
 | 
			
		||||
func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
 | 
			
		||||
	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, promptTokens, completionTokens, modelName, tokenName, quota, content))
 | 
			
		||||
	if !common.LogConsumeEnabled {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -62,7 +65,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
 | 
			
		||||
	}
 | 
			
		||||
	err := DB.Create(log).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("failed to record log: " + err.Error())
 | 
			
		||||
		common.LogError(ctx, "failed to record log: "+err.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user