mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-13 09:33:43 +08:00
feat: 支持设置worker访问请求中的图片地址
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
)
|
||||
|
||||
func GetCallbackAddress() string {
|
||||
if constant.CustomCallbackAddress == "" {
|
||||
return common.ServerAddress
|
||||
return constant.ServerAddress
|
||||
}
|
||||
return constant.CustomCallbackAddress
|
||||
}
|
||||
|
||||
107
service/image.go
Normal file
107
service/image.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/image/webp"
|
||||
"image"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
|
||||
// 去除base64数据的URL前缀(如果有)
|
||||
if idx := strings.Index(base64String, ","); idx != -1 {
|
||||
base64String = base64String[idx+1:]
|
||||
}
|
||||
|
||||
// 将base64字符串解码为字节切片
|
||||
decodedData, err := base64.StdEncoding.DecodeString(base64String)
|
||||
if err != nil {
|
||||
fmt.Println("Error: Failed to decode base64 string")
|
||||
return image.Config{}, "", "", err
|
||||
}
|
||||
|
||||
// 创建一个bytes.Buffer用于存储解码后的数据
|
||||
reader := bytes.NewReader(decodedData)
|
||||
config, format, err := getImageConfig(reader)
|
||||
return config, format, base64String, err
|
||||
}
|
||||
|
||||
// GetImageFromUrl 获取图片的类型和base64编码的数据
|
||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
resp, err := DoImageRequest(url)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
_, err = buffer.ReadFrom(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mimeType = resp.Header.Get("Content-Type")
|
||||
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
||||
response, err := DoImageRequest(imageUrl)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != 200 {
|
||||
err = errors.New(fmt.Sprintf("fail to get image from url: %s", response.Status))
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
|
||||
var readData []byte
|
||||
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
|
||||
common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
|
||||
|
||||
// 从response.Body读取更多的数据直到达到当前的限制
|
||||
additionalData := make([]byte, limit-int64(len(readData)))
|
||||
n, _ := io.ReadFull(response.Body, additionalData)
|
||||
readData = append(readData, additionalData[:n]...)
|
||||
|
||||
// 使用io.MultiReader组合已经读取的数据和response.Body
|
||||
limitReader := io.MultiReader(bytes.NewReader(readData), response.Body)
|
||||
|
||||
var config image.Config
|
||||
var format string
|
||||
config, format, err = getImageConfig(limitReader)
|
||||
if err == nil {
|
||||
return config, format, nil
|
||||
}
|
||||
}
|
||||
|
||||
return image.Config{}, "", err // 返回最后一个错误
|
||||
}
|
||||
|
||||
func getImageConfig(reader io.Reader) (image.Config, string, error) {
|
||||
// 读取图片的头部信息来获取图片尺寸
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
||||
common.SysLog(err.Error())
|
||||
config, err = webp.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
||||
common.SysLog(err.Error())
|
||||
}
|
||||
format = "webp"
|
||||
}
|
||||
if err != nil {
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
return config, format, nil
|
||||
}
|
||||
@@ -79,11 +79,10 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
|
||||
var err error
|
||||
var format string
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
|
||||
config, format, err = common.DecodeUrlImageData(imageUrl.Url)
|
||||
config, format, err = DecodeUrlImageData(imageUrl.Url)
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("decoding image"))
|
||||
config, format, _, err = common.DecodeBase64ImageData(imageUrl.Url)
|
||||
config, format, _, err = DecodeBase64ImageData(imageUrl.Url)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
||||
26
service/worker.go
Normal file
26
service/worker.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func DoImageRequest(originUrl string) (resp *http.Response, err error) {
|
||||
if constant.EnableWorker() {
|
||||
common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
|
||||
workerUrl := constant.WorkerUrl
|
||||
if !strings.HasSuffix(workerUrl, "/") {
|
||||
workerUrl += "/"
|
||||
}
|
||||
// post request to worker
|
||||
data := []byte(`{"url":"` + originUrl + `","key":"` + constant.WorkerValidKey + `"}`)
|
||||
return http.Post(constant.WorkerUrl, "application/json", bytes.NewBuffer(data))
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
|
||||
return http.Get(originUrl)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user