Merge remote-tracking branch 'origin/main'

This commit is contained in:
1808837298@qq.com 2024-05-30 21:46:22 +08:00
commit 4dd5233f49
15 changed files with 105 additions and 39 deletions

View File

@ -12,7 +12,6 @@ import (
var StartTime = time.Now().Unix() // unit: second var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "New API" var SystemName = "New API"
var ServerAddress = "http://localhost:3000"
var Footer = "" var Footer = ""
var Logo = "" var Logo = ""
var TopUpLink = "" var TopUpLink = ""

9
constant/system.go Normal file
View File

@ -0,0 +1,9 @@
package constant
var ServerAddress = "http://localhost:3000"
var WorkerUrl = ""
var WorkerValidKey = ""
func EnableWorker() bool {
return WorkerUrl != ""
}

View File

@ -235,7 +235,7 @@ func GetAllMidjourney(c *gin.Context) {
} }
if constant.MjForwardUrlEnabled { if constant.MjForwardUrlEnabled {
for i, midjourney := range logs { for i, midjourney := range logs {
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney logs[i] = midjourney
} }
} }
@ -267,7 +267,7 @@ func GetUserMidjourney(c *gin.Context) {
} }
if constant.MjForwardUrlEnabled { if constant.MjForwardUrlEnabled {
for i, midjourney := range logs { for i, midjourney := range logs {
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney logs[i] = midjourney
} }
} }

View File

