support base64 image

This commit is contained in:
CaIon
2023-11-19 18:59:35 +08:00
parent 6e670c0b2e
commit 57d0fc3021
4 changed files with 90 additions and 32 deletions

View File

@@ -19,12 +19,12 @@ import (
func UpdateMidjourneyTask() {
//revocer
imageModel := "midjourney"
defer func() {
if err := recover(); err != nil {
log.Printf("UpdateMidjourneyTask panic: %v", err)
}
}()
for {
defer func() {
if err := recover(); err != nil {
log.Printf("UpdateMidjourneyTask panic: %v", err)
}
}()
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) != 0 {
@@ -55,7 +55,6 @@ func UpdateMidjourneyTask() {
// 设置超时时间
timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
@@ -68,8 +67,8 @@ func UpdateMidjourneyTask() {
log.Printf("UpdateMidjourneyTask error: %v", err)
continue
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
resp.Body.Close()
log.Printf("responseBody: %s", string(responseBody))
var responseItem Midjourney
// err = json.NewDecoder(resp.Body).Decode(&responseItem)
@@ -83,12 +82,12 @@ func UpdateMidjourneyTask() {
if err1 == nil && err2 == nil {
jsonData, err3 := json.Marshal(responseWithoutStatus)
if err3 != nil {
log.Fatalf("UpdateMidjourneyTask error1: %v", err3)
log.Printf("UpdateMidjourneyTask error1: %v", err3)
continue
}
err4 := json.Unmarshal(jsonData, &responseStatus)
if err4 != nil {
log.Fatalf("UpdateMidjourneyTask error2: %v", err4)
log.Printf("UpdateMidjourneyTask error2: %v", err4)
continue
}
responseItem.Status = strconv.Itoa(responseStatus.Status)
@@ -138,6 +137,7 @@ func UpdateMidjourneyTask() {
log.Printf("UpdateMidjourneyTask error5: %v", err)
}
log.Printf("UpdateMidjourneyTask success: %v", task)
cancel()
}
}
}

View File

@@ -4,7 +4,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/chai2010/webp"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"image"
@@ -75,29 +74,21 @@ func getImageToken(imageUrl MessageImageUrl) (int, error) {
if imageUrl.Detail == "low" {
return 85, nil
}
response, err := http.Get(imageUrl.Url)
var config image.Config
var err error
if strings.HasPrefix(imageUrl.Url, "http") {
common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
config, err = common.DecodeUrlImageData(imageUrl.Url)
} else {
common.SysLog(fmt.Sprintf("decoding image"))
config, err = common.DecodeBase64ImageData(imageUrl.Url)
}
if err != nil {
fmt.Println("Error: Failed to get the URL")
return 0, err
}
// 限制读取的字节数,防止下载整个图片
limitReader := io.LimitReader(response.Body, 8192)
response.Body.Close()
// 读取图片的头部信息来获取图片尺寸
config, _, err := image.DecodeConfig(limitReader)
if err != nil {
common.SysLog(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
config, err = webp.DecodeConfig(limitReader)
if err != nil {
common.SysLog(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
}
}
if config.Width == 0 || config.Height == 0 {
return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", err.Error()))
return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url))
}
if config.Width < 512 && config.Height < 512 {
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {