mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
support base64 image
This commit is contained in:
parent
6e670c0b2e
commit
57d0fc3021
64
common/image.go
Normal file
64
common/image.go
Normal file
@ -0,0 +1,64 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/chai2010/webp"
|
||||
"image"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, error) {
|
||||
// 去除base64数据的URL前缀(如果有)
|
||||
if idx := strings.Index(base64String, ","); idx != -1 {
|
||||
base64String = base64String[idx+1:]
|
||||
}
|
||||
|
||||
// 将base64字符串解码为字节切片
|
||||
decodedData, err := base64.StdEncoding.DecodeString(base64String)
|
||||
if err != nil {
|
||||
fmt.Println("Error: Failed to decode base64 string")
|
||||
return image.Config{}, err
|
||||
}
|
||||
|
||||
// 创建一个bytes.Buffer用于存储解码后的数据
|
||||
reader := bytes.NewReader(decodedData)
|
||||
config, err := getImageConfig(reader)
|
||||
return config, err
|
||||
}
|
||||
|
||||
func DecodeUrlImageData(imageUrl string) (image.Config, error) {
|
||||
response, err := http.Get(imageUrl)
|
||||
if err != nil {
|
||||
SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
|
||||
return image.Config{}, err
|
||||
}
|
||||
|
||||
// 限制读取的字节数,防止下载整个图片
|
||||
limitReader := io.LimitReader(response.Body, 8192)
|
||||
config, err := getImageConfig(limitReader)
|
||||
response.Body.Close()
|
||||
return config, err
|
||||
}
|
||||
|
||||
func getImageConfig(reader io.Reader) (image.Config, error) {
|
||||
// 读取图片的头部信息来获取图片尺寸
|
||||
config, _, err := image.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
||||
SysLog(err.Error())
|
||||
config, err = webp.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
||||
SysLog(err.Error())
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return image.Config{}, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
@ -19,12 +19,12 @@ import (
|
||||
func UpdateMidjourneyTask() {
|
||||
//revocer
|
||||
imageModel := "midjourney"
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Printf("UpdateMidjourneyTask panic: %v", err)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Printf("UpdateMidjourneyTask panic: %v", err)
|
||||
}
|
||||
}()
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
tasks := model.GetAllUnFinishTasks()
|
||||
if len(tasks) != 0 {
|
||||
@ -55,7 +55,6 @@ func UpdateMidjourneyTask() {
|
||||
// 设置超时时间
|
||||
timeout := time.Second * 5
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// 使用带有超时的 context 创建新的请求
|
||||
req = req.WithContext(ctx)
|
||||
@ -68,8 +67,8 @@ func UpdateMidjourneyTask() {
|
||||
log.Printf("UpdateMidjourneyTask error: %v", err)
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
log.Printf("responseBody: %s", string(responseBody))
|
||||
var responseItem Midjourney
|
||||
// err = json.NewDecoder(resp.Body).Decode(&responseItem)
|
||||
@ -83,12 +82,12 @@ func UpdateMidjourneyTask() {
|
||||
if err1 == nil && err2 == nil {
|
||||
jsonData, err3 := json.Marshal(responseWithoutStatus)
|
||||
if err3 != nil {
|
||||
log.Fatalf("UpdateMidjourneyTask error1: %v", err3)
|
||||
log.Printf("UpdateMidjourneyTask error1: %v", err3)
|
||||
continue
|
||||
}
|
||||
err4 := json.Unmarshal(jsonData, &responseStatus)
|
||||
if err4 != nil {
|
||||
log.Fatalf("UpdateMidjourneyTask error2: %v", err4)
|
||||
log.Printf("UpdateMidjourneyTask error2: %v", err4)
|
||||
continue
|
||||
}
|
||||
responseItem.Status = strconv.Itoa(responseStatus.Status)
|
||||
@ -138,6 +137,7 @@ func UpdateMidjourneyTask() {
|
||||
log.Printf("UpdateMidjourneyTask error5: %v", err)
|
||||
}
|
||||
log.Printf("UpdateMidjourneyTask success: %v", task)
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/chai2010/webp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"image"
|
||||
@ -75,29 +74,21 @@ func getImageToken(imageUrl MessageImageUrl) (int, error) {
|
||||
if imageUrl.Detail == "low" {
|
||||
return 85, nil
|
||||
}
|
||||
|
||||
response, err := http.Get(imageUrl.Url)
|
||||
var config image.Config
|
||||
var err error
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
|
||||
config, err = common.DecodeUrlImageData(imageUrl.Url)
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("decoding image"))
|
||||
config, err = common.DecodeBase64ImageData(imageUrl.Url)
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Println("Error: Failed to get the URL")
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 限制读取的字节数,防止下载整个图片
|
||||
limitReader := io.LimitReader(response.Body, 8192)
|
||||
|
||||
response.Body.Close()
|
||||
|
||||
// 读取图片的头部信息来获取图片尺寸
|
||||
config, _, err := image.DecodeConfig(limitReader)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
||||
config, err = webp.DecodeConfig(limitReader)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
||||
}
|
||||
}
|
||||
if config.Width == 0 || config.Height == 0 {
|
||||
return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", err.Error()))
|
||||
return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url))
|
||||
}
|
||||
if config.Width < 512 && config.Height < 512 {
|
||||
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
|
||||
|
@ -106,7 +106,7 @@ const LogsTable = () => {
|
||||
return (
|
||||
record.type === 0 || record.type === 2 ?
|
||||
<div>
|
||||
<Tag color='grey' size='large' onClick={()=>{
|
||||
<Tag color='grey' size='large' onClick={() => {
|
||||
copyText(text)
|
||||
}}> {text} </Tag>
|
||||
</div>
|
||||
@ -133,7 +133,7 @@ const LogsTable = () => {
|
||||
return (
|
||||
record.type === 0 || record.type === 2 ?
|
||||
<div>
|
||||
<Tag color={stringToColor(text)} size='large' onClick={()=>{
|
||||
<Tag color={stringToColor(text)} size='large' onClick={() => {
|
||||
copyText(text)
|
||||
}}> {text} </Tag>
|
||||
</div>
|
||||
@ -202,11 +202,12 @@ const LogsTable = () => {
|
||||
const [logType, setLogType] = useState(0);
|
||||
const isAdminUser = isAdmin();
|
||||
let now = new Date();
|
||||
// 初始化start_timestamp为前一天
|
||||
const [inputs, setInputs] = useState({
|
||||
username: '',
|
||||
token_name: '',
|
||||
model_name: '',
|
||||
start_timestamp: timestamp2string(0),
|
||||
start_timestamp: timestamp2string(now.getTime() / 1000 - 86400),
|
||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
||||
channel: ''
|
||||
});
|
||||
@ -338,7 +339,7 @@ const LogsTable = () => {
|
||||
showSuccess('已复制:' + text);
|
||||
} else {
|
||||
// setSearchKeyword(text);
|
||||
Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text });
|
||||
Modal.error({title: '无法复制到剪贴板,请手动复制', content: text});
|
||||
}
|
||||
}
|
||||
|
||||
@ -412,10 +413,12 @@ const LogsTable = () => {
|
||||
name='model_name'
|
||||
onChange={value => handleInputChange(value, 'model_name')}/>
|
||||
<Form.DatePicker field="start_timestamp" label='起始时间' style={{width: 272}}
|
||||
initValue={start_timestamp}
|
||||
value={start_timestamp} type='dateTime'
|
||||
name='start_timestamp'
|
||||
onChange={value => handleInputChange(value, 'start_timestamp')}/>
|
||||
<Form.DatePicker field="end_timestamp" fluid label='结束时间' style={{width: 272}}
|
||||
initValue={end_timestamp}
|
||||
value={end_timestamp} type='dateTime'
|
||||
name='end_timestamp'
|
||||
onChange={value => handleInputChange(value, 'end_timestamp')}/>
|
||||
|
Loading…
Reference in New Issue
Block a user