merged v4.0.9 and fixed conflicts

This commit is contained in:
RockYang
2024-09-05 11:02:32 +08:00
47 changed files with 1927 additions and 427 deletions

View File

@@ -109,13 +109,13 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
logger.Debugf("绘画参数:%+v", task)
prompt := task.Prompt
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
if err != nil {
return "", fmt.Errorf("error with translate prompt: %v", err)
if utils.HasChinese(prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt))
if err == nil {
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
}
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
}
var user model.User
@@ -124,10 +124,28 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
return "", errors.New("insufficient of power")
}
// 更新用户算力
tx := s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
s.db.Where("id", user.Id).First(&u)
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: task.Power,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
CreatedAt: time.Now(),
})
}
// get image generation API KEY
var apiKey model.ApiKey
tx := s.db.Where("platform", types.OpenAI.Value).
Where("type", "img").
tx = s.db.Where("type", "img").
Where("enabled", true).
Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
@@ -139,25 +157,28 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
if len(apiKey.ProxyURL) > 5 {
s.httpClient.SetProxyURL(apiKey.ProxyURL).R()
}
logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: task.Size,
Style: task.Style,
Quality: task.Quality,
}).
SetErrorResult(&errRes).
SetSuccessResult(&res).Post(apiKey.ApiURL)
reqBody := imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: task.Size,
Style: task.Style,
Quality: task.Quality,
}
logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, BODY: %+v", apiKey.Platform, apiKey.ApiURL, apiKey.Value, reqBody)
request := s.httpClient.R().SetHeader("Content-Type", "application/json")
if apiKey.Platform == types.Azure.Value {
request = request.SetHeader("api-key", apiKey.Value)
} else {
request = request.SetHeader("Authorization", "Bearer "+apiKey.Value)
}
r, err := request.SetBody(reqBody).SetErrorResult(&errRes).SetSuccessResult(&res).Post(apiKey.ApiURL)
if err != nil {
return "", fmt.Errorf("error with send request: %v", err)
}
if r.IsErrorState() {
return "", fmt.Errorf("error with send request: %v", errRes.Error)
return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
}
// update the api key last use time
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
@@ -181,25 +202,6 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片\n\n![](%s)\n", prompt, imgURL)
}
// 更新用户算力
tx = s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
s.db.Where("id", user.Id).First(&u)
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: task.Power,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
CreatedAt: time.Now(),
})
}
return content, nil
}

View File

@@ -91,7 +91,7 @@ func (s *LicenseService) SyncLicense() {
if err != nil {
retryCounter++
if retryCounter < 5 {
logger.Error(err)
logger.Warn(err)
}
s.license.IsActive = false
} else {

View File

@@ -179,14 +179,14 @@ func (p *ServicePool) HasAvailableService() bool {
// SyncTaskProgress 异步拉取任务
func (p *ServicePool) SyncTaskProgress() {
go func() {
var items []model.MidJourneyJob
var jobs []model.MidJourneyJob
for {
res := p.db.Where("progress < ?", 100).Find(&items)
res := p.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range items {
for _, job := range jobs {
// 失败或者 30 分钟还没完成的任务删除并退回算力
if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
p.db.Delete(&job)

View File

@@ -8,12 +8,13 @@ package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"github.com/smartwalle/alipay/v3"
"log"
"net/url"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/alipay"
"net/http"
"os"
)
@@ -35,93 +36,90 @@ func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
return nil, fmt.Errorf("error with read App Private key: %v", err)
}
xClient, err := alipay.New(config.AppId, priKey, !config.SandBox)
client, err := alipay.NewClient(config.AppId, priKey, !config.SandBox)
if err != nil {
return nil, fmt.Errorf("error with initialize alipay service: %v", err)
}
if err = xClient.LoadAppCertPublicKeyFromFile(config.PublicKey); err != nil {
return nil, fmt.Errorf("error with loading App PublicKey: %v", err)
}
if err = xClient.LoadAliPayRootCertFromFile(config.RootCert); err != nil {
return nil, fmt.Errorf("error with loading alipay RootCert: %v", err)
}
if err = xClient.LoadAlipayCertPublicKeyFromFile(config.AlipayPublicKey); err != nil {
return nil, fmt.Errorf("error with loading Alipay PublicKey: %v", err)
//client.DebugSwitch = gopay.DebugOn // 开启调试模式
client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间
SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
SetSignType(alipay.RSA2). // 设置签名类型,不设置默认 RSA2
SetReturnUrl(config.ReturnURL). // 设置返回URL
SetNotifyUrl(config.NotifyURL)
if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil {
return nil, fmt.Errorf("error with load payment public key: %v", err)
}
return &AlipayService{config: &config, client: xClient}, nil
return &AlipayService{config: &config, client: client}, nil
}
func (s *AlipayService) PayUrlMobile(outTradeNo string, notifyURL string, returnURL string, Amount string, subject string) (string, error) {
var p = alipay.TradeWapPay{}
p.NotifyURL = notifyURL
p.ReturnURL = returnURL
p.Subject = subject
p.OutTradeNo = outTradeNo
p.TotalAmount = Amount
p.ProductCode = "QUICK_WAP_WAY"
res, err := s.client.TradeWapPay(p)
if err != nil {
return "", err
}
return res.String(), err
func (s *AlipayService) PayUrlMobile(outTradeNo string, amount string, subject string) (string, error) {
bm := make(gopay.BodyMap)
bm.Set("subject", subject)
bm.Set("out_trade_no", outTradeNo)
bm.Set("quit_url", s.config.ReturnURL)
bm.Set("total_amount", amount)
bm.Set("product_code", "QUICK_WAP_WAY")
return s.client.TradeWapPay(context.Background(), bm)
}
func (s *AlipayService) PayUrlPc(outTradeNo string, notifyURL string, returnURL string, amount string, subject string) (string, error) {
var p = alipay.TradePagePay{}
p.NotifyURL = notifyURL
p.ReturnURL = returnURL
p.Subject = subject
p.OutTradeNo = outTradeNo
p.TotalAmount = amount
p.ProductCode = "FAST_INSTANT_TRADE_PAY"
res, err := s.client.TradePagePay(p)
if err != nil {
return "", nil
}
return res.String(), err
func (s *AlipayService) PayUrlPc(outTradeNo string, amount string, subject string) (string, error) {
bm := make(gopay.BodyMap)
bm.Set("subject", subject)
bm.Set("out_trade_no", outTradeNo)
bm.Set("total_amount", amount)
bm.Set("product_code", "FAST_INSTANT_TRADE_PAY")
return s.client.TradePagePay(context.Background(), bm)
}
// TradeVerify 交易验证
func (s *AlipayService) TradeVerify(reqForm url.Values) NotifyVo {
err := s.client.VerifySign(reqForm)
func (s *AlipayService) TradeVerify(request *http.Request) NotifyVo {
notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法
if err != nil {
log.Println("异步通知验证签名发生错误", err)
return NotifyVo{
Status: 0,
Message: "异步通知验证签名发生错误",
Status: Failure,
Message: "error with parse notify request: " + err.Error(),
}
}
return s.TradeQuery(reqForm.Get("out_trade_no"))
_, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq)
if err != nil {
return NotifyVo{
Status: Failure,
Message: "error with verify sign: " + err.Error(),
}
}
return s.TradeQuery(request.Form.Get("out_trade_no"))
}
func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo {
var p = alipay.TradeQuery{}
p.OutTradeNo = outTradeNo
rsp, err := s.client.TradeQuery(p)
bm := make(gopay.BodyMap)
bm.Set("out_trade_no", outTradeNo)
//查询订单
rsp, err := s.client.TradeQuery(context.Background(), bm)
if err != nil {
return NotifyVo{
Status: 0,
Status: Failure,
Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(),
}
}
if rsp.IsSuccess() == true && rsp.TradeStatus == "TRADE_SUCCESS" {
if rsp.Response.TradeStatus == "TRADE_SUCCESS" {
return NotifyVo{
Status: 1,
OutTradeNo: rsp.OutTradeNo,
TradeNo: rsp.TradeNo,
Amount: rsp.TotalAmount,
Subject: rsp.Subject,
Status: Success,
OutTradeNo: rsp.Response.OutTradeNo,
TradeId: rsp.Response.TradeNo,
Amount: rsp.Response.TotalAmount,
Subject: rsp.Response.Subject,
Message: "OK",
}
} else {
return NotifyVo{
Status: 0,
Status: Failure,
Message: "异步查询验证订单信息发生错误" + outTradeNo,
}
}
@@ -134,16 +132,3 @@ func readKey(filename string) (string, error) {
}
return string(data), nil
}
type NotifyVo struct {
Status int
OutTradeNo string
TradeNo string
Amount string
Message string
Subject string
}
func (v NotifyVo) Success() bool {
return v.Status == 1
}

View File

@@ -21,12 +21,12 @@ import (
"strings"
)
type PayJS struct {
type JPayService struct {
config *types.JPayConfig
}
func NewPayJS(appConfig *types.AppConfig) *PayJS {
return &PayJS{
func NewJPayService(appConfig *types.AppConfig) *JPayService {
return &JPayService{
config: &appConfig.JPayConfig,
}
}
@@ -53,7 +53,7 @@ func (r JPayReps) IsOK() bool {
return r.ReturnMsg == "SUCCESS"
}
func (js *PayJS) Pay(param JPayReq) JPayReps {
func (js *JPayService) Pay(param JPayReq) JPayReps {
param.NotifyURL = js.config.NotifyURL
var p = url.Values{}
encode := utils.JsonEncode(param)
@@ -86,13 +86,13 @@ func (js *PayJS) Pay(param JPayReq) JPayReps {
return data
}
func (js *PayJS) PayH5(p url.Values) string {
func (js *JPayService) PayH5(p url.Values) string {
p.Add("mchid", js.config.AppId)
p.Add("sign", js.sign(p))
return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode())
}
func (js *PayJS) sign(params url.Values) string {
func (js *JPayService) sign(params url.Values) string {
params.Del(`sign`)
var keys = make([]string, 0, 0)
for key := range params {
@@ -117,20 +117,18 @@ func (js *PayJS) sign(params url.Values) string {
return strings.ToUpper(md5res)
}
// Check 查询订单支付状态
// TradeVerify 查询订单支付状态
// @param tradeNo 支付平台交易 ID
func (js *PayJS) Check(tradeNo string) error {
func (js *JPayService) TradeVerify(tradeNo string) error {
apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL)
params := url.Values{}
params.Add("payjs_order_id", tradeNo)
params.Add("sign", js.sign(params))
data := strings.NewReader(params.Encode())
resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data)
defer resp.Body.Close()
if err != nil {
return fmt.Errorf("error with http reqeust: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {

View File

@@ -0,0 +1,19 @@
package payment
type NotifyVo struct {
Status int
OutTradeNo string // 商户订单号
TradeId string // 交易ID
Amount string // 交易金额
Message string
Subject string
}
func (v NotifyVo) Success() bool {
return v.Status == Success
}
const (
Success = 0
Failure = 1
)

View File

@@ -0,0 +1,135 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"fmt"
"geekai/core/types"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/wechat/v3"
"net/http"
"time"
)
type WechatPayService struct {
config *types.WechatPayConfig
client *wechat.ClientV3
}
func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) {
config := appConfig.WechatPayConfig
if !config.Enabled {
logger.Info("Disabled WechatPay service")
return nil, nil
}
priKey, err := readKey(config.PrivateKey)
if err != nil {
return nil, fmt.Errorf("error with read App Private key: %v", err)
}
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, priKey)
if err != nil {
return nil, fmt.Errorf("error with initialize WechatPay service: %v", err)
}
err = client.AutoVerifySign()
if err != nil {
return nil, fmt.Errorf("error with autoVerifySign: %v", err)
}
//client.DebugSwitch = gopay.DebugOn
return &WechatPayService{config: &config, client: client}, nil
}
func (s *WechatPayService) PayUrlNative(outTradeNo string, amount int, subject string) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", subject).
Set("out_trade_no", outTradeNo).
Set("time_expire", expire).
Set("notify_url", s.config.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", amount).
Set("currency", "CNY")
})
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.CodeUrl, nil
}
func (s *WechatPayService) PayUrlH5(outTradeNo string, amount int, subject string, ip string) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", subject).
Set("out_trade_no", outTradeNo).
Set("time_expire", expire).
Set("notify_url", s.config.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", amount).
Set("currency", "CNY")
}).
SetBodyMap("scene_info", func(bm gopay.BodyMap) {
bm.Set("payer_client_ip", ip).
SetBodyMap("h5_info", func(bm gopay.BodyMap) {
bm.Set("type", "Wap")
})
})
wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.H5Url, nil
}
type NotifyResponse struct {
Code string `json:"code"`
Message string `xml:"message"`
}
// TradeVerify 交易验证
func (s *WechatPayService) TradeVerify(request *http.Request) NotifyVo {
notifyReq, err := wechat.V3ParseNotify(request)
if err != nil {
return NotifyVo{Status: 1, Message: fmt.Sprintf("error with client v3 parse notify: %v", err)}
}
// TODO: 这里验签程序有 Bug一直报错crypto/rsa: verification error先暂时取消验签
//err = notifyReq.VerifySignByPK(s.client.WxPublicKey())
//if err != nil {
// return fmt.Errorf("error with client v3 verify sign: %v", err)
//}
// 解密支付密文,验证订单信息
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
if err != nil {
return NotifyVo{Status: Failure, Message: fmt.Sprintf("error with client v3 decrypt: %v", err)}
}
return NotifyVo{
Status: Success,
OutTradeNo: result.OutTradeNo,
TradeId: result.TransactionId,
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
}
}

View File

@@ -132,7 +132,7 @@ func (p *ServicePool) CheckTaskStatus() {
continue
}
}
time.Sleep(time.Second * 10)
time.Sleep(time.Second * 5)
}
}()
}

View File

@@ -192,7 +192,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
return
}
task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1))
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)})
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params), Prompt: task.Params.Prompt})
errChan <- nil
}()

