diff --git a/common/constants.go b/common/constants.go
index 96c0b57..f5dbb3d 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -12,7 +12,6 @@ import (
var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "New API"
-var ServerAddress = "http://localhost:3000"
var Footer = ""
var Logo = ""
var TopUpLink = ""
diff --git a/constant/system.go b/constant/system.go
new file mode 100644
index 0000000..b2976e4
--- /dev/null
+++ b/constant/system.go
@@ -0,0 +1,9 @@
+package constant
+
+var ServerAddress = "http://localhost:3000"
+var WorkerUrl = ""
+var WorkerValidKey = ""
+
+func EnableWorker() bool {
+ return WorkerUrl != ""
+}
diff --git a/controller/midjourney.go b/controller/midjourney.go
index 2d538f1..508c5dd 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -235,7 +235,7 @@ func GetAllMidjourney(c *gin.Context) {
}
if constant.MjForwardUrlEnabled {
for i, midjourney := range logs {
- midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
+ midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney
}
}
@@ -267,7 +267,7 @@ func GetUserMidjourney(c *gin.Context) {
}
if constant.MjForwardUrlEnabled {
for i, midjourney := range logs {
- midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
+ midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney
}
}
diff --git a/controller/misc.go b/controller/misc.go
index 8c59952..b8203f3 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -45,7 +45,7 @@ func GetStatus(c *gin.Context) {
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
- "server_address": common.ServerAddress,
+ "server_address": constant.ServerAddress,
"price": constant.Price,
"min_topup": constant.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
@@ -203,7 +203,7 @@ func SendPasswordResetEmail(c *gin.Context) {
}
code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
- link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
+ link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", constant.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", common.SystemName)
content := fmt.Sprintf("
您好,你正在进行%s密码重置。
"+
"点击 此处 进行密码重置。
"+
diff --git a/controller/topup.go b/controller/topup.go
index ebb24a9..90c3f77 100644
--- a/controller/topup.go
+++ b/controller/topup.go
@@ -92,7 +92,7 @@ func RequestEpay(c *gin.Context) {
payType = epay.WechatPay
}
callBackAddress := service.GetCallbackAddress()
- returnUrl, _ := url.Parse(common.ServerAddress + "/log")
+ returnUrl, _ := url.Parse(constant.ServerAddress + "/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
client := GetEpayClient()
diff --git a/model/option.go b/model/option.go
index bfa7ddc..6aa59cb 100644
--- a/model/option.go
+++ b/model/option.go
@@ -59,6 +59,8 @@ func InitOptionMap() {
common.OptionMap["SystemName"] = common.SystemName
common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = ""
+ common.OptionMap["WorkerUrl"] = constant.WorkerUrl
+ common.OptionMap["WorkerValidKey"] = constant.WorkerValidKey
common.OptionMap["PayAddress"] = ""
common.OptionMap["CustomCallbackAddress"] = ""
common.OptionMap["EpayId"] = ""
@@ -232,7 +234,11 @@ func updateOptionMap(key string, value string) (err error) {
case "SMTPToken":
common.SMTPToken = value
case "ServerAddress":
- common.ServerAddress = value
+ constant.ServerAddress = value
+ case "WorkerUrl":
+ constant.WorkerUrl = value
+ case "WorkerValidKey":
+ constant.WorkerValidKey = value
case "PayAddress":
constant.PayAddress = value
case "CustomCallbackAddress":
diff --git a/model/token.go b/model/token.go
index 056156e..d921b1d 100644
--- a/model/token.go
+++ b/model/token.go
@@ -5,6 +5,7 @@ import (
"fmt"
"gorm.io/gorm"
"one-api/common"
+ "one-api/constant"
"strconv"
"strings"
)
@@ -297,7 +298,7 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
prompt = "您的额度已用尽"
}
if email != "" {
- topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
+ topUpLink := fmt.Sprintf("%s/topup", constant.ServerAddress)
err = common.SendEmail(prompt, email,
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink))
if err != nil {
diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go
index 7195bcf..f9c6ab3 100644
--- a/relay/channel/claude/relay-claude.go
+++ b/relay/channel/claude/relay-claude.go
@@ -138,11 +138,11 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
// 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") {
// 是url,获取图片的类型和base64编码的数据
- mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
+ mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url)
claudeMediaMessage.Source.MediaType = mimeType
claudeMediaMessage.Source.Data = data
} else {
- _, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
+ _, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
if err != nil {
return nil, err
}
diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go
index 79644af..6d45b57 100644
--- a/relay/channel/gemini/relay-gemini.go
+++ b/relay/channel/gemini/relay-gemini.go
@@ -74,7 +74,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
if imageNum > GeminiVisionMaxImageNum {
continue
}
- mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
+ mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index c83493b..f06956a 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -111,7 +111,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled {
- midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
+ midjourneyTask.ImageUrl = constant.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
}
diff --git a/service/epay.go b/service/epay.go
index 4678157..46ea401 100644
--- a/service/epay.go
+++ b/service/epay.go
@@ -1,13 +1,12 @@
package service
import (
- "one-api/common"
"one-api/constant"
)
func GetCallbackAddress() string {
if constant.CustomCallbackAddress == "" {
- return common.ServerAddress
+ return constant.ServerAddress
}
return constant.CustomCallbackAddress
}
diff --git a/common/image.go b/service/image.go
similarity index 84%
rename from common/image.go
rename to service/image.go
index 41ff51f..f3eddff 100644
--- a/common/image.go
+++ b/service/image.go
@@ -1,4 +1,4 @@
-package common
+package service
import (
"bytes"
@@ -8,7 +8,7 @@ import (
"golang.org/x/image/webp"
"image"
"io"
- "net/http"
+ "one-api/common"
"strings"
)
@@ -31,25 +31,13 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
return config, format, base64String, err
}
-func IsImageUrl(url string) (bool, error) {
- resp, err := http.Head(url)
- if err != nil {
- return false, err
- }
- if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
- return false, nil
- }
- return true, nil
-}
-
// GetImageFromUrl 获取图片的类型和base64编码的数据
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
- isImage, err := IsImageUrl(url)
- if !isImage {
+ resp, err := DoImageRequest(url)
+ if err != nil {
return
}
- resp, err := http.Get(url)
- if err != nil {
+ if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return
}
defer resp.Body.Close()
@@ -64,16 +52,21 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
}
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
- response, err := http.Get(imageUrl)
+ response, err := DoImageRequest(imageUrl)
if err != nil {
- SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
+ common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, "", err
}
defer response.Body.Close()
+ if response.StatusCode != 200 {
+ err = errors.New(fmt.Sprintf("fail to get image from url: %s", response.Status))
+ return image.Config{}, "", err
+ }
+
var readData []byte
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
- SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
+ common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
// 从response.Body读取更多的数据直到达到当前的限制
additionalData := make([]byte, limit-int64(len(readData)))
@@ -99,11 +92,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) {
config, format, err := image.DecodeConfig(reader)
if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
- SysLog(err.Error())
+ common.SysLog(err.Error())
config, err = webp.DecodeConfig(reader)
if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
- SysLog(err.Error())
+ common.SysLog(err.Error())
}
format = "webp"
}
diff --git a/service/token_counter.go b/service/token_counter.go
index c540ac5..cdca1fd 100644
--- a/service/token_counter.go
+++ b/service/token_counter.go
@@ -86,11 +86,10 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
var err error
var format string
if strings.HasPrefix(imageUrl.Url, "http") {
- common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
- config, format, err = common.DecodeUrlImageData(imageUrl.Url)
+ config, format, err = DecodeUrlImageData(imageUrl.Url)
} else {
common.SysLog(fmt.Sprintf("decoding image"))
- config, format, _, err = common.DecodeBase64ImageData(imageUrl.Url)
+ config, format, _, err = DecodeBase64ImageData(imageUrl.Url)
}
if err != nil {
return 0, err
diff --git a/service/worker.go b/service/worker.go
new file mode 100644
index 0000000..e950eaf
--- /dev/null
+++ b/service/worker.go
@@ -0,0 +1,26 @@
+package service
+
+import (
+ "bytes"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "strings"
+)
+
+func DoImageRequest(originUrl string) (resp *http.Response, err error) {
+ if constant.EnableWorker() {
+ common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
+ workerUrl := constant.WorkerUrl
+ if !strings.HasSuffix(workerUrl, "/") {
+ workerUrl += "/"
+ }
+ // post request to worker
+ data := []byte(`{"url":"` + originUrl + `","key":"` + constant.WorkerValidKey + `"}`)
+ return http.Post(constant.WorkerUrl, "application/json", bytes.NewBuffer(data))
+ } else {
+ common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
+ return http.Get(originUrl)
+ }
+}
diff --git a/web/src/components/SystemSetting.js b/web/src/components/SystemSetting.js
index 6abc6f1..44b2ce2 100644
--- a/web/src/components/SystemSetting.js
+++ b/web/src/components/SystemSetting.js
@@ -27,6 +27,8 @@ const SystemSetting = () => {
SMTPFrom: '',
SMTPToken: '',
ServerAddress: '',
+ WorkerUrl: '',
+ WorkerValidKey: '',
EpayId: '',
EpayKey: '',
Price: 7.3,
@@ -145,6 +147,8 @@ const SystemSetting = () => {
name === 'Notice' ||
(name.startsWith('SMTP') && name !== 'SMTPSSLEnabled') ||
name === 'ServerAddress' ||
+ name === 'WorkerUrl' ||
+ name === 'WorkerValidKey' ||
name === 'EpayId' ||
name === 'EpayKey' ||
name === 'Price' ||
@@ -172,6 +176,14 @@ const SystemSetting = () => {
await updateOption('ServerAddress', ServerAddress);
};
+ const submitWorker = async () => {
+ let WorkerUrl = removeTrailingSlash(inputs.WorkerUrl);
+ await updateOption('WorkerUrl', WorkerUrl);
+ if (inputs.WorkerValidKey !== '') {
+ await updateOption('WorkerValidKey', inputs.WorkerValidKey);
+ }
+ }
+
const submitPayAddress = async () => {
if (inputs.ServerAddress === '') {
showError('请先填写服务器地址');
@@ -327,6 +339,28 @@ const SystemSetting = () => {
更新服务器地址
+
+
+
+
+
+
+ 更新Worker设置
+
支付设置(当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!)