new-api/controller/midjourney.go
2024-01-14 16:17:45 +08:00

384 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"time"
)
/*func UpdateMidjourneyTask() {
//revocer
//imageModel := "midjourney"
ctx := context.TODO()
imageModel := "midjourney"
defer func() {
if err := recover(); err != nil {
log.Printf("UpdateMidjourneyTask panic: %v", err)
}
}()
for {
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) != 0 {
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
for _, task := range tasks {
common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
if err != nil {
common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
task.FailReason = fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", task.ChannelId)
task.Status = "FAILURE"
task.Progress = "100%"
err := task.Update()
if err != nil {
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
continue
}
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
if err != nil {
common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
// 设置超时时间
timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
//req.Header.Set("Authorization", "Bearer midjourney-proxy")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := httpClient.Do(req)
if err != nil {
log.Printf("UpdateMidjourneyTask error: %v", err)
continue
}
responseBody, err := io.ReadAll(resp.Body)
resp.Body.Close()
log.Printf("responseBody: %s", string(responseBody))
var responseItem Midjourney
// err = json.NewDecoder(resp.Body).Decode(&responseItem)
err = json.Unmarshal(responseBody, &responseItem)
if err != nil {
if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") {
var responseWithoutStatus MidjourneyWithoutStatus
var responseStatus MidjourneyStatus
err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
err2 := json.Unmarshal(responseBody, &responseStatus)
if err1 == nil && err2 == nil {
jsonData, err3 := json.Marshal(responseWithoutStatus)
if err3 != nil {
log.Printf("UpdateMidjourneyTask error1: %v", err3)
continue
}
err4 := json.Unmarshal(jsonData, &responseStatus)
if err4 != nil {
log.Printf("UpdateMidjourneyTask error2: %v", err4)
continue
}
responseItem.Status = strconv.Itoa(responseStatus.Status)
} else {
log.Printf("UpdateMidjourneyTask error3: %v", err)
continue
}
} else {
log.Printf("UpdateMidjourneyTask error4: %v", err)
continue
}
}
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
task.State = responseItem.State
task.SubmitTime = responseItem.SubmitTime
task.StartTime = responseItem.StartTime
task.FinishTime = responseItem.FinishTime
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
log.Println("error update user quota cache: " + err.Error())
} else {
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio("default")
ratio := modelRatio * groupRatio
quota := int(ratio * 1 * 1000)
if quota != 0 {
err := model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
log.Println("fail to increase user quota")
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
err = task.Update()
if err != nil {
log.Printf("UpdateMidjourneyTask error5: %v", err)
}
log.Printf("UpdateMidjourneyTask success: %v", task)
cancel()
}
}
}
}
*/
func UpdateMidjourneyTaskBulk() {
//imageModel := "midjourney"
ctx := context.TODO()
for {
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) == 0 {
continue
}
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Midjourney)
nullTaskIds := make([]int, 0)
for _, task := range tasks {
if task.MjId == "" {
// 统计失败的未完成任务
nullTaskIds = append(nullTaskIds, task.Id)
continue
}
taskM[task.MjId] = task
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
}
if len(nullTaskIds) > 0 {
err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
} else {
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
}
}
if len(taskChannelM) == 0 {
continue
}
for channelId, taskIds := range taskChannelM {
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
continue
}
midjourneyChannel, err := model.CacheGetChannel(channelId)
if err != nil {
common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
err := model.MjBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
}
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
body, _ := json.Marshal(map[string]any{
"ids": taskIds,
})
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
// 设置超时时间
timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := httpClient.Do(req)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
var responseItems []Midjourney
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
continue
}
resp.Body.Close()
req.Body.Close()
cancel()
for _, responseItem := range responseItems {
task := taskM[responseItem.MjId]
if !checkMjTaskNeedUpdate(task, responseItem) {
continue
}
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
task.State = responseItem.State
task.SubmitTime = responseItem.SubmitTime
task.StartTime = responseItem.StartTime
task.FinishTime = responseItem.FinishTime
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
err = task.Update()
if err != nil {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
}
}
}
}
}
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool {
if oldTask.Code != 1 {
return true
}
if oldTask.Progress != newTask.Progress {
return true
}
if oldTask.PromptEn != newTask.PromptEn {
return true
}
if oldTask.State != newTask.State {
return true
}
if oldTask.SubmitTime != newTask.SubmitTime {
return true
}
if oldTask.StartTime != newTask.StartTime {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if oldTask.ImageUrl != newTask.ImageUrl {
return true
}
if oldTask.Status != newTask.Status {
return true
}
if oldTask.FailReason != newTask.FailReason {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if oldTask.Progress != "100%" && newTask.FailReason != "" {
return true
}
return false
}
func GetAllMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
// 解析其他查询参数
queryParams := model.TaskQueryParams{
ChannelID: c.Query("channel_id"),
MjID: c.Query("mj_id"),
StartTimestamp: c.Query("start_timestamp"),
EndTimestamp: c.Query("end_timestamp"),
}
logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}
func GetUserMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
userId := c.GetInt("id")
log.Printf("userId = %d \n", userId)
queryParams := model.TaskQueryParams{
MjID: c.Query("mj_id"),
StartTimestamp: c.Query("start_timestamp"),
EndTimestamp: c.Query("end_timestamp"),
}
logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
if !strings.Contains(common.ServerAddress, "localhost") {
for i, midjourney := range logs {
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney
}
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}