View File

@@ -28,8 +28,8 @@ func NewSmtpService(appConfig *types.AppConfig) *SmtpService {
}
func (s *SmtpService) SendVerifyCode(to string, code int) error {
subject := "Geek-AI 注册验证码"
body := fmt.Sprintf("您正在注册 Geek-AI 助手账户,注册验证码为 %d请不要告诉他人。如非本人操作请忽略此邮件。", code)
subject := fmt.Sprintf("%s 注册验证码", s.config.AppName)
body := fmt.Sprintf("您正在注册 %s 账户,注册验证码为 %d请不要告诉他人。如非本人操作请忽略此邮件。", s.config.AppName, code)
auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host)
if s.config.UseTls {

View File

@@ -106,23 +106,26 @@ func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (
e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("vip", false)
continue
}
// update user
tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", config.VipMonthPower))
// 记录算力变动日志
if tx.Error == nil {
var user model.User
e.db.Where("id", u.Id).First(&user)
e.db.Create(&model.PowerLog{
UserId: u.Id,
Username: u.Username,
Type: types.PowerRecharge,
Amount: config.VipMonthPower,
Mark: types.PowerAdd,
Balance: user.Power,
Model: "系统盘点",
Remark: fmt.Sprintf("VIP会员每月算力派发%d", config.VipMonthPower),
CreatedAt: time.Now(),
})
if u.Power < config.VipMonthPower {
power := config.VipMonthPower - u.Power
// update user
tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", power))
// 记录算力变动日志
if tx.Error == nil {
var user model.User
e.db.Where("id", u.Id).First(&user)
e.db.Create(&model.PowerLog{
UserId: u.Id,
Username: u.Username,
Type: types.PowerRecharge,
Amount: power,
Mark: types.PowerAdd,
Balance: user.Power,
Model: "系统盘点",
Remark: fmt.Sprintf("VIP会员每月算力派发%d", config.VipMonthPower),
CreatedAt: time.Now(),
})
}
}
}
logger.Info("月底盘点完成!")