@ -45,7 +45,7 @@ func GetStatus(c *gin.Context) {
"footer_html": common.Footer, "footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL, "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled, "wechat_login": common.WeChatAuthEnabled,
"server_address": common.ServerAddress, "server_address": constant.ServerAddress,
"price": constant.Price, "price": constant.Price,
"min_topup": constant.MinTopUp, "min_topup": constant.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled, "turnstile_check": common.TurnstileCheckEnabled,
@ -203,7 +203,7 @@ func SendPasswordResetEmail(c *gin.Context) {
} }
code := common.GenerateVerificationCode(0) code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", constant.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", common.SystemName) subject := fmt.Sprintf("%s密码重置", common.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ "<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+

View File

@ -92,7 +92,7 @@ func RequestEpay(c *gin.Context) {
payType = epay.WechatPay payType = epay.WechatPay
} }
callBackAddress := service.GetCallbackAddress() callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(common.ServerAddress + "/log") returnUrl, _ := url.Parse(constant.ServerAddress + "/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix()) tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
client := GetEpayClient() client := GetEpayClient()

View File

@ -59,6 +59,8 @@ func InitOptionMap() {
common.OptionMap["SystemName"] = common.SystemName common.OptionMap["SystemName"] = common.SystemName
common.OptionMap["Logo"] = common.Logo common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = "" common.OptionMap["ServerAddress"] = ""
common.OptionMap["WorkerUrl"] = constant.WorkerUrl
common.OptionMap["WorkerValidKey"] = constant.WorkerValidKey
common.OptionMap["PayAddress"] = "" common.OptionMap["PayAddress"] = ""
common.OptionMap["CustomCallbackAddress"] = "" common.OptionMap["CustomCallbackAddress"] = ""
common.OptionMap["EpayId"] = "" common.OptionMap["EpayId"] = ""
@ -232,7 +234,11 @@ func updateOptionMap(key string, value string) (err error) {
case "SMTPToken": case "SMTPToken":
common.SMTPToken = value common.SMTPToken = value
case "ServerAddress": case "ServerAddress":
common.ServerAddress = value constant.ServerAddress = value
case "WorkerUrl":
constant.WorkerUrl = value
case "WorkerValidKey":
constant.WorkerValidKey = value
case "PayAddress": case "PayAddress":
constant.PayAddress = value constant.PayAddress = value
case "CustomCallbackAddress": case "CustomCallbackAddress":

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"one-api/constant"
"strconv" "strconv"
"strings" "strings"
) )
@ -297,7 +298,7 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
prompt = "您的额度已用尽" prompt = "您的额度已用尽"
} }
if email != "" { if email != "" {
topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) topUpLink := fmt.Sprintf("%s/topup", constant.ServerAddress)
err = common.SendEmail(prompt, email, err = common.SendEmail(prompt, email,
fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil { if err != nil {

View File

@ -138,11 +138,11 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
// 判断是否是url // 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") { if strings.HasPrefix(imageUrl.Url, "http") {
// 是url获取图片的类型和base64编码的数据 // 是url获取图片的类型和base64编码的数据
mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url) mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url)
claudeMediaMessage.Source.MediaType = mimeType claudeMediaMessage.Source.MediaType = mimeType
claudeMediaMessage.Source.Data = data claudeMediaMessage.Source.Data = data
} else { } else {
_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url) _, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -74,7 +74,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
if imageNum > GeminiVisionMaxImageNum { if imageNum > GeminiVisionMaxImageNum {
continue continue
} }
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{ parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{ InlineData: &GeminiInlineData{
MimeType: mimeType, MimeType: mimeType,

View File

@ -111,7 +111,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.FinishTime = originTask.FinishTime midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = "" midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled { if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled {
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId midjourneyTask.ImageUrl = constant.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" { if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
} }

View File

@ -1,13 +1,12 @@
package service package service
import ( import (
"one-api/common"
"one-api/constant" "one-api/constant"
) )
func GetCallbackAddress() string { func GetCallbackAddress() string {
if constant.CustomCallbackAddress == "" { if constant.CustomCallbackAddress == "" {
return common.ServerAddress return constant.ServerAddress
} }
return constant.CustomCallbackAddress return constant.CustomCallbackAddress
} }

View File

@ -1,4 +1,4 @@
package common package service
import ( import (
"bytes" "bytes"
@ -8,7 +8,7 @@ import (
"golang.org/x/image/webp" "golang.org/x/image/webp"
"image" "image"
"io" "io"
"net/http" "one-api/common"
"strings" "strings"
) )
@ -31,25 +31,13 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
return config, format, base64String, err return config, format, base64String, err
} }
func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}
// GetImageFromUrl 获取图片的类型和base64编码的数据 // GetImageFromUrl 获取图片的类型和base64编码的数据
func GetImageFromUrl(url string) (mimeType string, data string, err error) { func GetImageFromUrl(url string) (mimeType string, data string, err error) {
isImage, err := IsImageUrl(url) resp, err := DoImageRequest(url)
if !isImage { if err != nil {
return return
} }
resp, err := http.Get(url) if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
if err != nil {
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
@ -64,16 +52,21 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
} }
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
response, err := http.Get(imageUrl) response, err := DoImageRequest(imageUrl)
if err != nil { if err != nil {
SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, "", err return image.Config{}, "", err
} }
defer response.Body.Close() defer response.Body.Close()
if response.StatusCode != 200 {
err = errors.New(fmt.Sprintf("fail to get image from url: %s", response.Status))
return image.Config{}, "", err
}
var readData []byte var readData []byte
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} { for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
// 从response.Body读取更多的数据直到达到当前的限制 // 从response.Body读取更多的数据直到达到当前的限制
additionalData := make([]byte, limit-int64(len(readData))) additionalData := make([]byte, limit-int64(len(readData)))
@ -99,11 +92,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) {
config, format, err := image.DecodeConfig(reader) config, format, err := image.DecodeConfig(reader)
if err != nil { if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
SysLog(err.Error()) common.SysLog(err.Error())
config, err = webp.DecodeConfig(reader) config, err = webp.DecodeConfig(reader)
if err != nil { if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
SysLog(err.Error()) common.SysLog(err.Error())
} }
format = "webp" format = "webp"
} }

View File

@ -86,11 +86,10 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
var err error var err error
var format string var format string
if strings.HasPrefix(imageUrl.Url, "http") { if strings.HasPrefix(imageUrl.Url, "http") {
common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url)) config, format, err = DecodeUrlImageData(imageUrl.Url)
config, format, err = common.DecodeUrlImageData(imageUrl.Url)
} else { } else {
common.SysLog(fmt.Sprintf("decoding image")) common.SysLog(fmt.Sprintf("decoding image"))
config, format, _, err = common.DecodeBase64ImageData(imageUrl.Url) config, format, _, err = DecodeBase64ImageData(imageUrl.Url)
} }
if err != nil { if err != nil {
return 0, err return 0, err

26
service/worker.go Normal file
View File

@ -0,0 +1,26 @@
package service
import (
"bytes"
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"strings"
)
func DoImageRequest(originUrl string) (resp *http.Response, err error) {
if constant.EnableWorker() {
common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
workerUrl := constant.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/"
}
// post request to worker
data := []byte(`{"url":"` + originUrl + `","key":"` + constant.WorkerValidKey + `"}`)
return http.Post(constant.WorkerUrl, "application/json", bytes.NewBuffer(data))
} else {
common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
return http.Get(originUrl)
}
}

View File

@ -27,6 +27,8 @@ const SystemSetting = () => {
SMTPFrom: '', SMTPFrom: '',
SMTPToken: '', SMTPToken: '',
ServerAddress: '', ServerAddress: '',
WorkerUrl: '',
WorkerValidKey: '',
EpayId: '', EpayId: '',
EpayKey: '', EpayKey: '',
Price: 7.3, Price: 7.3,
@ -145,6 +147,8 @@ const SystemSetting = () => {
name === 'Notice' || name === 'Notice' ||
(name.startsWith('SMTP') && name !== 'SMTPSSLEnabled') || (name.startsWith('SMTP') && name !== 'SMTPSSLEnabled') ||
name === 'ServerAddress' || name === 'ServerAddress' ||
name === 'WorkerUrl' ||
name === 'WorkerValidKey' ||
name === 'EpayId' || name === 'EpayId' ||
name === 'EpayKey' || name === 'EpayKey' ||
name === 'Price' || name === 'Price' ||
@ -172,6 +176,14 @@ const SystemSetting = () => {
await updateOption('ServerAddress', ServerAddress); await updateOption('ServerAddress', ServerAddress);
}; };
const submitWorker = async () => {
let WorkerUrl = removeTrailingSlash(inputs.WorkerUrl);
await updateOption('WorkerUrl', WorkerUrl);
if (inputs.WorkerValidKey !== '') {
await updateOption('WorkerValidKey', inputs.WorkerValidKey);
}
}
const submitPayAddress = async () => { const submitPayAddress = async () => {
if (inputs.ServerAddress === '') { if (inputs.ServerAddress === '') {
showError('请先填写服务器地址'); showError('请先填写服务器地址');
@ -327,6 +339,28 @@ const SystemSetting = () => {
<Form.Button onClick={submitServerAddress}> <Form.Button onClick={submitServerAddress}>
更新服务器地址 更新服务器地址
</Form.Button> </Form.Button>
<Header as='h3' inverted={isDark}>
代理设置支持 <a href='https://github.com/Calcium-Ion/new-api-worker' target='_blank' rel='noreferrer'>new-api-worker</a>
</Header>
<Form.Group widths='equal'>
<Form.Input
label='Worker地址不填写则不启用代理'
placeholder='例如https://workername.yourdomain.workers.dev'
value={inputs.WorkerUrl}
name='WorkerUrl'
onChange={handleInputChange}
/>
<Form.Input
label='Worker密钥根据你部署的 Worker 填写'
placeholder='例如your_secret_key'
value={inputs.WorkerValidKey}
name='WorkerValidKey'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Button onClick={submitWorker}>
更新Worker设置
</Form.Button>
<Divider /> <Divider />
<Header as='h3' inverted={isDark}> <Header as='h3' inverted={isDark}>
支付设置当前仅支持易支付接口默认使用上方服务器地址作为回调地址 支付设置当前仅支持易支付接口默认使用上方服务器地址作为回调地址