diff --git a/common/constants.go b/common/constants.go index 96c0b57..f5dbb3d 100644 --- a/common/constants.go +++ b/common/constants.go @@ -12,7 +12,6 @@ import ( var StartTime = time.Now().Unix() // unit: second var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change var SystemName = "New API" -var ServerAddress = "http://localhost:3000" var Footer = "" var Logo = "" var TopUpLink = "" diff --git a/constant/system.go b/constant/system.go new file mode 100644 index 0000000..b2976e4 --- /dev/null +++ b/constant/system.go @@ -0,0 +1,9 @@ +package constant + +var ServerAddress = "http://localhost:3000" +var WorkerUrl = "" +var WorkerValidKey = "" + +func EnableWorker() bool { + return WorkerUrl != "" +} diff --git a/controller/midjourney.go b/controller/midjourney.go index 2d538f1..508c5dd 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -235,7 +235,7 @@ func GetAllMidjourney(c *gin.Context) { } if constant.MjForwardUrlEnabled { for i, midjourney := range logs { - midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId + midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId logs[i] = midjourney } } @@ -267,7 +267,7 @@ func GetUserMidjourney(c *gin.Context) { } if constant.MjForwardUrlEnabled { for i, midjourney := range logs { - midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId + midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId logs[i] = midjourney } } diff --git a/controller/misc.go b/controller/misc.go index 8c59952..b8203f3 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -45,7 +45,7 @@ func GetStatus(c *gin.Context) { "footer_html": common.Footer, "wechat_qrcode": common.WeChatAccountQRCodeImageURL, "wechat_login": common.WeChatAuthEnabled, - "server_address": common.ServerAddress, + "server_address": constant.ServerAddress, "price": constant.Price, "min_topup": constant.MinTopUp, "turnstile_check": common.TurnstileCheckEnabled, @@ -203,7 +203,7 @@ func SendPasswordResetEmail(c *gin.Context) { } code := common.GenerateVerificationCode(0) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) - link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) + link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", constant.ServerAddress, email, code) subject := fmt.Sprintf("%s密码重置", common.SystemName) content := fmt.Sprintf("

您好,你正在进行%s密码重置。

"+ "

点击 此处 进行密码重置。

"+ diff --git a/controller/topup.go b/controller/topup.go index ebb24a9..90c3f77 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -92,7 +92,7 @@ func RequestEpay(c *gin.Context) { payType = epay.WechatPay } callBackAddress := service.GetCallbackAddress() - returnUrl, _ := url.Parse(common.ServerAddress + "/log") + returnUrl, _ := url.Parse(constant.ServerAddress + "/log") notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix()) client := GetEpayClient() diff --git a/model/option.go b/model/option.go index bfa7ddc..6aa59cb 100644 --- a/model/option.go +++ b/model/option.go @@ -59,6 +59,8 @@ func InitOptionMap() { common.OptionMap["SystemName"] = common.SystemName common.OptionMap["Logo"] = common.Logo common.OptionMap["ServerAddress"] = "" + common.OptionMap["WorkerUrl"] = constant.WorkerUrl + common.OptionMap["WorkerValidKey"] = constant.WorkerValidKey common.OptionMap["PayAddress"] = "" common.OptionMap["CustomCallbackAddress"] = "" common.OptionMap["EpayId"] = "" @@ -232,7 +234,11 @@ func updateOptionMap(key string, value string) (err error) { case "SMTPToken": common.SMTPToken = value case "ServerAddress": - common.ServerAddress = value + constant.ServerAddress = value + case "WorkerUrl": + constant.WorkerUrl = value + case "WorkerValidKey": + constant.WorkerValidKey = value case "PayAddress": constant.PayAddress = value case "CustomCallbackAddress": diff --git a/model/token.go b/model/token.go index 056156e..d921b1d 100644 --- a/model/token.go +++ b/model/token.go @@ -5,6 +5,7 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" + "one-api/constant" "strconv" "strings" ) @@ -297,7 +298,7 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo prompt = "您的额度已用尽" } if email != "" { - topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) + topUpLink := fmt.Sprintf("%s/topup", constant.ServerAddress) err = common.SendEmail(prompt, email, fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink)) if err != nil { diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 7195bcf..f9c6ab3 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -138,11 +138,11 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR // 判断是否是url if strings.HasPrefix(imageUrl.Url, "http") { // 是url,获取图片的类型和base64编码的数据 - mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url) + mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url) claudeMediaMessage.Source.MediaType = mimeType claudeMediaMessage.Source.Data = data } else { - _, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url) + _, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url) if err != nil { return nil, err } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 79644af..6d45b57 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -74,7 +74,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques if imageNum > GeminiVisionMaxImageNum { continue } - mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) + mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) parts = append(parts, GeminiPart{ InlineData: &GeminiInlineData{ MimeType: mimeType, diff --git a/relay/relay-mj.go b/relay/relay-mj.go index c83493b..f06956a 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -111,7 +111,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo midjourneyTask.FinishTime = originTask.FinishTime midjourneyTask.ImageUrl = "" if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled { - midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId + midjourneyTask.ImageUrl = constant.ServerAddress + "/mj/image/" + originTask.MjId if originTask.Status != "SUCCESS" { midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) } diff --git a/service/epay.go b/service/epay.go index 4678157..46ea401 100644 --- a/service/epay.go +++ b/service/epay.go @@ -1,13 +1,12 @@ package service import ( - "one-api/common" "one-api/constant" ) func GetCallbackAddress() string { if constant.CustomCallbackAddress == "" { - return common.ServerAddress + return constant.ServerAddress } return constant.CustomCallbackAddress } diff --git a/common/image.go b/service/image.go similarity index 84% rename from common/image.go rename to service/image.go index 41ff51f..f3eddff 100644 --- a/common/image.go +++ b/service/image.go @@ -1,4 +1,4 @@ -package common +package service import ( "bytes" @@ -8,7 +8,7 @@ import ( "golang.org/x/image/webp" "image" "io" - "net/http" + "one-api/common" "strings" ) @@ -31,25 +31,13 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e return config, format, base64String, err } -func IsImageUrl(url string) (bool, error) { - resp, err := http.Head(url) - if err != nil { - return false, err - } - if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") { - return false, nil - } - return true, nil -} - // GetImageFromUrl 获取图片的类型和base64编码的数据 func GetImageFromUrl(url string) (mimeType string, data string, err error) { - isImage, err := IsImageUrl(url) - if !isImage { + resp, err := DoImageRequest(url) + if err != nil { return } - resp, err := http.Get(url) - if err != nil { + if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") { return } defer resp.Body.Close() @@ -64,16 +52,21 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) { } func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { - response, err := http.Get(imageUrl) + response, err := DoImageRequest(imageUrl) if err != nil { - SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) + common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) return image.Config{}, "", err } defer response.Body.Close() + if response.StatusCode != 200 { + err = errors.New(fmt.Sprintf("fail to get image from url: %s", response.Status)) + return image.Config{}, "", err + } + var readData []byte for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} { - SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) + common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) // 从response.Body读取更多的数据直到达到当前的限制 additionalData := make([]byte, limit-int64(len(readData))) @@ -99,11 +92,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) { config, format, err := image.DecodeConfig(reader) if err != nil { err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) - SysLog(err.Error()) + common.SysLog(err.Error()) config, err = webp.DecodeConfig(reader) if err != nil { err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) - SysLog(err.Error()) + common.SysLog(err.Error()) } format = "webp" } diff --git a/service/token_counter.go b/service/token_counter.go index c540ac5..cdca1fd 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -86,11 +86,10 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in var err error var format string if strings.HasPrefix(imageUrl.Url, "http") { - common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url)) - config, format, err = common.DecodeUrlImageData(imageUrl.Url) + config, format, err = DecodeUrlImageData(imageUrl.Url) } else { common.SysLog(fmt.Sprintf("decoding image")) - config, format, _, err = common.DecodeBase64ImageData(imageUrl.Url) + config, format, _, err = DecodeBase64ImageData(imageUrl.Url) } if err != nil { return 0, err diff --git a/service/worker.go b/service/worker.go new file mode 100644 index 0000000..e950eaf --- /dev/null +++ b/service/worker.go @@ -0,0 +1,26 @@ +package service + +import ( + "bytes" + "fmt" + "net/http" + "one-api/common" + "one-api/constant" + "strings" +) + +func DoImageRequest(originUrl string) (resp *http.Response, err error) { + if constant.EnableWorker() { + common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl)) + workerUrl := constant.WorkerUrl + if !strings.HasSuffix(workerUrl, "/") { + workerUrl += "/" + } + // post request to worker + data := []byte(`{"url":"` + originUrl + `","key":"` + constant.WorkerValidKey + `"}`) + return http.Post(constant.WorkerUrl, "application/json", bytes.NewBuffer(data)) + } else { + common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl)) + return http.Get(originUrl) + } +} diff --git a/web/src/components/SystemSetting.js b/web/src/components/SystemSetting.js index 6abc6f1..44b2ce2 100644 --- a/web/src/components/SystemSetting.js +++ b/web/src/components/SystemSetting.js @@ -27,6 +27,8 @@ const SystemSetting = () => { SMTPFrom: '', SMTPToken: '', ServerAddress: '', + WorkerUrl: '', + WorkerValidKey: '', EpayId: '', EpayKey: '', Price: 7.3, @@ -145,6 +147,8 @@ const SystemSetting = () => { name === 'Notice' || (name.startsWith('SMTP') && name !== 'SMTPSSLEnabled') || name === 'ServerAddress' || + name === 'WorkerUrl' || + name === 'WorkerValidKey' || name === 'EpayId' || name === 'EpayKey' || name === 'Price' || @@ -172,6 +176,14 @@ const SystemSetting = () => { await updateOption('ServerAddress', ServerAddress); }; + const submitWorker = async () => { + let WorkerUrl = removeTrailingSlash(inputs.WorkerUrl); + await updateOption('WorkerUrl', WorkerUrl); + if (inputs.WorkerValidKey !== '') { + await updateOption('WorkerValidKey', inputs.WorkerValidKey); + } + } + const submitPayAddress = async () => { if (inputs.ServerAddress === '') { showError('请先填写服务器地址'); @@ -327,6 +339,28 @@ const SystemSetting = () => { 更新服务器地址 +
+ 代理设置(支持 new-api-worker) +
+ + + + + + 更新Worker设置 +
支付设置(当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!)