diff --git a/common/image.go b/common/image.go new file mode 100644 index 0000000..d80cbb7 --- /dev/null +++ b/common/image.go @@ -0,0 +1,64 @@ +package common + +import ( + "bytes" + "encoding/base64" + "errors" + "fmt" + "github.com/chai2010/webp" + "image" + "io" + "net/http" + "strings" +) + +func DecodeBase64ImageData(base64String string) (image.Config, error) { + // 去除base64数据的URL前缀(如果有) + if idx := strings.Index(base64String, ","); idx != -1 { + base64String = base64String[idx+1:] + } + + // 将base64字符串解码为字节切片 + decodedData, err := base64.StdEncoding.DecodeString(base64String) + if err != nil { + fmt.Println("Error: Failed to decode base64 string") + return image.Config{}, err + } + + // 创建一个bytes.Buffer用于存储解码后的数据 + reader := bytes.NewReader(decodedData) + config, err := getImageConfig(reader) + return config, err +} + +func DecodeUrlImageData(imageUrl string) (image.Config, error) { + response, err := http.Get(imageUrl) + if err != nil { + SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) + return image.Config{}, err + } + + // 限制读取的字节数,防止下载整个图片 + limitReader := io.LimitReader(response.Body, 8192) + config, err := getImageConfig(limitReader) + response.Body.Close() + return config, err +} + +func getImageConfig(reader io.Reader) (image.Config, error) { + // 读取图片的头部信息来获取图片尺寸 + config, _, err := image.DecodeConfig(reader) + if err != nil { + err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) + SysLog(err.Error()) + config, err = webp.DecodeConfig(reader) + if err != nil { + err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) + SysLog(err.Error()) + } + } + if err != nil { + return image.Config{}, err + } + return config, nil +} diff --git a/controller/midjourney.go b/controller/midjourney.go index d9f4f89..a5ef8e4 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -19,12 +19,12 @@ import ( func UpdateMidjourneyTask() { //revocer imageModel := "midjourney" + defer func() { + if err := recover(); err != nil { + log.Printf("UpdateMidjourneyTask panic: %v", err) + } + }() for { - defer func() { - if err := recover(); err != nil { - log.Printf("UpdateMidjourneyTask panic: %v", err) - } - }() time.Sleep(time.Duration(15) * time.Second) tasks := model.GetAllUnFinishTasks() if len(tasks) != 0 { @@ -55,7 +55,6 @@ func UpdateMidjourneyTask() { // 设置超时时间 timeout := time.Second * 5 ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() // 使用带有超时的 context 创建新的请求 req = req.WithContext(ctx) @@ -68,8 +67,8 @@ func UpdateMidjourneyTask() { log.Printf("UpdateMidjourneyTask error: %v", err) continue } - defer resp.Body.Close() responseBody, err := io.ReadAll(resp.Body) + resp.Body.Close() log.Printf("responseBody: %s", string(responseBody)) var responseItem Midjourney // err = json.NewDecoder(resp.Body).Decode(&responseItem) @@ -83,12 +82,12 @@ func UpdateMidjourneyTask() { if err1 == nil && err2 == nil { jsonData, err3 := json.Marshal(responseWithoutStatus) if err3 != nil { - log.Fatalf("UpdateMidjourneyTask error1: %v", err3) + log.Printf("UpdateMidjourneyTask error1: %v", err3) continue } err4 := json.Unmarshal(jsonData, &responseStatus) if err4 != nil { - log.Fatalf("UpdateMidjourneyTask error2: %v", err4) + log.Printf("UpdateMidjourneyTask error2: %v", err4) continue } responseItem.Status = strconv.Itoa(responseStatus.Status) @@ -138,6 +137,7 @@ func UpdateMidjourneyTask() { log.Printf("UpdateMidjourneyTask error5: %v", err) } log.Printf("UpdateMidjourneyTask success: %v", task) + cancel() } } } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 6b1322b..ce9a3bf 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/chai2010/webp" "github.com/gin-gonic/gin" "github.com/pkoukk/tiktoken-go" "image" @@ -75,29 +74,21 @@ func getImageToken(imageUrl MessageImageUrl) (int, error) { if imageUrl.Detail == "low" { return 85, nil } - - response, err := http.Get(imageUrl.Url) + var config image.Config + var err error + if strings.HasPrefix(imageUrl.Url, "http") { + common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url)) + config, err = common.DecodeUrlImageData(imageUrl.Url) + } else { + common.SysLog(fmt.Sprintf("decoding image")) + config, err = common.DecodeBase64ImageData(imageUrl.Url) + } if err != nil { - fmt.Println("Error: Failed to get the URL") return 0, err } - // 限制读取的字节数,防止下载整个图片 - limitReader := io.LimitReader(response.Body, 8192) - - response.Body.Close() - - // 读取图片的头部信息来获取图片尺寸 - config, _, err := image.DecodeConfig(limitReader) - if err != nil { - common.SysLog(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) - config, err = webp.DecodeConfig(limitReader) - if err != nil { - common.SysLog(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) - } - } if config.Width == 0 || config.Height == 0 { - return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", err.Error())) + return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url)) } if config.Width < 512 && config.Height < 512 { if imageUrl.Detail == "auto" || imageUrl.Detail == "" { diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index e2aa8aa..8556493 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -106,7 +106,7 @@ const LogsTable = () => { return ( record.type === 0 || record.type === 2 ?
- { + { copyText(text) }}> {text}
@@ -133,7 +133,7 @@ const LogsTable = () => { return ( record.type === 0 || record.type === 2 ?
- { + { copyText(text) }}> {text}
@@ -202,11 +202,12 @@ const LogsTable = () => { const [logType, setLogType] = useState(0); const isAdminUser = isAdmin(); let now = new Date(); + // 初始化start_timestamp为前一天 const [inputs, setInputs] = useState({ username: '', token_name: '', model_name: '', - start_timestamp: timestamp2string(0), + start_timestamp: timestamp2string(now.getTime() / 1000 - 86400), end_timestamp: timestamp2string(now.getTime() / 1000 + 3600), channel: '' }); @@ -338,7 +339,7 @@ const LogsTable = () => { showSuccess('已复制:' + text); } else { // setSearchKeyword(text); - Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + Modal.error({title: '无法复制到剪贴板,请手动复制', content: text}); } } @@ -412,10 +413,12 @@ const LogsTable = () => { name='model_name' onChange={value => handleInputChange(value, 'model_name')}/> handleInputChange(value, 'start_timestamp')}/> handleInputChange(value, 'end_timestamp')}/>