🔖 chore: migration logger package

This commit is contained in:
MartialBE
2024-05-29 01:04:23 +08:00
parent 79524108a3
commit ce12558ad6
44 changed files with 207 additions and 174 deletions

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"one-api/common"
"one-api/common/logger"
"one-api/common/notify"
"one-api/common/utils"
"one-api/model"
@@ -70,7 +71,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
// 转换为JSON字符串
jsonBytes, _ := json.Marshal(response)
common.SysLog(fmt.Sprintf("测试渠道 %s : %s 返回内容为:%s", channel.Name, request.Model, string(jsonBytes)))
logger.SysLog(fmt.Sprintf("测试渠道 %s : %s 返回内容为:%s", channel.Name, request.Model, string(jsonBytes)))
return nil, nil
}
@@ -233,8 +234,8 @@ func AutomaticallyTestChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("testing all channels")
logger.SysLog("testing all channels")
_ = testAllChannels(false)
common.SysLog("channel test finished")
logger.SysLog("channel test finished")
}
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/common/logger"
"one-api/common/utils"
"one-api/model"
"strconv"
@@ -48,7 +49,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res.Body.Close()
@@ -64,7 +65,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res2.Body.Close()

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/common/logger"
"one-api/model"
"strconv"
"time"
@@ -58,7 +59,7 @@ func getLarkAppAccessToken() (string, error) {
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
logger.SysLog(err.Error())
return "", errors.New("无法连接至飞书服务器,请稍后重试!")
}
defer res.Body.Close()
@@ -100,7 +101,7 @@ func getLarkUserAccessToken(code string) (string, error) {
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
logger.SysLog(err.Error())
return "", errors.New("无法连接至飞书服务器,请稍后重试!")
}
defer res.Body.Close()
@@ -135,7 +136,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
}
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
logger.SysLog(err.Error())
return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
}
var larkUser LarkUser

View File

@@ -11,6 +11,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/common/logger"
"one-api/common/requester"
"one-api/model"
provider "one-api/providers/midjourney"
@@ -45,9 +46,9 @@ func ActivateUpdateMidjourneyTaskBulk() {
}
func UpdateMidjourneyTaskBulk() {
ctx := context.WithValue(context.Background(), common.RequestIdKey, "MidjourneyTask")
ctx := context.WithValue(context.Background(), logger.RequestIdKey, "MidjourneyTask")
for {
common.LogInfo(ctx, "running")
logger.LogInfo(ctx, "running")
tasks := model.GetAllUnFinishTasks()
@@ -56,11 +57,11 @@ func UpdateMidjourneyTaskBulk() {
for len(activeMidjourneyTask) > 0 {
<-activeMidjourneyTask
}
common.LogInfo(ctx, "no tasks, waiting...")
logger.LogInfo(ctx, "no tasks, waiting...")
return
}
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Midjourney)
nullTaskIds := make([]int, 0)
@@ -79,9 +80,9 @@ func UpdateMidjourneyTaskBulk() {
"progress": "100%",
})
if err != nil {
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
} else {
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
}
}
if len(taskChannelM) == 0 {
@@ -89,7 +90,7 @@ func UpdateMidjourneyTaskBulk() {
}
for channelId, taskIds := range taskChannelM {
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
continue
}
@@ -100,7 +101,7 @@ func UpdateMidjourneyTaskBulk() {
"status": "FAILURE",
"progress": "100%",
})
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
@@ -110,7 +111,7 @@ func UpdateMidjourneyTaskBulk() {
})
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
// 设置超时时间
@@ -122,22 +123,22 @@ func UpdateMidjourneyTaskBulk() {
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := requester.HTTPClient.Do(req)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
if resp.StatusCode != http.StatusOK {
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
continue
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
var responseItems []provider.MidjourneyDto
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
continue
}
resp.Body.Close()
@@ -176,17 +177,17 @@ func UpdateMidjourneyTaskBulk() {
}
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
logger.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
@@ -195,7 +196,7 @@ func UpdateMidjourneyTaskBulk() {
}
err = task.Update()
if err != nil {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
}
}
}