支付,OSS 服务重构完成

This commit is contained in:
GeekMaster
2025-08-24 19:32:45 +08:00
parent 7fb0aad3c7
commit 536b4b8056
57 changed files with 1663 additions and 1358 deletions

View File

@@ -82,18 +82,21 @@ type AppServer struct {
Config *types.AppConfig
Engine *gin.Engine
SysConfig *types.SystemConfig // system config cache
Redis *redis.Client
}
func NewServer(appConfig *types.AppConfig) *AppServer {
func NewServer(appConfig *types.AppConfig, redis *redis.Client, sysConfig *types.SystemConfig) *AppServer {
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
return &AppServer{
Config: appConfig,
Engine: gin.Default(),
Config: appConfig,
Redis: redis,
Engine: gin.Default(),
SysConfig: sysConfig,
}
}
func (s *AppServer) Init(debug bool, client *redis.Client) {
func (s *AppServer) Init(client *redis.Client) {
s.Engine.Use(corsMiddleware())
s.Engine.Use(staticResourceMiddleware())
s.Engine.Use(authorizeMiddleware(s, client))
@@ -104,21 +107,6 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
}
func (s *AppServer) Run(db *gorm.DB) error {
// 重命名 config 表字段
if db.Migrator().HasColumn(&model.Config{}, "config_json") {
db.Migrator().RenameColumn(&model.Config{}, "config_json", "value")
}
if db.Migrator().HasColumn(&model.Config{}, "marker") {
db.Migrator().RenameColumn(&model.Config{}, "marker", "name")
}
if db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") {
db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key")
}
if db.Migrator().HasIndex(&model.Config{}, "marker") {
db.Migrator().DropIndex(&model.Config{}, "marker")
}
// load system configs
var sysConfig model.Config
err := db.Where("name", "system").First(&sysConfig).Error
@@ -130,57 +118,6 @@ func (s *AppServer) Run(db *gorm.DB) error {
return fmt.Errorf("failed to decode system config: %v", err)
}
// 迁移数据表
logger.Info("Migrating database tables...")
db.AutoMigrate(
&model.ChatItem{},
&model.ChatMessage{},
&model.ChatRole{},
&model.ChatModel{},
&model.InviteCode{},
&model.InviteLog{},
&model.Menu{},
&model.Order{},
&model.Product{},
&model.User{},
&model.Function{},
&model.File{},
&model.Redeem{},
&model.Config{},
&model.ApiKey{},
&model.AdminUser{},
&model.AppType{},
&model.SdJob{},
&model.SunoJob{},
&model.PowerLog{},
&model.VideoJob{},
&model.MidJourneyJob{},
&model.UserLoginLog{},
&model.DallJob{},
&model.JimengJob{},
)
// 手动删除字段
if db.Migrator().HasColumn(&model.Order{}, "deleted_at") {
db.Migrator().DropColumn(&model.Order{}, "deleted_at")
}
if db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") {
db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at")
}
if db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") {
db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at")
}
if db.Migrator().HasColumn(&model.User{}, "chat_config") {
db.Migrator().DropColumn(&model.User{}, "chat_config")
}
if db.Migrator().HasColumn(&model.ChatModel{}, "category") {
db.Migrator().DropColumn(&model.ChatModel{}, "category")
}
if db.Migrator().HasColumn(&model.ChatModel{}, "description") {
db.Migrator().DropColumn(&model.ChatModel{}, "description")
}
logger.Info("Database tables migrated successfully")
// 统计安装信息
go func() {
info, err := host.Info()

View File

@@ -11,10 +11,12 @@ import (
"bytes"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/store/model"
"geekai/utils"
"os"
"github.com/BurntSushi/toml"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
@@ -72,3 +74,78 @@ func SaveConfig(config *types.AppConfig) error {
return os.WriteFile(config.Path, buf.Bytes(), 0644)
}
func LoadSystemConfig(db *gorm.DB) *types.SystemConfig {
// 加载系统配置
var sysConfig model.Config
var baseConfig types.BaseConfig
db.Where("name", "system").First(&sysConfig)
err := utils.JsonDecode(sysConfig.Value, &baseConfig)
if err != nil {
logger.Error("load system config error: ", err)
}
// 加载许可证配置
var license types.License
sysConfig.Id = 0
db.Where("name", types.ConfigKeyLicense).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &license)
if err != nil {
logger.Error("load license config error: ", err)
}
// 加载 GeekAPI 配置
var geekAPIConfig types.GeekAPIConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyGeekAPI).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &geekAPIConfig)
if err != nil {
logger.Error("load geek service config error: ", err)
}
// 加载短信配置
var smsConfig types.SMSConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeySms).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &smsConfig)
if err != nil {
logger.Error("load sms config error: ", err)
}
// 加载 OSS 配置
var ossConfig types.OSSConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyOss).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &ossConfig)
if err != nil {
logger.Error("load oss config error: ", err)
}
// 加载 SMTP 配置
var smtpConfig types.SmtpConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeySmtp).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &smtpConfig)
if err != nil {
logger.Error("load smtp config error: ", err)
}
// 加载支付配置
var paymentConfig types.PaymentConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyPayment).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &paymentConfig)
if err != nil {
logger.Error("load payment config error: ", err)
}
return &types.SystemConfig{
Base: baseConfig,
License: license,
SMS: smsConfig,
OSS: ossConfig,
SMTP: smtpConfig,
Payment: paymentConfig,
GeekAPI: geekAPIConfig,
}
}

View File

@@ -0,0 +1,107 @@
package midware
import (
"context"
"fmt"
"geekai/core/types"
"geekai/utils"
"geekai/utils/resp"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt"
)
// 用户授权验证
func UserAuthMiddleware(secretKey string, redis *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) {
tokenString := c.GetHeader(types.UserAuthHeader)
if tokenString == "" {
resp.NotAuth(c, "无效的授权令牌")
c.Abort()
return
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("不支持的令牌签名方法: %v", token.Header["alg"])
}
return []byte(secretKey), nil
})
if err != nil {
resp.NotAuth(c, fmt.Sprintf("解析授权令牌失败: %v", err))
c.Abort()
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
resp.NotAuth(c, "令牌无效")
c.Abort()
return
}
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
if expr > 0 && int64(expr) < time.Now().Unix() {
resp.NotAuth(c, "令牌过期")
c.Abort()
return
}
key := fmt.Sprintf("users/%v", claims["user_id"])
if _, err := redis.Get(context.Background(), key).Result(); err != nil {
resp.NotAuth(c, "当前用户已退出登录")
c.Abort()
return
}
c.Set(types.LoginUserID, claims["user_id"])
}
}
func AdminAuthMiddleware(secretKey string, redis *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) {
tokenString := c.GetHeader(types.AdminAuthHeader)
if tokenString == "" {
resp.NotAuth(c, "无效的授权令牌")
c.Abort()
return
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("不支持的令牌签名方法: %v", token.Header["alg"])
}
return []byte(secretKey), nil
})
if err != nil {
resp.NotAuth(c, fmt.Sprintf("解析授权令牌失败: %v", err))
c.Abort()
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
resp.NotAuth(c, "令牌无效")
c.Abort()
return
}
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
if expr > 0 && int64(expr) < time.Now().Unix() {
resp.NotAuth(c, "令牌过期")
c.Abort()
return
}
key := fmt.Sprintf("admin/%v", claims["user_id"])
if _, err := redis.Get(context.Background(), key).Result(); err != nil {
resp.NotAuth(c, "当前用户已退出登录")
c.Abort()
return
}
c.Set(types.AdminUserID, claims["user_id"])
}
}

View File

@@ -0,0 +1,80 @@
package midware
import (
"bytes"
"geekai/utils"
"io"
"strings"
"github.com/gin-gonic/gin"
)
// 统一参数处理
func ParameterHandlerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// GET 参数处理
params := c.Request.URL.Query()
for key, values := range params {
for i, value := range values {
params[key][i] = strings.TrimSpace(value)
}
}
// update get parameters
c.Request.URL.RawQuery = params.Encode()
// skip file upload requests
contentType := c.Request.Header.Get("Content-Type")
if strings.Contains(contentType, "multipart/form-data") {
c.Next()
return
}
if strings.Contains(contentType, "application/json") {
// process POST JSON request body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
c.Next()
return
}
// 还原请求体
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 将请求体解析为 JSON
var jsonData map[string]any
if err := c.ShouldBindJSON(&jsonData); err != nil {
c.Next()
return
}
// 对 JSON 数据中的字符串值去除两端空格
trimJSONStrings(jsonData)
// 更新请求体
c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData)))
}
c.Next()
}
}
// 递归对 JSON 数据中的字符串值去除两端空格
func trimJSONStrings(data any) {
switch v := data.(type) {
case map[string]any:
for key, value := range v {
switch valueType := value.(type) {
case string:
v[key] = strings.TrimSpace(valueType)
case map[string]any, []any:
trimJSONStrings(value)
}
}
case []any:
for i, value := range v {
switch valueType := value.(type) {
case string:
v[i] = strings.TrimSpace(valueType)
case map[string]any, []any:
trimJSONStrings(value)
}
}
}
}

View File

@@ -0,0 +1,43 @@
package midware
import (
"context"
"fmt"
"geekai/core/types"
"geekai/utils"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)
// RateLimitEvery 使用 Redis 做固定间隔限流:在 interval 内仅允许一次请求
// Key 优先使用登录用户ID若没有则退化为 route + IP
func RateLimitEvery(redisClient *redis.Client, interval time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
keyID := ""
if userID, ok := c.Get(types.LoginUserID); ok {
keyID = fmt.Sprintf("user:%s", utils.InterfaceToString(userID))
} else {
keyID = fmt.Sprintf("ip:%s", c.ClientIP())
}
fullPath := c.FullPath()
if fullPath == "" {
fullPath = c.Request.URL.Path
}
key := fmt.Sprintf("rl:%s:%s", fullPath, keyID)
okSet, err := redisClient.SetNX(context.Background(), key, 1, interval).Result()
if err != nil {
// Redis 异常时放行,避免误伤可用性
return
}
if !okSet {
c.JSON(http.StatusTooManyRequests, types.BizVo{Code: types.Failed, Message: "请求过于频繁,请稍后重试"})
c.Abort()
return
}
}
}

View File

@@ -17,18 +17,17 @@ type AppConfig struct {
Session Session
AdminSession Session
ProxyURL string
MysqlDns string // mysql 连接地址
StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息
SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config
SmtpConfig SmtpConfig // 邮件发送配置
AlipayConfig AlipayConfig // 支付宝支付渠道配置
HuPiPayConfig HuPiPayConfig // 虎皮椒支付配置
GeekPayConfig GeekPayConfig // GEEK 支付配置
WechatPayConfig WxPayConfig // 微信支付渠道配置
TikaHost string // TiKa 服务器地址
MysqlDns string // mysql 连接地址
StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息
SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config
SmtpConfig SmtpConfig // 邮件发送配置
AlipayConfig AlipayConfig // 支付宝支付渠道配置
GeekPayConfig EpayConfig // GEEK 支付配置
WechatPayConfig WxPayConfig // 微信支付渠道配置
TikaHost string // TiKa 服务器地址
}
type RedisConfig struct {
@@ -58,7 +57,7 @@ func (c RedisConfig) Url() string {
return fmt.Sprintf("%s:%d", c.Host, c.Port)
}
type SystemConfig struct {
type BaseConfig struct {
Title string `json:"title,omitempty"` // 网站标题
Slogan string `json:"slogan,omitempty"` // 网站 slogan
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
@@ -103,18 +102,29 @@ type SystemConfig struct {
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
AssistantModelId int `json:"assistant_model_id"` // 用来做提示词,翻译的AI模型 id
MaxFileSize int `json:"max_file_size"` // 最大文件大小,单位MB
}
type SystemConfig struct {
Base BaseConfig
Payment PaymentConfig
OSS OSSConfig
SMS SMSConfig
SMTP SmtpConfig
GeekAPI GeekAPIConfig
Jimeng JimengConfig
License License
}
// 配置键名常量
const (
ConfigKeySystem = "system"
ConfigKeyNotice = "notice"
ConfigKeyAgreement = "agreement"
ConfigKeyPrivacy = "privacy"
ConfigKeyGeekService = "geekai"
ConfigKeySms = "sms"
ConfigKeySmtp = "smtp"
ConfigKeyOss = "oss"
ConfigKeyPayment = "payment"
ConfigKeySystem = "system"
ConfigKeyNotice = "notice"
ConfigKeyAgreement = "agreement"
ConfigKeyPrivacy = "privacy"
ConfigKeyGeekAPI = "geekapi"
ConfigKeyLicense = "license"
ConfigKeySms = "sms"
ConfigKeySmtp = "smtp"
ConfigKeyOss = "oss"
ConfigKeyPayment = "payment"
)

View File

@@ -23,3 +23,8 @@ type WxLoginConfig struct {
NotifyURL string `json:"notify_url"` // 登录成功回调 URL
Enabled bool `json:"enabled"` // 是否启用微信登录
}
type GeekAPIConfig struct {
Captcha CaptchaConfig
WxLogin WxLoginConfig
}

View File

@@ -13,27 +13,24 @@ const (
OrderNotPaid = OrderStatus(0)
OrderScanned = OrderStatus(1) // 已扫码
OrderPaidSuccess = OrderStatus(2)
OrderPaidFailed = OrderStatus(3)
)
type OrderRemark struct {
Days int `json:"days"` // 有效期
Power int `json:"power"` // 增加算力点数
Name string `json:"name"` // 产品名称
Price float64 `json:"price"`
Discount float64 `json:"discount"`
Days int `json:"days"` // 有效期
Power int `json:"power"` // 增加算力点数
Name string `json:"name"` // 产品名称
Price float64 `json:"price"`
}
var PayMethods = map[string]string{
// PayChannel 支付渠道
var PayChannel = map[string]string{
"alipay": "支付宝商号",
"wechat": "微信商号",
"hupi": "虎皮椒",
"geek": "易支付",
"wxpay": "微信商号",
"epay": "易支付",
}
var PayNames = map[string]string{
var PayWays = map[string]string{
"alipay": "支付宝",
"wxpay": "微信支付",
"qqpay": "QQ钱包",
"jdpay": "京东支付",
"douyin": "抖音支付",
"paypal": "PayPal支付",
}

View File

@@ -1,19 +1,9 @@
package types
type PaymentConfig struct {
AlipayConfig AlipayConfig `json:"alipay"` // 支付宝支付渠道配置
GeekPayConfig GeekPayConfig `json:"geekpay"` // GEEK 支付配置
WxPayConfig WxPayConfig `json:"wxpay"` // 微信支付渠道配置
HuPiPayConfig HuPiPayConfig `json:"hupi"` // 虎皮椒支付渠道配置
}
type HuPiPayConfig struct { //虎皮椒第四方支付配置
Enabled bool // 是否启用该支付通道
AppId string // App ID
AppSecret string // app 密钥
ApiURL string // 支付网关
NotifyURL string // 异步通知地址
ReturnURL string // 同步通知地址
Alipay AlipayConfig `json:"alipay"` // 支付宝支付渠道配置
Epay EpayConfig `json:"epay"` // GEEK 支付配置
WxPay WxPayConfig `json:"wxpay"` // 微信支付渠道配置
}
// AlipayConfig 支付宝支付配置
@@ -53,8 +43,8 @@ func (c *WxPayConfig) Equal(other *WxPayConfig) bool {
c.Domain == other.Domain
}
// GeekPayConfig 易支付配置
type GeekPayConfig struct {
// EpayConfig 易支付配置
type EpayConfig struct {
Enabled bool `json:"enabled"` // 是否启用该支付通道
AppId string `json:"app_id"` // 商户 ID
PrivateKey string `json:"private_key"` // 私钥
@@ -62,7 +52,7 @@ type GeekPayConfig struct {
Domain string `json:"domain"` // 支付回调域名
}
func (c *GeekPayConfig) Equal(other *GeekPayConfig) bool {
func (c *EpayConfig) Equal(other *EpayConfig) bool {
return c.AppId == other.AppId &&
c.PrivateKey == other.PrivateKey &&
c.ApiURL == other.ApiURL &&

View File

@@ -8,6 +8,7 @@ package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
const LoginUserID = "LOGIN_USER_ID"
const AdminUserID = "ADMIN_USER_ID"
const LoginUserCache = "LOGIN_USER_CACHE"
const UserAuthHeader = "Authorization"

View File

@@ -17,8 +17,6 @@ type SMSConfig struct {
type SmsConfigAli struct {
AccessKey string
AccessSecret string
Product string
Domain string
Sign string // 短信签名
CodeTempId string // 验证码短信模板 ID
}
@@ -27,7 +25,6 @@ type SmsConfigAli struct {
type SmsConfigBao struct {
Username string //短信宝平台注册的用户名
Password string //短信宝平台注册的密码
Domain string //域名
Sign string // 短信签名
CodeTemplate string // 验证码短信模板 匹配
}

View File

@@ -27,6 +27,7 @@ require (
require (
github.com/go-pay/gopay v1.5.101
github.com/go-rod/rod v0.116.2
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/go-tika v0.3.1
github.com/microcosm-cc/bluemonday v1.0.26
github.com/sashabaranov/go-openai v1.38.1

View File

@@ -87,6 +87,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=

View File

@@ -19,9 +19,10 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"time"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -72,7 +73,7 @@ func (h *ManagerHandler) Login(c *gin.Context) {
return
}
if h.App.SysConfig.EnabledVerify {
if h.App.SysConfig.Base.EnabledVerify {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)

View File

@@ -47,7 +47,6 @@ func (h *ConfigHandler) RegisterRoutes() {
group.GET("get", h.Get)
group.POST("active", h.Active)
group.POST("test", h.Test)
group.GET("fixData", h.FixData)
group.GET("license", h.GetLicense)
}
@@ -70,7 +69,7 @@ func (h *ConfigHandler) Update(c *gin.Context) {
resp.ERROR(c, "系统配置解析失败: "+err.Error())
return
}
if (sys.Copyright != payload.ConfigBak.Copyright) && !h.licenseService.GetLicense().Configs.DeCopy {
if (sys.Base.Copyright != payload.ConfigBak.Base.Copyright) && !h.licenseService.GetLicense().Configs.DeCopy {
resp.ERROR(c, "您无权修改版权信息,请先联系作者获取授权")
return
}
@@ -162,69 +161,3 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
license := h.licenseService.GetLicense()
resp.SUCCESS(c, license)
}
// FixData 修复数据
func (h *ConfigHandler) FixData(c *gin.Context) {
resp.ERROR(c, "当前升级版本没有数据需要修正!")
//var fixed bool
//version := "data_fix_4.1.4"
//err := h.levelDB.Get(version, &fixed)
//if err == nil || fixed {
// resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作")
// return
//}
//tx := h.DB.Begin()
//var users []model.User
//err = tx.Find(&users).Error
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//for _, user := range users {
// if user.Email != "" || user.Mobile != "" {
// continue
// }
// if utils.IsValidEmail(user.Username) {
// user.Email = user.Username
// } else if utils.IsValidMobile(user.Username) {
// user.Mobile = user.Username
// }
// err = tx.Save(&user).Error
// if err != nil {
// resp.ERROR(c, err.Error())
// tx.Rollback()
// return
// }
//}
//
//var orders []model.Order
//err = h.DB.Find(&orders).Error
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//for _, order := range orders {
// if order.PayWay == "支付宝" {
// order.PayWay = "alipay"
// order.PayType = "alipay"
// } else if order.PayWay == "微信支付" {
// order.PayWay = "wechat"
// order.PayType = "wxpay"
// } else if order.PayWay == "hupi" {
// order.PayType = "wxpay"
// }
// err = tx.Save(&order).Error
// if err != nil {
// resp.ERROR(c, err.Error())
// tx.Rollback()
// return
// }
//}
//tx.Commit()
//err = h.levelDB.Put(version, true)
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//resp.SUCCESS(c)
}

View File

@@ -76,16 +76,16 @@ func (h *OrderHandler) List(c *gin.Context) {
order.Id = item.Id
order.CreatedAt = item.CreatedAt.Unix()
order.UpdatedAt = item.UpdatedAt.Unix()
payMethod, ok := types.PayMethods[item.PayWay]
payChannel, ok := types.PayChannel[item.Channel]
if !ok {
payMethod = item.PayWay
payChannel = item.Channel
}
payName, ok := types.PayNames[item.PayType]
payWays, ok := types.PayWays[item.PayWay]
if !ok {
payName = item.PayWay
payWays = item.PayWay
}
order.PayMethod = payMethod
order.PayName = payName
order.ChannelName = payChannel
order.PayName = payWays
list = append(list, order)
} else {
logger.Error(err)

View File

@@ -41,7 +41,7 @@ func (h *UploadHandler) Upload(c *gin.Context) {
return
}
if h.App.SysConfig.MaxFileSize > 0 && f.Size > int64(h.App.SysConfig.MaxFileSize)*1024*1024 {
if h.App.SysConfig.Base.MaxFileSize > 0 && f.Size > int64(h.App.SysConfig.Base.MaxFileSize)*1024*1024 {
resp.ERROR(c, "文件大小超过限制")
return
}

View File

@@ -21,11 +21,11 @@ import (
type CaptchaHandler struct {
App *core.AppServer
service *service.CaptchaService
config *types.CaptchaConfig
config types.CaptchaConfig
}
func NewCaptchaHandler(app *core.AppServer, s *service.CaptchaService, config *types.CaptchaConfig) *CaptchaHandler {
return &CaptchaHandler{App: app, service: s, config: config}
func NewCaptchaHandler(app *core.AppServer, s *service.CaptchaService, sysConfig *types.SystemConfig) *CaptchaHandler {
return &CaptchaHandler{App: app, service: s, config: sysConfig.GeekAPI.Captcha}
}
// RegisterRoutes 注册路由

View File

@@ -244,9 +244,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
// 加载聊天上下文
chatCtx := make([]any, 0)
messages := make([]any, 0)
if h.App.SysConfig.EnableContext {
if h.App.SysConfig.Base.EnableContext {
_ = utils.JsonDecode(input.ChatRole.Context, &messages)
if h.App.SysConfig.ContextDeep > 0 {
if h.App.SysConfig.Base.ContextDeep > 0 {
var historyMessages []model.ChatMessage
dbSession := h.DB.Session(&gorm.Session{}).Where("chat_id", input.ChatId)
if input.LastMsgId > 0 { // 重新生成逻辑
@@ -254,7 +254,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
// 删除对应的聊天记录
h.DB.Debug().Where("chat_id", input.ChatId).Where("id >= ?", input.LastMsgId).Delete(&model.ChatMessage{})
}
err = dbSession.Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages).Error
err = dbSession.Limit(h.App.SysConfig.Base.ContextDeep).Order("id DESC").Find(&historyMessages).Error
if err == nil {
for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i]
@@ -282,7 +282,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
}
// 上下文的深度超出了模型的最大上下文深度
if len(chatCtx) >= h.App.SysConfig.ContextDeep {
if len(chatCtx) >= h.App.SysConfig.Base.ContextDeep {
break
}

View File

@@ -88,7 +88,7 @@ func (h *DallJobHandler) Image(c *gin.Context) {
Quality: data.Quality,
Size: data.Size,
Style: data.Style,
TranslateModelId: h.App.SysConfig.AssistantModelId,
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
Power: chatModel.Power,
}
job := model.DallJob{

View File

@@ -197,7 +197,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
return
}
if user.Power < h.App.SysConfig.DallPower {
if user.Power < h.App.SysConfig.Base.DallPower {
resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足")
return
}
@@ -209,17 +209,17 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
Prompt: prompt,
ModelId: 0,
ModelName: "dall-e-3",
TranslateModelId: h.App.SysConfig.AssistantModelId,
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
N: 1,
Quality: "standard",
Size: "1024x1024",
Style: "vivid",
Power: h.App.SysConfig.DallPower,
Power: h.App.SysConfig.Base.DallPower,
}
job := model.DallJob{
UserId: user.Id,
Prompt: prompt,
Power: h.App.SysConfig.DallPower,
Power: h.App.SysConfig.Base.DallPower,
TaskInfo: utils.JsonEncode(task),
}
err := h.DB.Create(&job).Error

View File

@@ -39,7 +39,7 @@ func (h *MenuHandler) List(c *gin.Context) {
session := h.DB.Session(&gorm.Session{})
session = session.Where("enabled", true)
if index {
session = session.Where("id IN ?", h.App.SysConfig.IndexNavs)
session = session.Where("id IN ?", h.App.SysConfig.Base.IndexNavs)
}
res := session.Order("sort_num ASC").Find(&items)
if res.Error == nil {

View File

@@ -65,7 +65,7 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
return false
}
if user.Power < h.App.SysConfig.MjPower {
if user.Power < h.App.SysConfig.Base.MjPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false
}
@@ -171,8 +171,8 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
Params: params,
UserId: userId,
ImgArr: data.ImgArr,
Mode: h.App.SysConfig.MjMode,
TranslateModelId: h.App.SysConfig.AssistantModelId,
Mode: h.App.SysConfig.Base.MjMode,
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
}
job := model.MidJourneyJob{
Type: data.TaskType,
@@ -181,7 +181,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
TaskInfo: utils.JsonEncode(task),
Progress: 0,
Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
Power: h.App.SysConfig.MjPower,
Power: h.App.SysConfig.Base.MjPower,
CreatedAt: time.Now(),
}
opt := "绘图"
@@ -244,7 +244,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
Index: data.Index,
MessageId: data.MessageId,
MessageHash: data.MessageHash,
Mode: h.App.SysConfig.MjMode,
Mode: h.App.SysConfig.Base.MjMode,
}
job := model.MidJourneyJob{
Type: types.TaskUpscale.String(),
@@ -252,7 +252,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
TaskId: taskId,
TaskInfo: utils.JsonEncode(task),
Progress: 0,
Power: h.App.SysConfig.MjActionPower,
Power: h.App.SysConfig.Base.MjActionPower,
CreatedAt: time.Now(),
}
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
@@ -299,7 +299,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
ChannelId: data.ChannelId,
MessageId: data.MessageId,
MessageHash: data.MessageHash,
Mode: h.App.SysConfig.MjMode,
Mode: h.App.SysConfig.Base.MjMode,
}
job := model.MidJourneyJob{
Type: types.TaskVariation.String(),
@@ -308,7 +308,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
TaskId: taskId,
TaskInfo: utils.JsonEncode(task),
Progress: 0,
Power: h.App.SysConfig.MjActionPower,
Power: h.App.SysConfig.Base.MjActionPower,
CreatedAt: time.Now(),
}
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {

View File

@@ -55,20 +55,21 @@ func (h *OrderHandler) List(c *gin.Context) {
order.Id = item.Id
order.CreatedAt = item.CreatedAt.Unix()
order.UpdatedAt = item.UpdatedAt.Unix()
payMethod, ok := types.PayMethods[item.PayWay]
payChannel, ok := types.PayChannel[item.Channel]
if !ok {
payMethod = item.PayWay
payChannel = item.PayWay
}
payName, ok := types.PayNames[item.PayType]
payWays, ok := types.PayWays[item.PayWay]
if !ok {
payName = item.PayWay
payWays = item.PayWay
}
order.PayMethod = payMethod
order.PayName = payName
order.ChannelName = payChannel
order.PayName = payWays
list = append(list, order)
} else {
logger.Error(err)
}
}
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))

View File

@@ -9,8 +9,10 @@ package handler
import (
"embed"
"errors"
"fmt"
"geekai/core"
"geekai/core/midware"
"geekai/core/types"
"geekai/service"
"geekai/service/payment"
@@ -33,63 +35,143 @@ type PayWay struct {
// PaymentHandler 支付服务回调 handler
type PaymentHandler struct {
BaseHandler
alipayService *payment.AlipayService
huPiPayService *payment.HuPiPayService
geekPayService *payment.GeekPayService
wechatPayService *payment.WechatPayService
snowflake *service.Snowflake
userService *service.UserService
fs embed.FS
lock sync.Mutex
signKey string // 用来签名的随机秘钥
alipayService *payment.AlipayService
epayService *payment.EPayService
wxpayService *payment.WxPayService
snowflake *service.Snowflake
userService *service.UserService
fs embed.FS
lock sync.Mutex
config *types.PaymentConfig
}
func NewPaymentHandler(
server *core.AppServer,
alipayService *payment.AlipayService,
huPiPayService *payment.HuPiPayService,
geekPayService *payment.GeekPayService,
wechatPayService *payment.WechatPayService,
geekPayService *payment.EPayService,
wxpayService *payment.WxPayService,
db *gorm.DB,
userService *service.UserService,
snowflake *service.Snowflake,
fs embed.FS) *PaymentHandler {
fs embed.FS,
sysConfig *types.SystemConfig) *PaymentHandler {
return &PaymentHandler{
alipayService: alipayService,
huPiPayService: huPiPayService,
geekPayService: geekPayService,
wechatPayService: wechatPayService,
snowflake: snowflake,
userService: userService,
fs: fs,
lock: sync.Mutex{},
alipayService: alipayService,
epayService: geekPayService,
wxpayService: wxpayService,
snowflake: snowflake,
userService: userService,
fs: fs,
lock: sync.Mutex{},
BaseHandler: BaseHandler{
App: server,
DB: db,
},
signKey: utils.RandString(32),
config: &sysConfig.Payment,
}
}
// RegisterRoutes 注册路由
func (h *PaymentHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/payment/")
group.POST("doPay", h.Pay)
group.GET("payWays", h.GetPayWays)
group.POST("notify/alipay", h.AlipayNotify)
group.GET("notify/geek", h.GeekPayNotify)
group.POST("notify/wechat", h.WechatPayNotify)
group.POST("notify/hupi", h.HuPiPayNotify)
rg := h.App.Engine.Group("/api/payment/")
// 支付回调接口(公开)
rg.POST("notify/alipay", h.AlipayNotify)
rg.GET("notify/geek", h.GeekPayNotify)
rg.POST("notify/wechat", h.WechatPayNotify)
// 需要用户登录的接口
rg.Use(midware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
rg.POST("create", h.Pay)
}
// 同步订单状态
h.StartSyncOrders()
}
func (h *PaymentHandler) StartSyncOrders() {
go func() {
for {
err := h.SyncOrders()
if err != nil {
logger.Error(err)
}
time.Sleep(time.Second * 5)
}
}()
}
// SyncOrders 同步订单状态
func (h *PaymentHandler) SyncOrders() error {
defer func() {
if err := recover(); err != nil {
logger.Errorf("同步订单状态发生异常: %v", err)
}
}()
var orders []model.Order
err := h.DB.Where("status", types.OrderNotPaid).Where("checked", false).Find(&orders).Error
if err != nil {
return err
}
for _, order := range orders {
// 超时15分钟的订单直接标记为已关闭
if time.Now().After(order.CreatedAt.Add(time.Minute * 15)) {
h.DB.Model(&model.Order{}).Where("id", order.Id).Update("checked", true)
return errors.New("订单超时")
}
// 查询订单状态
var res payment.OrderInfo
switch order.Channel {
case payment.PayChannelEpay:
res, err = h.epayService.Query(order.OrderNo)
if err != nil {
return fmt.Errorf("error with query order info: %v", err)
}
// 微信支付
case payment.PayChannelWX:
res, err = h.wxpayService.Query(order.OrderNo)
if err != nil {
return fmt.Errorf("error with query order info: %v", err)
}
case payment.PayChannelAL:
res, err = h.alipayService.Query(order.OrderNo)
if err != nil {
return fmt.Errorf("error with query order info: %v", err)
}
}
// 订单已关闭
if res.Closed() {
h.DB.Model(&model.Order{}).Where("id", order.Id).Updates(map[string]any{
"checked": true,
"status": types.OrderPaidFailed,
})
return errors.New("订单已关闭")
}
// 订单未支付,不处理,继续轮询
if !res.Success() {
return nil
}
// 订单支付成功
err = h.paySuccess(res)
if err != nil {
return fmt.Errorf("error with deal order: %v", err)
}
}
return nil
}
func (h *PaymentHandler) Pay(c *gin.Context) {
var data struct {
PayWay string `json:"pay_way"`
PayType string `json:"pay_type"`
ProductId int `json:"product_id"`
UserId int `json:"user_id"`
Device string `json:"device"`
Host string `json:"host"`
PayWay string `json:"pay_way,omitempty"` // 支付方式:支付宝,微信
Pid int `json:"pid,omitempty"`
Device string `json:"device,omitempty"`
Domain string `json:"domain,omitempty"` // 支付回调域名
Channel string `json:"channel,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -97,7 +179,7 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
}
var product model.Product
err := h.DB.Where("id", data.ProductId).First(&product).Error
err := h.DB.Where("id", data.Pid).First(&product).Error
if err != nil {
resp.ERROR(c, "Product not found")
return
@@ -108,136 +190,118 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
userId := h.GetLoginUserId(c)
var user model.User
err = h.DB.Where("id", data.UserId).First(&user).Error
err = h.DB.Where("id", userId).First(&user).Error
if err != nil {
resp.NotAuth(c)
return
}
amount := product.Discount
var payURL, returnURL, notifyURL string
amount := product.Price
var payURL, notifyURL string
switch data.PayWay {
case "alipay":
if h.App.Config.AlipayConfig.NotifyURL != "" { // 用于本地调试支付
notifyURL = h.App.Config.AlipayConfig.NotifyURL
} else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Host)
}
if h.App.Config.AlipayConfig.ReturnURL != "" { // 用于本地调试支付
returnURL = h.App.Config.AlipayConfig.ReturnURL
} else {
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
}
money := fmt.Sprintf("%.2f", amount)
if data.Device == "wechat" {
payURL, err = h.alipayService.PayMobile(payment.AlipayParams{
case "wxpay":
logger.Debugf("微信支付,%+v", data)
data.Channel = payment.PayChannelWX
// 优先使用微信官方支付
if h.config.WxPay.Enabled {
data.Channel = "wxpay"
if h.config.WxPay.Domain != "" {
data.Domain = h.config.WxPay.Domain
}
notifyURL = fmt.Sprintf("%s/api/payment/notify/wechat", data.Domain)
payURL, err = h.wxpayService.Pay(payment.PayRequest{
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: money,
ReturnURL: returnURL,
NotifyURL: notifyURL,
})
} else {
payURL, err = h.alipayService.PayPC(payment.AlipayParams{
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: money,
ReturnURL: returnURL,
NotifyURL: notifyURL,
})
}
if err != nil {
resp.ERROR(c, "error with generate pay url: "+err.Error())
return
}
break
case "wechat":
if h.App.Config.WechatPayConfig.NotifyURL != "" {
notifyURL = h.App.Config.WechatPayConfig.NotifyURL
} else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/wechat", data.Host)
}
if data.Device == "wechat" {
payURL, err = h.wechatPayService.PayUrlH5(payment.WechatPayParams{
OutTradeNo: orderNo,
TotalFee: int(amount * 100),
TotalFee: fmt.Sprintf("%d", int(amount*100)),
Subject: product.Name,
NotifyURL: notifyURL,
ClientIP: c.ClientIP(),
Device: data.Device,
PayWay: payment.PayWayWX,
})
} else {
payURL, err = h.wechatPayService.PayUrlNative(payment.WechatPayParams{
if err != nil {
resp.ERROR(c, err.Error())
return
}
} else if h.config.Epay.Enabled { // 聚合支付
logger.Debugf("聚合支付%+v", data)
data.Channel = payment.PayChannelEpay
if h.config.Epay.Domain != "" {
data.Domain = h.config.Epay.Domain
}
notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Domain)
params := payment.PayRequest{
OutTradeNo: orderNo,
TotalFee: int(amount * 100),
Subject: product.Name,
TotalFee: fmt.Sprintf("%f", amount),
ClientIP: c.ClientIP(),
Device: data.Device,
PayWay: payment.PayWayWX,
NotifyURL: notifyURL,
}
r, err := h.epayService.Pay(params)
logger.Debugf("请求支付结果,%+v", r)
if err != nil {
resp.ERROR(c, err.Error())
return
} else {
payURL = r
}
} else {
resp.ERROR(c, "系统没有配置可用的支付渠道!")
return
}
case "alipay":
if h.config.Alipay.Enabled {
logger.Debugf("支付宝,%+v", data)
data.Channel = payment.PayChannelAL
if h.config.Alipay.Domain != "" { // 用于本地调试支付
data.Domain = h.config.Alipay.Domain
}
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Domain)
money := fmt.Sprintf("%.2f", amount)
payURL, err = h.alipayService.Pay(payment.PayRequest{
Device: data.Device,
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: money,
NotifyURL: notifyURL,
})
}
if err != nil {
resp.ERROR(c, err.Error())
return
}
break
case "hupi":
if h.App.Config.HuPiPayConfig.NotifyURL != "" {
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
} else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/hupi", data.Host)
}
if h.App.Config.HuPiPayConfig.ReturnURL != "" {
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
} else {
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
}
r, err := h.huPiPayService.Pay(payment.HuPiPayParams{
Version: "1.1",
TradeOrderId: orderNo,
TotalFee: fmt.Sprintf("%f", amount),
Title: product.Name,
NotifyURL: notifyURL,
ReturnURL: returnURL,
WapName: "GeekAI助手",
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
payURL = r.URL
break
case "geek":
if h.App.Config.GeekPayConfig.NotifyURL != "" {
notifyURL = h.App.Config.GeekPayConfig.NotifyURL
} else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Host)
}
if h.App.Config.GeekPayConfig.ReturnURL != "" {
data.Host = utils.GetBaseURL(h.App.Config.GeekPayConfig.ReturnURL)
}
if data.Device == "wechat" { // 微信客户端打开,调回手机端用户中心页面
returnURL = fmt.Sprintf("%s/mobile/profile", data.Host)
} else {
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
}
params := payment.GeekPayParams{
OutTradeNo: orderNo,
Method: "web",
Name: product.Name,
Money: fmt.Sprintf("%f", amount),
ClientIP: c.ClientIP(),
Device: data.Device,
Type: data.PayType,
ReturnURL: returnURL,
NotifyURL: notifyURL,
}
res, err := h.geekPayService.Pay(params)
if err != nil {
resp.ERROR(c, err.Error())
if err != nil {
resp.ERROR(c, "error with generate pay url: "+err.Error())
return
}
} else if h.config.Epay.Enabled { // 聚合支付
logger.Debugf("聚合支付,%+v", data)
data.Channel = payment.PayChannelEpay
if h.config.Epay.Domain != "" {
data.Domain = h.config.Epay.Domain
}
notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Domain)
params := payment.PayRequest{
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: fmt.Sprintf("%f", amount),
ClientIP: c.ClientIP(),
Device: data.Device,
PayWay: data.PayWay,
NotifyURL: notifyURL,
}
r, err := h.epayService.Pay(params)
if err != nil {
resp.ERROR(c, err.Error())
return
} else {
payURL = r
}
} else {
resp.ERROR(c, "系统没有配置可用的支付渠道!")
return
}
payURL = res.PayURL
default:
resp.ERROR(c, "不支持的支付渠道")
return
@@ -245,43 +309,41 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
// 创建订单
remark := types.OrderRemark{
Days: product.Days,
Power: product.Power,
Name: product.Name,
Price: product.Price,
Discount: product.Discount,
Days: product.Days,
Power: product.Power,
Name: product.Name,
Price: product.Price,
}
order := model.Order{
UserId: user.Id,
Username: user.Username,
ProductId: product.Id,
OrderNo: orderNo,
Subject: product.Name,
Amount: amount,
Status: types.OrderNotPaid,
PayWay: data.PayWay,
PayType: data.PayType,
Remark: utils.JsonEncode(remark),
UserId: user.Id,
Username: user.Username,
OrderNo: orderNo,
Subject: product.Name,
Amount: amount,
Status: types.OrderNotPaid,
PayWay: data.PayWay,
Channel: data.Channel,
Remark: utils.JsonEncode(remark),
}
err = h.DB.Create(&order).Error
if err != nil {
resp.ERROR(c, "error with create order: "+err.Error())
return
}
resp.SUCCESS(c, payURL)
resp.SUCCESS(c, gin.H{"pay_url": payURL, "order_no": orderNo})
}
// 异步通知回调公共逻辑
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
// 支付成功处理
func (h *PaymentHandler) paySuccess(info payment.OrderInfo) error {
h.lock.Lock()
defer h.lock.Unlock()
var order model.Order
err := h.DB.Where("order_no = ?", orderNo).First(&order).Error
err := h.DB.Where("order_no", info.OutTradeNo).First(&order).Error
if err != nil {
return fmt.Errorf("error with fetch order: %v", err)
}
h.lock.Lock()
defer h.lock.Unlock()
// 已支付订单,直接返回
if order.Status == types.OrderPaidSuccess {
return nil
@@ -301,18 +363,20 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
// 增加用户算力
err = h.userService.IncreasePower(order.UserId, remark.Power, model.PowerLog{
Type: types.PowerRecharge,
Model: order.PayWay,
Remark: fmt.Sprintf("充值算力,金额:%f订单号%s", order.Amount, order.OrderNo),
Type: types.PowerRecharge,
Model: order.Subject,
Remark: fmt.Sprintf("充值算力,金额:%f订单号%s", order.Amount, order.OrderNo),
CreatedAt: time.Now(),
})
if err != nil {
return err
}
// 更新订单状态
order.PayTime = time.Now().Unix()
order.PayTime = utils.Str2stamp(info.PayTime)
order.Status = types.OrderPaidSuccess
order.TradeNo = tradeNo
order.TradeNo = info.TradeId
order.Checked = true
err = h.DB.Updates(&order).Error
if err != nil {
return fmt.Errorf("error with update order info: %v", err)
@@ -328,54 +392,6 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
return nil
}
// GetPayWays 获取支付方式
func (h *PaymentHandler) GetPayWays(c *gin.Context) {
payWays := make([]gin.H, 0)
if h.App.Config.AlipayConfig.Enabled {
payWays = append(payWays, gin.H{"pay_way": "alipay", "pay_type": "alipay"})
}
if h.App.Config.HuPiPayConfig.Enabled {
payWays = append(payWays, gin.H{"pay_way": "hupi", "pay_type": "wxpay"})
}
if h.App.Config.GeekPayConfig.Enabled {
for _, v := range h.App.Config.GeekPayConfig.Methods {
payWays = append(payWays, gin.H{"pay_way": "geek", "pay_type": v})
}
}
if h.App.Config.WechatPayConfig.Enabled {
payWays = append(payWays, gin.H{"pay_way": "wechat", "pay_type": "wxpay"})
}
resp.SUCCESS(c, payWays)
}
// HuPiPayNotify 虎皮椒支付异步回调
func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
err := c.Request.ParseForm()
if err != nil {
c.String(http.StatusOK, "fail")
return
}
orderNo := c.Request.Form.Get("trade_order_id")
tradeNo := c.Request.Form.Get("open_order_id")
logger.Infof("收到虎皮椒订单支付回调,%+v", c.Request.Form)
if err = h.huPiPayService.Check(orderNo); err != nil {
logger.Error("订单校验失败:", err)
c.String(http.StatusOK, "fail")
return
}
err = h.notify(orderNo, tradeNo)
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")
return
}
c.String(http.StatusOK, "success")
}
// AlipayNotify 支付宝支付回调
func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
err := c.Request.ParseForm()
@@ -384,16 +400,15 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
return
}
result := h.alipayService.TradeVerify(c.Request)
logger.Infof("收到支付宝商号订单支付回调:%+v", result)
if !result.Success() {
logger.Error("订单校验失败:", result.Message)
orderInfo, err := h.alipayService.Query(c.Request.Form.Get("out_trade_no"))
logger.Infof("收到支付宝商号订单支付回调:%+v", orderInfo)
if !orderInfo.Success() {
logger.Errorf("订单校验失败:%v", err)
c.String(http.StatusOK, "fail")
return
}
tradeNo := c.Request.Form.Get("trade_no")
err = h.notify(result.OutTradeNo, tradeNo)
err = h.paySuccess(orderInfo)
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")
@@ -411,20 +426,27 @@ func (h *PaymentHandler) GeekPayNotify(c *gin.Context) {
}
logger.Infof("收到GeekPay订单支付回调%+v", params)
// 检查支付状态
// 检查支付状态, 如果未支付,则返回成功
if params["trade_status"] != "TRADE_SUCCESS" {
c.String(http.StatusOK, "success")
return
}
sign := h.geekPayService.Sign(params)
sign := h.epayService.Sign(params)
if sign != c.Query("sign") {
logger.Errorf("签名验证失败, %s, %s", sign, c.Query("sign"))
c.String(http.StatusOK, "fail")
return
}
// 查询订单状态
order, err := h.epayService.Query(params["out_trade_no"])
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")
return
}
err := h.notify(params["out_trade_no"], params["trade_no"])
err = h.paySuccess(order)
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")
@@ -442,18 +464,15 @@ func (h *PaymentHandler) WechatPayNotify(c *gin.Context) {
return
}
result := h.wechatPayService.TradeVerify(c.Request)
logger.Infof("收到微信商号订单支付回调:%+v", result)
if !result.Success() {
logger.Error("订单校验失败:", err)
c.JSON(http.StatusBadRequest, gin.H{
"code": "FAIL",
"message": err.Error(),
})
orderInfo, err := h.wxpayService.TradeVerify(c.Request)
logger.Infof("收到微信商号订单支付回调:%+v", orderInfo)
if err != nil {
logger.Errorf("订单校验失败:%v", err)
c.JSON(http.StatusBadRequest, gin.H{"code": "FAIL"})
return
}
err = h.notify(result.OutTradeNo, result.TradeId)
err = h.paySuccess(orderInfo)
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")

View File

@@ -57,15 +57,15 @@ func (h *PromptHandler) Lyric(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if h.App.SysConfig.PromptPower > 0 {
if h.App.SysConfig.Base.PromptPower > 0 {
userId := h.GetLoginUserId(c)
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
err = h.userService.DecreasePower(userId, h.App.SysConfig.Base.PromptPower, model.PowerLog{
Type: types.PowerConsume,
Model: h.getPromptModel(),
Remark: "生成歌词",
@@ -88,14 +88,14 @@ func (h *PromptHandler) Image(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if h.App.SysConfig.PromptPower > 0 {
if h.App.SysConfig.Base.PromptPower > 0 {
userId := h.GetLoginUserId(c)
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
err = h.userService.DecreasePower(userId, h.App.SysConfig.Base.PromptPower, model.PowerLog{
Type: types.PowerConsume,
Model: h.getPromptModel(),
Remark: "生成绘画提示词",
@@ -117,15 +117,15 @@ func (h *PromptHandler) Video(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if h.App.SysConfig.PromptPower > 0 {
if h.App.SysConfig.Base.PromptPower > 0 {
userId := h.GetLoginUserId(c)
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
err = h.userService.DecreasePower(userId, h.App.SysConfig.Base.PromptPower, model.PowerLog{
Type: types.PowerConsume,
Model: h.getPromptModel(),
Remark: "生成视频脚本",
@@ -167,9 +167,9 @@ func (h *PromptHandler) MetaPrompt(c *gin.Context) {
}
func (h *PromptHandler) getPromptModel() string {
if h.App.SysConfig.AssistantModelId > 0 {
if h.App.SysConfig.Base.AssistantModelId > 0 {
var chatModel model.ChatModel
h.DB.Where("id", h.App.SysConfig.AssistantModelId).First(&chatModel)
h.DB.Where("id", h.App.SysConfig.Base.AssistantModelId).First(&chatModel)
return chatModel.Value
}
return "gpt-4o"

View File

@@ -160,7 +160,7 @@ func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
return
}
if user.Power < h.App.SysConfig.AdvanceVoicePower {
if user.Power < h.App.SysConfig.Base.AdvanceVoicePower {
resp.ERROR(c, "当前用户算力不足,无法使用该功能")
return
}
@@ -204,7 +204,7 @@ func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 扣减算力
err = h.userService.DecreasePower(userId, h.App.SysConfig.AdvanceVoicePower, model.PowerLog{
err = h.userService.DecreasePower(userId, h.App.SysConfig.Base.AdvanceVoicePower, model.PowerLog{
Type: types.PowerConsume,
Model: "advanced-voice",
Remark: "实时语音通话",

View File

@@ -73,7 +73,7 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool {
return false
}
if user.Power < h.App.SysConfig.SdPower {
if user.Power < h.App.SysConfig.Base.SdPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false
}
@@ -141,7 +141,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
HdSteps: data.HdSteps,
},
UserId: userId,
TranslateModelId: h.App.SysConfig.AssistantModelId,
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
}
job := model.SdJob{
@@ -152,7 +152,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
TaskInfo: utils.JsonEncode(task),
Prompt: data.Prompt,
Progress: 0,
Power: h.App.SysConfig.SdPower,
Power: h.App.SysConfig.Base.SdPower,
CreatedAt: time.Now(),
}
res := h.DB.Create(&job)

View File

@@ -25,7 +25,7 @@ const CodeStorePrefix = "/verify/codes/"
type SmsHandler struct {
BaseHandler
redis *redis.Client
sms *sms.ServiceManager
sms *sms.SmsManager
smtp *service.SmtpService
captcha *service.CaptchaService
}
@@ -33,7 +33,7 @@ type SmsHandler struct {
func NewSmsHandler(
app *core.AppServer,
client *redis.Client,
sms *sms.ServiceManager,
sms *sms.SmsManager,
smtp *service.SmtpService,
captcha *service.CaptchaService) *SmsHandler {
return &SmsHandler{
@@ -62,7 +62,7 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
if h.App.SysConfig.EnabledVerify {
if h.App.SysConfig.Base.EnabledVerify {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
@@ -78,14 +78,14 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
code := utils.RandomNumber(6)
var err error
if strings.Contains(data.Receiver, "@") { // email
if !utils.Contains(h.App.SysConfig.RegisterWays, "email") {
if !utils.Contains(h.App.SysConfig.Base.RegisterWays, "email") {
resp.ERROR(c, "系统已禁用邮箱注册!")
return
}
// 检查邮箱后缀是否在白名单
if len(h.App.SysConfig.EmailWhiteList) > 0 {
if len(h.App.SysConfig.Base.EmailWhiteList) > 0 {
inWhiteList := false
for _, suffix := range h.App.SysConfig.EmailWhiteList {
for _, suffix := range h.App.SysConfig.Base.EmailWhiteList {
if strings.HasSuffix(data.Receiver, suffix) {
inWhiteList = true
break
@@ -98,7 +98,7 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
}
err = h.smtp.SendVerifyCode(data.Receiver, code)
} else {
if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") {
if !utils.Contains(h.App.SysConfig.Base.RegisterWays, "mobile") {
resp.ERROR(c, "系统已禁用手机号注册!")
return
}

View File

@@ -82,7 +82,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
return
}
if user.Power < h.App.SysConfig.SunoPower {
if user.Power < h.App.SysConfig.Base.SunoPower {
resp.ERROR(c, "您的算力不足,请充值后再试!")
return
}
@@ -130,7 +130,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
RefSongId: data.RefSongId,
RefTaskId: data.RefTaskId,
ExtendSecs: data.ExtendSecs,
Power: h.App.SysConfig.SunoPower,
Power: h.App.SysConfig.Base.SunoPower,
SongId: utils.RandString(32),
}
if data.Lyrics != "" {

View File

@@ -4,19 +4,20 @@ import (
"geekai/core"
"geekai/service"
"geekai/service/payment"
"net/http"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"net/http"
)
type TestHandler struct {
App *core.AppServer
db *gorm.DB
snowflake *service.Snowflake
js *payment.GeekPayService
js *payment.EPayService
}
func NewTestHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, js *payment.GeekPayService) *TestHandler {
func NewTestHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, js *payment.EPayService) *TestHandler {
return &TestHandler{App: app, db: db, snowflake: snowflake, js: js}
}

View File

@@ -20,8 +20,6 @@ import (
"strings"
"time"
"github.com/imroc/req/v3"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
@@ -98,7 +96,7 @@ func (h *UserHandler) Register(c *gin.Context) {
return
}
if h.App.SysConfig.EnabledVerify && data.RegWay == "username" {
if h.App.SysConfig.Base.EnabledVerify && data.RegWay == "username" {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
@@ -163,7 +161,7 @@ func (h *UserHandler) Register(c *gin.Context) {
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
ChatConfig: "{}",
ChatModels: "{}",
Power: h.App.SysConfig.InitPower,
Power: h.App.SysConfig.Base.InitPower,
}
// check if the username is existing
@@ -188,13 +186,13 @@ func (h *UserHandler) Register(c *gin.Context) {
// 被邀请人也获得赠送算力
if data.InviteCode != "" {
user.Power += h.App.SysConfig.InvitePower
user.Power += h.App.SysConfig.Base.InvitePower
}
if h.licenseService.GetLicense().Configs.DeCopy {
user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
} else {
defaultNickname := h.App.SysConfig.DefaultNickname
defaultNickname := h.App.SysConfig.Base.DefaultNickname
if defaultNickname == "" {
defaultNickname = "极客学长"
}
@@ -211,11 +209,11 @@ func (h *UserHandler) Register(c *gin.Context) {
if data.InviteCode != "" {
// 增加邀请数量
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.InvitePower > 0 {
err := h.userService.IncreasePower(inviteCode.UserId, h.App.SysConfig.InvitePower, model.PowerLog{
if h.App.SysConfig.Base.InvitePower > 0 {
err := h.userService.IncreasePower(inviteCode.UserId, h.App.SysConfig.Base.InvitePower, model.PowerLog{
Type: types.PowerInvite,
Model: "Invite",
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.Base.InvitePower, inviteCode.Code, user.Username),
})
if err != nil {
tx.Rollback()
@@ -230,7 +228,7 @@ func (h *UserHandler) Register(c *gin.Context) {
UserId: user.Id,
Username: user.Username,
InviteCode: inviteCode.Code,
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.Base.InvitePower),
}).Error
if err != nil {
tx.Rollback()
@@ -276,7 +274,7 @@ func (h *UserHandler) Login(c *gin.Context) {
verifyKey := fmt.Sprintf("users/verify/%s", data.Username)
needVerify, err := h.redis.Get(c, verifyKey).Bool()
if h.App.SysConfig.EnabledVerify && needVerify {
if h.App.SysConfig.Base.EnabledVerify && needVerify {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
@@ -353,150 +351,128 @@ func (h *UserHandler) Logout(c *gin.Context) {
// CLogin 第三方登录请求二维码
func (h *UserHandler) CLogin(c *gin.Context) {
returnURL := h.GetTrim(c, "return_url")
var res types.BizVo
apiURL := fmt.Sprintf("%s/api/clogin/request", h.App.Config.ApiConfig.ApiURL)
r, err := req.C().R().SetBody(gin.H{"login_type": "wx", "return_url": returnURL}).
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
SetSuccessResult(&res).
Post(apiURL)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "error with login http status: "+r.Status)
return
}
if res.Code != types.Success {
resp.ERROR(c, "error with http response: "+res.Message)
return
}
resp.SUCCESS(c, res.Data)
}
// CLoginCallback 第三方登录回调
func (h *UserHandler) CLoginCallback(c *gin.Context) {
loginType := c.Query("login_type")
code := c.Query("code")
userId := h.GetInt(c, "user_id", 0)
action := c.Query("action")
// loginType := c.Query("login_type")
// code := c.Query("code")
// userId := h.GetInt(c, "user_id", 0)
// action := c.Query("action")
var res types.BizVo
apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL)
r, err := req.C().R().SetBody(gin.H{"login_type": loginType, "code": code}).
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
SetSuccessResult(&res).
Post(apiURL)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "error with login http status: "+r.Status)
return
}
// var res types.BizVo
// apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL)
// r, err := req.C().R().SetBody(gin.H{"login_type": loginType, "code": code}).
// SetHeader("AppId", h.App.Config.ApiConfig.AppId).
// SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
// SetSuccessResult(&res).
// Post(apiURL)
// if err != nil {
// resp.ERROR(c, err.Error())
// return
// }
// if r.IsErrorState() {
// resp.ERROR(c, "error with login http status: "+r.Status)
// return
// }
if res.Code != types.Success {
resp.ERROR(c, "error with http response: "+res.Message)
return
}
// if res.Code != types.Success {
// resp.ERROR(c, "error with http response: "+res.Message)
// return
// }
// login successfully
data := res.Data.(map[string]interface{})
var user model.User
if action == "bind" && userId > 0 {
err = h.DB.Where("openid", data["openid"]).First(&user).Error
if err == nil {
resp.ERROR(c, "该微信已经绑定其他账号,请先解绑")
return
}
// // login successfully
// data := res.Data.(map[string]interface{})
// var user model.User
// if action == "bind" && userId > 0 {
// err = h.DB.Where("openid", data["openid"]).First(&user).Error
// if err == nil {
// resp.ERROR(c, "该微信已经绑定其他账号,请先解绑")
// return
// }
err = h.DB.Where("id", userId).First(&user).Error
if err != nil {
resp.ERROR(c, "绑定用户不存在")
return
}
// err = h.DB.Where("id", userId).First(&user).Error
// if err != nil {
// resp.ERROR(c, "绑定用户不存在")
// return
// }
err = h.DB.Model(&user).UpdateColumn("openid", data["openid"]).Error
if err != nil {
resp.ERROR(c, "更新用户信息失败,"+err.Error())
return
}
// err = h.DB.Model(&user).UpdateColumn("openid", data["openid"]).Error
// if err != nil {
// resp.ERROR(c, "更新用户信息失败,"+err.Error())
// return
// }
resp.SUCCESS(c, gin.H{"token": ""})
return
}
// resp.SUCCESS(c, gin.H{"token": ""})
// return
// }
session := gin.H{}
tx := h.DB.Where("openid", data["openid"]).First(&user)
if tx.Error != nil {
// create new user
var totalUser int64
h.DB.Model(&model.User{}).Count(&totalUser)
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
return
}
// session := gin.H{}
// tx := h.DB.Where("openid", data["openid"]).First(&user)
// if tx.Error != nil {
// // create new user
// var totalUser int64
// h.DB.Model(&model.User{}).Count(&totalUser)
// if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
// resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
// return
// }
salt := utils.RandString(8)
password := fmt.Sprintf("%d", utils.RandomNumber(8))
user = model.User{
Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)),
Password: utils.GenPassword(password, salt),
Avatar: fmt.Sprintf("%s", data["avatar"]),
Salt: salt,
Status: true,
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
Power: h.App.SysConfig.InitPower,
OpenId: fmt.Sprintf("%s", data["openid"]),
Nickname: fmt.Sprintf("%s", data["nickname"]),
}
// salt := utils.RandString(8)
// password := fmt.Sprintf("%d", utils.RandomNumber(8))
// user = model.User{
// Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)),
// Password: utils.GenPassword(password, salt),
// Avatar: fmt.Sprintf("%s", data["avatar"]),
// Salt: salt,
// Status: true,
// ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
// Power: h.App.SysConfig.InitPower,
// OpenId: fmt.Sprintf("%s", data["openid"]),
// Nickname: fmt.Sprintf("%s", data["nickname"]),
// }
tx = h.DB.Create(&user)
if tx.Error != nil {
resp.ERROR(c, "保存数据失败")
logger.Error(tx.Error)
return
}
session["username"] = user.Username
session["password"] = password
} else { // login directly
// 更新最后登录时间和IP
user.LastLoginIp = c.ClientIP()
user.LastLoginAt = time.Now().Unix()
h.DB.Model(&user).Updates(user)
// tx = h.DB.Create(&user)
// if tx.Error != nil {
// resp.ERROR(c, "保存数据失败")
// logger.Error(tx.Error)
// return
// }
// session["username"] = user.Username
// session["password"] = password
// } else { // login directly
// // 更新最后登录时间和IP
// user.LastLoginIp = c.ClientIP()
// user.LastLoginAt = time.Now().Unix()
// h.DB.Model(&user).Updates(user)
h.DB.Create(&model.UserLoginLog{
UserId: user.Id,
Username: user.Username,
LoginIp: c.ClientIP(),
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
})
}
// h.DB.Create(&model.UserLoginLog{
// UserId: user.Id,
// Username: user.Username,
// LoginIp: c.ClientIP(),
// LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
// })
// }
// 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.Id,
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error())
return
}
// 保存到 redis
key := fmt.Sprintf("users/%d", user.Id)
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
session["token"] = tokenString
resp.SUCCESS(c, session)
// // 创建 token
// token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
// "user_id": user.Id,
// "expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
// })
// tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
// if err != nil {
// resp.ERROR(c, "Failed to generate token, "+err.Error())
// return
// }
// // 保存到 redis
// key := fmt.Sprintf("users/%d", user.Id)
// if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
// resp.ERROR(c, "error with save token: "+err.Error())
// return
// }
// session["token"] = tokenString
// resp.SUCCESS(c, session)
}
// Session 获取/验证会话
@@ -760,11 +736,11 @@ func (h *UserHandler) SignIn(c *gin.Context) {
// 签到
h.levelDB.Put(key, true)
if h.App.SysConfig.DailyPower > 0 {
h.userService.IncreasePower(userId, h.App.SysConfig.DailyPower, model.PowerLog{
if h.App.SysConfig.Base.DailyPower > 0 {
h.userService.IncreasePower(userId, h.App.SysConfig.Base.DailyPower, model.PowerLog{
Type: types.PowerSignIn,
Model: "SignIn",
Remark: fmt.Sprintf("每日签到奖励,金额:%d", h.App.SysConfig.DailyPower),
Remark: fmt.Sprintf("每日签到奖励,金额:%d", h.App.SysConfig.Base.DailyPower),
})
}
resp.SUCCESS(c)

View File

@@ -78,7 +78,7 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
return
}
if user.Power < h.App.SysConfig.LumaPower {
if user.Power < h.App.SysConfig.Base.LumaPower {
resp.ERROR(c, "您的算力不足,请充值后再试!")
return
}
@@ -95,14 +95,14 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
Type: types.VideoLuma,
Prompt: data.Prompt,
Params: params,
TranslateModelId: h.App.SysConfig.AssistantModelId,
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
}
// 插入数据库
job := model.VideoJob{
UserId: uint(userId),
Type: types.VideoLuma,
Prompt: data.Prompt,
Power: h.App.SysConfig.LumaPower,
Power: h.App.SysConfig.Base.LumaPower,
TaskInfo: utils.JsonEncode(task),
}
tx := h.DB.Create(&job)
@@ -157,7 +157,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) {
// 计算当前任务所需算力
key := fmt.Sprintf("%s_%s_%s", data.Model, data.Mode, data.Duration)
power := h.App.SysConfig.KeLingPowers[key]
power := h.App.SysConfig.Base.KeLingPowers[key]
if power == 0 {
resp.ERROR(c, "当前模型暂不支持")
return
@@ -191,7 +191,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) {
Type: types.VideoKeLing,
Prompt: data.Prompt,
Params: params,
TranslateModelId: h.App.SysConfig.AssistantModelId,
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
Channel: data.Channel,
}
// 插入数据库

View File

@@ -30,7 +30,7 @@ import (
"log"
"os"
"os/signal"
"strconv"
"runtime/debug"
"syscall"
"time"
@@ -71,15 +71,16 @@ func main() {
if configFile == "" {
configFile = "config.toml"
}
debug, _ := strconv.ParseBool(os.Getenv("APP_DEBUG"))
logger.Info("Loading config file: ", configFile)
if !debug {
defer func() {
if err := recover(); err != nil {
logger.Error("Panic Error:", err)
defer func() {
if err := recover(); err != nil {
logger.Error("Panic Error:", err)
// 打印堆栈信息
if os.Getenv("GEEKAI_DEBUG") == "true" {
debug.PrintStack()
}
}()
}
}
}()
app := fx.New(
// 初始化配置应用配置
@@ -89,16 +90,16 @@ func main() {
log.Fatal(err)
}
config.Path = configFile
if debug {
_ = core.SaveConfig(config)
}
return config
}),
// 创建应用服务
fx.Provide(core.NewServer),
// 初始化
fx.Invoke(func(s *core.AppServer, client *redis.Client) {
s.Init(debug, client)
s.Init(client)
}),
fx.Provide(func(db *gorm.DB) *types.SystemConfig {
return core.LoadSystemConfig(db)
}),
// 初始化数据库
@@ -111,6 +112,12 @@ func main() {
return xdbFS
}),
// 数据修复
fx.Provide(service.NewDataFixService),
fx.Invoke(func(s *core.AppServer, dfs *service.DataFixService) {
dfs.FixData()
}),
// 创建 Ip2Region 查询对象
fx.Provide(func() (*xdb.Searcher, error) {
file, err := xdbFS.Open("res/ip2region.xdb")
@@ -215,20 +222,34 @@ func main() {
fx.Invoke(func(service *jimeng.Service) {
service.Start()
}),
fx.Provide(service.NewUserService),
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay),
fx.Provide(payment.NewJPayService),
fx.Provide(payment.NewWechatService),
fx.Provide(service.NewSnowflake),
// 创建服务
fx.Provide(sms.NewSendServiceManager),
fx.Provide(func(config *types.AppConfig) *service.CaptchaService {
return service.NewCaptchaService(config.ApiConfig)
// 创建短信服务
fx.Provide(sms.NewAliYunSmsService),
fx.Provide(sms.NewBaoSmsService),
fx.Provide(sms.NewSmsManager),
fx.Provide(func(config *types.SystemConfig) *service.CaptchaService {
return service.NewCaptchaService(config.GeekAPI.Captcha)
}),
fx.Provide(func(config *types.SystemConfig, client *redis.Client) *service.WxLoginService {
return service.NewWxLoginService(config.GeekAPI.WxLogin, client)
}),
// 支付服务
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewEPayService),
fx.Provide(payment.NewWxpayService),
// 文件上传服务
fx.Provide(oss.NewLocalStorage),
fx.Provide(oss.NewMiniOss),
fx.Provide(oss.NewQiNiuOss),
fx.Provide(oss.NewAliYunOss),
fx.Provide(oss.NewUploaderManager),
// 用户服务
fx.Provide(service.NewUserService),
// 注册路由
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
h.RegisterRoutes()

View File

@@ -20,9 +20,9 @@ type CaptchaService struct {
client *req.Client
}
func NewCaptchaService(config types.CaptchaConfig) *CaptchaService {
func NewCaptchaService(captchaConfig types.CaptchaConfig) *CaptchaService {
return &CaptchaService{
config: config,
config: captchaConfig,
client: req.C().SetTimeout(10 * time.Second),
}
}

View File

@@ -65,12 +65,6 @@ func (s *ConfigMigrationService) MigrateFromConfig(config *types.AppConfig) erro
return err
}
// 迁移API配置
if err := s.migrateApiConfig(config); err != nil {
logger.Errorf("迁移API配置失败: %v", err)
return err
}
// 标记迁移完成
if err := s.markMigrationCompleted(); err != nil {
logger.Errorf("标记迁移完成失败: %v", err)
@@ -101,58 +95,13 @@ func (s *ConfigMigrationService) markMigrationCompleted() error {
// 迁移支付配置
func (s *ConfigMigrationService) migratePaymentConfig(config *types.AppConfig) error {
// 支付宝配置
alipayConfig := map[string]any{
"enabled": config.AlipayConfig.Enabled,
"sand_box": config.AlipayConfig.SandBox,
"app_id": config.AlipayConfig.AppId,
"private_key": config.AlipayConfig.PrivateKey,
"alipay_public_key": config.AlipayConfig.AlipayPublicKey,
"notify_url": config.AlipayConfig.NotifyURL,
"return_url": config.AlipayConfig.ReturnURL,
}
if err := s.saveConfig("alipay", alipayConfig); err != nil {
return err
}
// 微信支付配置
wechatConfig := map[string]any{
"enabled": config.WechatPayConfig.Enabled,
"app_id": config.WechatPayConfig.AppId,
"mch_id": config.WechatPayConfig.MchId,
"serial_no": config.WechatPayConfig.SerialNo,
"private_key": config.WechatPayConfig.PrivateKey,
"api_v3_key": config.WechatPayConfig.ApiV3Key,
"notify_url": config.WechatPayConfig.NotifyURL,
paymentConfig := types.PaymentConfig{
Alipay: config.AlipayConfig,
Epay: config.GeekPayConfig,
WxPay: config.WechatPayConfig,
}
if err := s.saveConfig("wechat", wechatConfig); err != nil {
return err
}
// 虎皮椒配置
hupiConfig := map[string]any{
"enabled": config.HuPiPayConfig.Enabled,
"app_id": config.HuPiPayConfig.AppId,
"app_secret": config.HuPiPayConfig.AppSecret,
"api_url": config.HuPiPayConfig.ApiURL,
"notify_url": config.HuPiPayConfig.NotifyURL,
"return_url": config.HuPiPayConfig.ReturnURL,
}
if err := s.saveConfig("hupi", hupiConfig); err != nil {
return err
}
// GeekPay配置
geekpayConfig := map[string]any{
"enabled": config.GeekPayConfig.Enabled,
"app_id": config.GeekPayConfig.AppId,
"private_key": config.GeekPayConfig.PrivateKey,
"api_url": config.GeekPayConfig.ApiURL,
"notify_url": config.GeekPayConfig.NotifyURL,
"return_url": config.GeekPayConfig.ReturnURL,
"methods": config.GeekPayConfig.Methods,
}
if err := s.saveConfig("geekpay", geekpayConfig); err != nil {
if err := s.saveConfig(types.ConfigKeyPayment, paymentConfig); err != nil {
return err
}
@@ -161,37 +110,15 @@ func (s *ConfigMigrationService) migratePaymentConfig(config *types.AppConfig) e
// 迁移存储配置
func (s *ConfigMigrationService) migrateStorageConfig(config *types.AppConfig) error {
ossConfig := map[string]any{
"active": config.OSS.Active,
"local": map[string]any{
"base_path": config.OSS.Local.BasePath,
"base_url": config.OSS.Local.BaseURL,
},
"minio": map[string]any{
"endpoint": config.OSS.Minio.Endpoint,
"access_key": config.OSS.Minio.AccessKey,
"access_secret": config.OSS.Minio.AccessSecret,
"bucket": config.OSS.Minio.Bucket,
"use_ssl": config.OSS.Minio.UseSSL,
"domain": config.OSS.Minio.Domain,
},
"qiniu": map[string]any{
"zone": config.OSS.QiNiu.Zone,
"access_key": config.OSS.QiNiu.AccessKey,
"access_secret": config.OSS.QiNiu.AccessSecret,
"bucket": config.OSS.QiNiu.Bucket,
"domain": config.OSS.QiNiu.Domain,
},
"aliyun": map[string]any{
"endpoint": config.OSS.AliYun.Endpoint,
"access_key": config.OSS.AliYun.AccessKey,
"access_secret": config.OSS.AliYun.AccessSecret,
"bucket": config.OSS.AliYun.Bucket,
"sub_dir": config.OSS.AliYun.SubDir,
"domain": config.OSS.AliYun.Domain,
},
ossConfig := types.OSSConfig{
Active: config.OSS.Active,
Local: config.OSS.Local,
Minio: config.OSS.Minio,
QiNiu: config.OSS.QiNiu,
AliYun: config.OSS.AliYun,
}
return s.saveConfig("oss", ossConfig)
return s.saveConfig(types.ConfigKeyOss, ossConfig)
}
// 迁移通信配置
@@ -205,7 +132,7 @@ func (s *ConfigMigrationService) migrateCommunicationConfig(config *types.AppCon
"from": config.SmtpConfig.From,
"password": config.SmtpConfig.Password,
}
if err := s.saveConfig("smtp", smtpConfig); err != nil {
if err := s.saveConfig(types.ConfigKeySmtp, smtpConfig); err != nil {
return err
}
@@ -215,34 +142,17 @@ func (s *ConfigMigrationService) migrateCommunicationConfig(config *types.AppCon
"ali": map[string]any{
"access_key": config.SMS.Ali.AccessKey,
"access_secret": config.SMS.Ali.AccessSecret,
"product": config.SMS.Ali.Product,
"domain": config.SMS.Ali.Domain,
"sign": config.SMS.Ali.Sign,
"code_temp_id": config.SMS.Ali.CodeTempId,
},
"bao": map[string]any{
"username": config.SMS.Bao.Username,
"password": config.SMS.Bao.Password,
"domain": config.SMS.Bao.Domain,
"sign": config.SMS.Bao.Sign,
"code_template": config.SMS.Bao.CodeTemplate,
},
}
return s.saveConfig("sms", smsConfig)
}
// 迁移API配置
func (s *ConfigMigrationService) migrateApiConfig(config *types.AppConfig) error {
apiConfig := map[string]any{
"api_url": config.ApiConfig.ApiURL,
"app_id": config.ApiConfig.AppId,
"token": config.ApiConfig.Token,
"jimeng_config": map[string]any{
"access_key": config.ApiConfig.JimengConfig.AccessKey,
"secret_key": config.ApiConfig.JimengConfig.SecretKey,
},
}
return s.saveConfig("api", apiConfig)
return s.saveConfig(types.ConfigKeySms, smsConfig)
}
// 保存配置到数据库

View File

@@ -0,0 +1,66 @@
package service
import (
"geekai/store/model"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type DataFixService struct {
db *gorm.DB
redis *redis.Client
}
func NewDataFixService(db *gorm.DB, redis *redis.Client) *DataFixService {
return &DataFixService{db: db, redis: redis}
}
func (s *DataFixService) FixData() {
s.FixColumn()
}
// 字段修正
func (s *DataFixService) FixColumn() {
// 订单字段整理
if s.db.Migrator().HasColumn(&model.Order{}, "pay_type") {
s.db.Migrator().RenameColumn(&model.Order{}, "pay_type", "channel")
}
if !s.db.Migrator().HasColumn(&model.Order{}, "check") {
s.db.Migrator().AddColumn(&model.Order{}, "checked")
}
// 重命名 config 表字段
if s.db.Migrator().HasColumn(&model.Config{}, "config_json") {
s.db.Migrator().RenameColumn(&model.Config{}, "config_json", "value")
}
if s.db.Migrator().HasColumn(&model.Config{}, "marker") {
s.db.Migrator().RenameColumn(&model.Config{}, "marker", "name")
}
if s.db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") {
s.db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key")
}
if s.db.Migrator().HasIndex(&model.Config{}, "marker") {
s.db.Migrator().DropIndex(&model.Config{}, "marker")
}
// 手动删除字段
if s.db.Migrator().HasColumn(&model.Order{}, "deleted_at") {
s.db.Migrator().DropColumn(&model.Order{}, "deleted_at")
}
if s.db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") {
s.db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at")
}
if s.db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") {
s.db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at")
}
if s.db.Migrator().HasColumn(&model.User{}, "chat_config") {
s.db.Migrator().DropColumn(&model.User{}, "chat_config")
}
if s.db.Migrator().HasColumn(&model.ChatModel{}, "category") {
s.db.Migrator().DropColumn(&model.ChatModel{}, "category")
}
if s.db.Migrator().HasColumn(&model.ChatModel{}, "description") {
s.db.Migrator().DropColumn(&model.ChatModel{}, "description")
}
}

View File

@@ -21,7 +21,6 @@ import (
)
type LicenseService struct {
config types.GeekServiceConfig
levelDB *store.LevelDB
license *types.License
urlWhiteList []string
@@ -39,7 +38,6 @@ func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseS
}
logger.Infof("License: %+v", license)
return &LicenseService{
config: server.Config.ApiConfig,
levelDB: levelDB,
license: &license,
machineId: machineId,
@@ -63,7 +61,7 @@ func (s *LicenseService) ActiveLicense(license string, machineId string) error {
Message string `json:"message"`
Data License `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active")
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/active")
response, err := req.C().R().
SetBody(map[string]string{"license": license, "machine_id": machineId}).
SetSuccessResult(&res).Post(apiURL)
@@ -129,7 +127,7 @@ func (s *LicenseService) fetchLicense() (*types.License, error) {
Message string `json:"message"`
Data License `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/check")
response, err := req.C().R().
SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
SetSuccessResult(&res).Post(apiURL)
@@ -158,7 +156,7 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
Message string `json:"message"`
Data []string `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls")
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/urls")
response, err := req.C().R().SetSuccessResult(&res).Get(apiURL)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %v", err)

View File

@@ -28,30 +28,32 @@ type AliYunOss struct {
proxyURL string
}
func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) {
config := &appConfig.OSS.AliYun
// 创建 OSS 客户端
func NewAliYunOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*AliYunOss, error) {
s := &AliYunOss{
proxyURL: appConfig.ProxyURL,
}
if sysConfig.OSS.Active == AliYun {
err := s.UpdateConfig(&sysConfig.OSS.AliYun)
if err != nil {
logger.Errorf("阿里云OSS初始化失败: %v", err)
}
}
return s, nil
}
func (s *AliYunOss) UpdateConfig(config *types.AliYunOssConfig) error {
client, err := oss.New(config.Endpoint, config.AccessKey, config.AccessSecret)
if err != nil {
return nil, err
return err
}
// 获取存储空间
bucket, err := client.Bucket(config.Bucket)
if err != nil {
return nil, err
return err
}
if config.SubDir == "" {
config.SubDir = "gpt"
}
return &AliYunOss{
config: config,
bucket: bucket,
proxyURL: appConfig.ProxyURL,
}, nil
s.bucket = bucket
s.config = config
return nil
}
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {

View File

@@ -25,13 +25,17 @@ type LocalStorage struct {
proxyURL string
}
func NewLocalStorage(config *types.AppConfig) LocalStorage {
return LocalStorage{
config: &config.OSS.Local,
proxyURL: config.ProxyURL,
func NewLocalStorage(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *LocalStorage {
return &LocalStorage{
config: &sysConfig.OSS.Local,
proxyURL: appConfig.ProxyURL,
}
}
func (s *LocalStorage) UpdateConfig(config *types.LocalStorageConfig) {
s.config = config
}
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
file, err := ctx.FormFile(name)
if err != nil {

View File

@@ -29,19 +29,29 @@ type MiniOss struct {
proxyURL string
}
func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
config := &appConfig.OSS.Minio
func NewMiniOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*MiniOss, error) {
s := &MiniOss{proxyURL: appConfig.ProxyURL}
if sysConfig.OSS.Active == Minio {
err := s.UpdateConfig(&sysConfig.OSS.Minio)
if err != nil {
logger.Errorf("MinioOSS初始化失败: %v", err)
}
}
return s, nil
}
func (s *MiniOss) UpdateConfig(config *types.MiniOssConfig) error {
minioClient, err := minio.New(config.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(config.AccessKey, config.AccessSecret, ""),
Secure: config.UseSSL,
})
if err != nil {
return MiniOss{}, err
return err
}
if config.SubDir == "" {
config.SubDir = "gpt"
}
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
s.config = config
s.client = minioClient
return nil
}
func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {

View File

@@ -29,13 +29,21 @@ type QinNiuOss struct {
mac *qbox.Mac
putPolicy storage.PutPolicy
uploader *storage.FormUploader
manager *storage.BucketManager
bucket *storage.BucketManager
proxyURL string
}
func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
config := &appConfig.OSS.QiNiu
// build storage uploader
func NewQiNiuOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *QinNiuOss {
s := &QinNiuOss{
proxyURL: appConfig.ProxyURL,
}
if sysConfig.OSS.Active == QiNiu {
s.UpdateConfig(&sysConfig.OSS.QiNiu)
}
return s
}
func (s *QinNiuOss) UpdateConfig(config *types.QiNiuOssConfig) {
zone, ok := storage.GetRegionByID(storage.RegionID(config.Zone))
if !ok {
zone = storage.ZoneHuanan
@@ -47,19 +55,12 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
putPolicy := storage.PutPolicy{
Scope: config.Bucket,
}
if config.SubDir == "" {
config.SubDir = "gpt"
}
return QinNiuOss{
config: config,
mac: mac,
putPolicy: putPolicy,
uploader: formUploader,
manager: storage.NewBucketManager(mac, &storeConfig),
proxyURL: appConfig.ProxyURL,
}
s.config = config
s.mac = mac
s.putPolicy = putPolicy
s.uploader = formUploader
s.bucket = storage.NewBucketManager(mac, &storeConfig)
}
func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
// 解析表单
file, err := ctx.FormFile(name)
@@ -147,7 +148,7 @@ func (s QinNiuOss) Delete(fileURL string) error {
objectKey = fileURL
}
return s.manager.Delete(s.config.Bucket, objectKey)
return s.bucket.Delete(s.config.Bucket, objectKey)
}
var _ Uploader = QinNiuOss{}

View File

@@ -10,44 +10,45 @@ package oss
import (
"geekai/core/types"
"strings"
logger2 "geekai/logger"
)
var logger = logger2.GetLogger()
type UploaderManager struct {
handler Uploader
local *LocalStorage
aliyun *AliYunOss
mini *MiniOss
qiniu *QinNiuOss
config *types.OSSConfig
}
func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) {
active := Local
if config.OSS.Active != "" {
active = strings.ToUpper(config.OSS.Active)
}
var handler Uploader
switch active {
case Local:
handler = NewLocalStorage(config)
break
case Minio:
client, err := NewMiniOss(config)
if err != nil {
return nil, err
}
handler = client
break
case QiNiu:
handler = NewQiNiuOss(config)
break
case AliYun:
client, err := NewAliYunOss(config)
if err != nil {
return nil, err
}
handler = client
break
func NewUploaderManager(sysConfig *types.SystemConfig, local *LocalStorage, aliyun *AliYunOss, mini *MiniOss, qiniu *QinNiuOss) (*UploaderManager, error) {
if sysConfig.OSS.Active == "" {
sysConfig.OSS.Active = Local
}
sysConfig.OSS.Active = strings.ToLower(sysConfig.OSS.Active)
return &UploaderManager{handler: handler}, nil
return &UploaderManager{
config: &sysConfig.OSS,
local: local,
aliyun: aliyun,
mini: mini,
qiniu: qiniu,
}, nil
}
func (m *UploaderManager) GetUploadHandler() Uploader {
return m.handler
switch m.config.Active {
case Local:
return m.local
case AliYun:
return m.aliyun
case Minio:
return m.mini
case QiNiu:
return m.qiniu
}
return m.local
}

View File

@@ -20,109 +20,90 @@ import (
)
type AlipayService struct {
config *types.AlipayConfig
client *alipay.Client
config *types.AlipayConfig
}
var logger = logger2.GetLogger()
func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
config := appConfig.AlipayConfig
func NewAlipayService(sysConfig *types.SystemConfig) (*AlipayService, error) {
config := sysConfig.Payment.Alipay
if !config.Enabled {
logger.Info("Disabled Alipay service")
return nil, nil
logger.Debug("Disabled Alipay service")
}
service := &AlipayService{config: &config}
if config.Enabled {
err := service.UpdateConfig(&config)
if err != nil {
logger.Errorf("支付宝服务初始化失败: %v", err)
}
}
return service, nil
}
func (s *AlipayService) UpdateConfig(config *types.AlipayConfig) error {
client, err := alipay.NewClient(config.AppId, config.PrivateKey, !config.SandBox)
if err != nil {
return nil, fmt.Errorf("error with initialize alipay service: %v", err)
return fmt.Errorf("error with initialize alipay service: %v", err)
}
return &AlipayService{config: &config, client: client}, nil
s.client = client
s.config = config
if os.Getenv("GEEKAI_DEBUG") == "true" {
logger.Info("Alipay Debug mode is enabled")
client.DebugSwitch = gopay.DebugOn
}
return nil
}
type AlipayParams struct {
OutTradeNo string `json:"out_trade_no"`
Subject string `json:"subject"`
TotalFee string `json:"total_fee"`
ReturnURL string `json:"return_url"`
NotifyURL string `json:"notify_url"`
}
func (s *AlipayService) PayMobile(params AlipayParams) (string, error) {
bm := make(gopay.BodyMap)
bm.Set("subject", params.Subject)
bm.Set("out_trade_no", params.OutTradeNo)
bm.Set("quit_url", params.ReturnURL)
bm.Set("total_amount", params.TotalFee)
bm.Set("product_code", "QUICK_WAP_WAY")
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradeWapPay(context.Background(), bm)
}
func (s *AlipayService) PayPC(params AlipayParams) (string, error) {
func (s *AlipayService) Pay(params PayRequest) (string, error) {
bm := make(gopay.BodyMap)
bm.Set("subject", params.Subject)
bm.Set("out_trade_no", params.OutTradeNo)
bm.Set("total_amount", params.TotalFee)
bm.Set("product_code", "FAST_INSTANT_TRADE_PAY")
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradePagePay(context.Background(), bm)
return s.client.TradeWapPay(context.Background(), bm)
}
func (s *AlipayService) Query(outTradeNo string) (OrderInfo, error) {
bm := make(gopay.BodyMap)
bm.Set("out_trade_no", outTradeNo)
rsp, err := s.client.TradeQuery(context.Background(), bm)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with trade query: %v", err)
}
switch rsp.Response.TradeStatus {
case "TRADE_SUCCESS":
logger.Debugf("支付宝查询订单成功:%+v", rsp.Response)
return OrderInfo{
OutTradeNo: rsp.Response.OutTradeNo,
TradeId: rsp.Response.TradeNo,
Amount: rsp.Response.TotalAmount,
Status: Success,
PayTime: rsp.Response.SendPayDate,
}, nil
case "TRADE_CLOSED":
return OrderInfo{Status: Closed}, nil
default:
return OrderInfo{}, fmt.Errorf("error with trade query: %v", rsp.Response.TradeStatus)
}
}
// TradeVerify 交易验证
func (s *AlipayService) TradeVerify(request *http.Request) NotifyVo {
func (s *AlipayService) TradeVerify(request *http.Request) (OrderInfo, error) {
notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法
if err != nil {
return NotifyVo{
Status: Failure,
Message: "error with parse notify request: " + err.Error(),
}
return OrderInfo{}, fmt.Errorf("error with parse notify request: %v", err)
}
_, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq)
if err != nil {
return NotifyVo{
Status: Failure,
Message: "error with verify sign: " + err.Error(),
}
return OrderInfo{}, fmt.Errorf("error with verify sign: %v", err)
}
return s.TradeQuery(request.Form.Get("out_trade_no"))
return s.Query(request.Form.Get("out_trade_no"))
}
func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo {
bm := make(gopay.BodyMap)
bm.Set("out_trade_no", outTradeNo)
//查询订单
rsp, err := s.client.TradeQuery(context.Background(), bm)
if err != nil {
return NotifyVo{
Status: Failure,
Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(),
}
}
if rsp.Response.TradeStatus == "TRADE_SUCCESS" {
return NotifyVo{
Status: Success,
OutTradeNo: rsp.Response.OutTradeNo,
TradeId: rsp.Response.TradeNo,
Amount: rsp.Response.TotalAmount,
Subject: rsp.Response.Subject,
Message: "OK",
}
} else {
return NotifyVo{
Status: Failure,
Message: "异步查询验证订单信息发生错误" + outTradeNo,
}
}
}
func readKey(filename string) (string, error) {
data, err := os.ReadFile(filename)
if err != nil {
return "", err
}
return string(data), nil
}
var _ PayService = (*AlipayService)(nil)

View File

@@ -22,41 +22,30 @@ import (
"time"
)
// GeekPayService Geek 支付服务
type GeekPayService struct {
config *types.GeekPayConfig
// EPayService 支付服务
type EPayService struct {
config *types.EpayConfig
}
func NewJPayService(appConfig *types.AppConfig) *GeekPayService {
return &GeekPayService{
config: &appConfig.GeekPayConfig,
func NewEPayService(sysConfig *types.SystemConfig) *EPayService {
return &EPayService{
config: &sysConfig.Payment.Epay,
}
}
type GeekPayParams struct {
Method string `json:"method"` // 接口类型
Device string `json:"device"` // 设备类型
Type string `json:"type"` // 支付方式
OutTradeNo string `json:"out_trade_no"` // 商户订单号
Name string `json:"name"` // 商品名称
Money string `json:"money"` // 商品金额
ClientIP string `json:"clientip"` //用户IP地址
SubOpenId string `json:"sub_openid"` // 微信用户 openid仅小程序支付需要
SubAppId string `json:"sub_appid"` // 小程序 AppId仅小程序支付需要
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"return_url"`
func (s *EPayService) UpdateConfig(config *types.EpayConfig) {
s.config = config
}
// Pay 支付订单
func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
func (s *EPayService) Pay(params PayRequest) (string, error) {
p := map[string]string{
"pid": s.config.AppId,
//"method": params.Method,
"pid": s.config.AppId,
"device": params.Device,
"type": params.Type,
"type": params.PayWay,
"out_trade_no": params.OutTradeNo,
"name": params.Name,
"money": params.Money,
"name": params.Subject,
"money": params.TotalFee,
"clientip": params.ClientIP,
"notify_url": params.NotifyURL,
"return_url": params.ReturnURL,
@@ -64,10 +53,21 @@ func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
}
p["sign"] = s.Sign(p)
p["sign_type"] = "MD5"
return s.sendRequest(s.config.ApiURL, p)
resp, err := s.sendRequest(s.config.ApiURL, p)
if err != nil {
return "", err
}
if resp.Code != 1 {
return "", errors.New(resp.Msg)
}
if resp.PayURL != "" {
return resp.PayURL, nil
} else {
return resp.QrCode, nil
}
}
func (s *GeekPayService) Sign(params map[string]string) string {
func (s *EPayService) Sign(params map[string]string) string {
// 按字母顺序排序参数
var keys []string
for k := range params {
@@ -100,7 +100,7 @@ type GeekPayResp struct {
UrlScheme string `json:"urlscheme"` // 小程序跳转支付链接
}
func (s *GeekPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
func (s *EPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
form := url.Values{}
for k, v := range params {
form.Add(k, v)
@@ -137,3 +137,61 @@ func (s *GeekPayService) sendRequest(endpoint string, params map[string]string)
}
return &r, nil
}
func (s *EPayService) Query(outTradeNo string) (OrderInfo, error) {
params := url.Values{}
params.Set("act", "order")
params.Set("pid", s.config.AppId)
params.Set("key", s.config.PrivateKey)
params.Set("out_trade_no", outTradeNo)
apiURL := fmt.Sprintf("%s/api.php?%s", s.config.ApiURL, params.Encode())
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{Transport: tr}
resp, err := client.Get(apiURL)
if err != nil {
return OrderInfo{}, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return OrderInfo{}, err
}
logger.Debugf(string(body))
var result struct {
Code int `json:"code"`
Msg string `json:"msg"`
Status string `json:"status"`
Name string `json:"name"`
Money string `json:"money"`
EndTime string `json:"endtime"`
TradeNo string `json:"trade_no"`
}
if err := json.Unmarshal(body, &result); err != nil {
return OrderInfo{}, errors.New("订单查询响应解析失败")
}
if result.Code != 1 {
return OrderInfo{}, errors.New(result.Msg)
}
logger.Debugf("订单信息:%+v", result)
orderInfo := OrderInfo{
OutTradeNo: outTradeNo,
TradeId: result.TradeNo,
Amount: result.Money,
PayTime: result.EndTime,
}
if result.Status == "1" {
orderInfo.Status = Success
} else {
orderInfo.Status = Failure
}
return orderInfo, nil
}
var _ PayService = (*EPayService)(nil)

View File

@@ -1,171 +0,0 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
)
type HuPiPayService struct {
appId string
appSecret string
apiURL string
}
func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
return &HuPiPayService{
appId: config.HuPiPayConfig.AppId,
appSecret: config.HuPiPayConfig.AppSecret,
apiURL: config.HuPiPayConfig.ApiURL,
}
}
type HuPiPayParams struct {
AppId string `json:"appid"`
Version string `json:"version"`
TradeOrderId string `json:"trade_order_id"`
TotalFee string `json:"total_fee"`
Title string `json:"title"`
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"return_url"`
WapName string `json:"wap_name"`
CallbackURL string `json:"callback_url"`
Time string `json:"time"`
NonceStr string `json:"nonce_str"`
Type string `json:"type"`
WapUrl string `json:"wap_url"`
}
type HuPiPayResp struct {
Openid interface{} `json:"openid"`
UrlQrcode string `json:"url_qrcode"`
URL string `json:"url"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg,omitempty"`
}
// Pay 执行支付请求操作
func (s *HuPiPayService) Pay(params HuPiPayParams) (HuPiPayResp, error) {
data := url.Values{}
simple := strconv.FormatInt(time.Now().Unix(), 10)
params.AppId = s.appId
params.Time = simple
params.NonceStr = simple
encode := utils.JsonEncode(params)
m := make(map[string]string)
_ = utils.JsonDecode(encode, &m)
for k, v := range m {
data.Add(k, fmt.Sprintf("%v", v))
}
// 生成签名
data.Add("hash", s.Sign(data))
// 发送支付请求
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
resp, err := http.PostForm(apiURL, data)
if err != nil {
return HuPiPayResp{}, fmt.Errorf("error with requst api: %v", err)
}
defer resp.Body.Close()
all, err := io.ReadAll(resp.Body)
if err != nil {
return HuPiPayResp{}, fmt.Errorf("error with reading response: %v", err)
}
var res HuPiPayResp
err = utils.JsonDecode(string(all), &res)
if err != nil {
return HuPiPayResp{}, fmt.Errorf("error with decode payment result: %v", err)
}
if res.ErrCode != 0 {
return HuPiPayResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
}
return res, nil
}
// Sign 签名方法
func (s *HuPiPayService) Sign(params url.Values) string {
params.Del(`Sign`)
var keys = make([]string, 0, 0)
for key := range params {
if params.Get(key) != `` {
keys = append(keys, key)
}
}
sort.Strings(keys)
var pList = make([]string, 0, 0)
for _, key := range keys {
var value = strings.TrimSpace(params.Get(key))
if len(value) > 0 {
pList = append(pList, key+"="+value)
}
}
var src = strings.Join(pList, "&")
src += s.appSecret
md5bs := md5.Sum([]byte(src))
return hex.EncodeToString(md5bs[:])
}
// Check 校验订单状态
func (s *HuPiPayService) Check(outTradeNo string) error {
data := url.Values{}
data.Add("appid", s.appId)
data.Add("out_trade_order", outTradeNo)
stamp := strconv.FormatInt(time.Now().Unix(), 10)
data.Add("time", stamp)
data.Add("nonce_str", stamp)
data.Add("hash", s.Sign(data))
apiURL := fmt.Sprintf("%s/payment/query.html", s.apiURL)
resp, err := http.PostForm(apiURL, data)
if err != nil {
return fmt.Errorf("error with http reqeust: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var r struct {
ErrCode int `json:"errcode"`
Data struct {
Status string `json:"status"`
OpenOrderId string `json:"open_order_id"`
} `json:"data,omitempty"`
ErrMsg string `json:"errmsg"`
Hash string `json:"hash"`
}
err = utils.JsonDecode(string(body), &r)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if r.ErrCode == 0 && r.Data.Status == "OD" {
return nil
} else {
logger.Debugf("%+v", r)
return errors.New("order not paid" + r.ErrMsg)
}
}

View File

@@ -0,0 +1,54 @@
package payment
// 支付渠道定义
const PayChannelAL = "alipay" // 支付宝
const PayChannelWX = "wxpay" // 微信支付
const PayChannelEpay = "epay" // 易支付
// 支付方式
const PayWayAL = "alipay"
const PayWayWX = "wxpay"
const (
Success = 0
Failure = 1
Closed = 2
)
type PayRequest struct {
OutTradeNo string // 商户订单号
Subject string // 商品名称
TotalFee string // 商品金额
ReturnURL string // 回调地址
NotifyURL string // 回调地址
// 易支付专有参数
Method string // 接口类型
Device string // 设备类型
PayWay string // 支付方式
ClientIP string //用户IP地址
OpenID string // 用户openid
}
type OrderInfo struct {
Mchid string // 商户号
OutTradeNo string // 商户订单号
TradeId string // 交易号
Amount string // 金额
Status int // 状态 0: 未支付 1: 已支付 2: 已关闭
PayTime string // 完成支付时间
}
func (o OrderInfo) Closed() bool {
return o.Status == Closed
}
func (o OrderInfo) Success() bool {
return o.Status == Success
}
type PayService interface {
Pay(params PayRequest) (string, error) // 生成支付链接
Query(outTradeNo string) (OrderInfo, error) // 查询订单
}

View File

@@ -1,19 +0,0 @@
package payment
type NotifyVo struct {
Status int
OutTradeNo string // 商户订单号
TradeId string // 交易ID
Amount string // 交易金额
Message string
Subject string
}
func (v NotifyVo) Success() bool {
return v.Status == Success
}
const (
Success = 0
Failure = 1
)

View File

@@ -1,141 +0,0 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"fmt"
"geekai/core/types"
"net/http"
"time"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/wechat/v3"
)
type WechatPayService struct {
config *types.WechatPayConfig
client *wechat.ClientV3
}
func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) {
config := appConfig.WechatPayConfig
if !config.Enabled {
logger.Info("Disabled WechatPay service")
return nil, nil
}
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, config.PrivateKey)
if err != nil {
return nil, fmt.Errorf("error with initialize WechatPay service: %v", err)
}
err = client.AutoVerifySign()
if err != nil {
return nil, fmt.Errorf("error with autoVerifySign: %v", err)
}
//client.DebugSwitch = gopay.DebugOn
return &WechatPayService{config: &config, client: client}, nil
}
type WechatPayParams struct {
OutTradeNo string `json:"out_trade_no"`
TotalFee int `json:"total_fee"`
Subject string `json:"subject"`
ClientIP string `json:"client_ip"`
ReturnURL string `json:"return_url"`
NotifyURL string `json:"notify_url"`
}
func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", params.TotalFee).
Set("currency", "CNY")
})
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.CodeUrl, nil
}
func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", params.TotalFee).
Set("currency", "CNY")
}).
SetBodyMap("scene_info", func(bm gopay.BodyMap) {
bm.Set("payer_client_ip", params.ClientIP).
SetBodyMap("h5_info", func(bm gopay.BodyMap) {
bm.Set("type", "Wap")
})
})
wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.H5Url, nil
}
type NotifyResponse struct {
Code string `json:"code"`
Message string `xml:"message"`
}
// TradeVerify 交易验证
func (s *WechatPayService) TradeVerify(request *http.Request) NotifyVo {
notifyReq, err := wechat.V3ParseNotify(request)
if err != nil {
return NotifyVo{Status: 1, Message: fmt.Sprintf("error with client v3 parse notify: %v", err)}
}
// TODO: 这里验签程序有 Bug一直报错crypto/rsa: verification error先暂时取消验签
//err = notifyReq.VerifySignByPK(s.client.WxPublicKey())
//if err != nil {
// return fmt.Errorf("error with client v3 verify sign: %v", err)
//}
// 解密支付密文,验证订单信息
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
if err != nil {
return NotifyVo{Status: Failure, Message: fmt.Sprintf("error with client v3 decrypt: %v", err)}
}
return NotifyVo{
Status: Success,
OutTradeNo: result.OutTradeNo,
TradeId: result.TransactionId,
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
}
}

View File

@@ -0,0 +1,216 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"fmt"
"geekai/core/types"
"geekai/utils"
"net/http"
"os"
"time"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/wechat/v3"
)
type WxPayService struct {
config *types.WxPayConfig
client *wechat.ClientV3
}
func NewWxpayService(sysConfig *types.SystemConfig) (*WxPayService, error) {
config := sysConfig.Payment.WxPay
if !config.Enabled {
logger.Debug("Disabled WechatPay service")
}
service := &WxPayService{config: &config}
if config.Enabled {
err := service.UpdateConfig(&config)
if err != nil {
logger.Errorf("微信支付服务初始化失败: %v", err)
}
}
return service, nil
}
func (s *WxPayService) UpdateConfig(config *types.WxPayConfig) error {
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, config.PrivateKey)
if err != nil {
return fmt.Errorf("error with initialize WechatPay service: %v", err)
}
err = client.AutoVerifySign()
if err != nil {
return fmt.Errorf("error with autoVerifySign: %v", err)
}
s.client = client
if os.Getenv("GEEKAI_DEBUG") == "true" {
logger.Info("WechatPay Debug mode is enabled")
client.DebugSwitch = gopay.DebugOn
}
s.config = config
return nil
}
func (s *WxPayService) Pay(params PayRequest) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", utils.IntValue(params.TotalFee, 0)).
Set("currency", "CNY")
})
if params.Device == "mobile" {
bm.SetBodyMap("scene_info", func(bm gopay.BodyMap) {
bm.Set("payer_client_ip", params.ClientIP)
}).SetBodyMap("payer", func(bm gopay.BodyMap) {
bm.Set("openid", params.OpenID)
})
wxRsp, err := s.client.V3TransactionJsapi(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction Jsapi: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.PrepayId, nil
} else if params.Device == "pc" {
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.CodeUrl, nil
}
return "", nil
}
func (s *WxPayService) Query(outTradeNo string) (OrderInfo, error) {
wxRsp, err := s.client.V3TransactionQueryOrder(context.Background(), wechat.OutTradeNo, outTradeNo)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with client v3 transaction query: %v", err)
}
if wxRsp.Code != wechat.Success {
return OrderInfo{}, fmt.Errorf("error status with querying order: %v", wxRsp.Error)
}
if wxRsp.Response.TradeState == "CLOSED" {
return OrderInfo{Status: Closed}, nil
}
orderInfo := OrderInfo{
OutTradeNo: wxRsp.Response.OutTradeNo,
TradeId: wxRsp.Response.TransactionId,
Amount: fmt.Sprintf("%d", wxRsp.Response.Amount.Total/100),
PayTime: wxRsp.Response.SuccessTime,
}
if wxRsp.Response.TradeState == "SUCCESS" {
orderInfo.Status = Success
} else {
orderInfo.Status = Failure
}
return orderInfo, nil
}
// TradeVerify 交易验证
func (s *WxPayService) TradeVerify(request *http.Request) (OrderInfo, error) {
notifyReq, err := wechat.V3ParseNotify(request)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with client v3 parse notify: %v", err)
}
// 解密支付密文,验证订单信息
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with client v3 decrypt: %v", err)
}
return OrderInfo{
Status: Success,
OutTradeNo: result.OutTradeNo,
TradeId: result.TransactionId,
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
PayTime: result.SuccessTime,
}, nil
}
// func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// // 初始化 BodyMap
// bm := make(gopay.BodyMap)
// bm.Set("appid", s.config.AppId).
// Set("mchid", s.config.MchId).
// Set("description", params.Subject).
// Set("out_trade_no", params.OutTradeNo).
// Set("time_expire", expire).
// Set("notify_url", params.NotifyURL).
// SetBodyMap("amount", func(bm gopay.BodyMap) {
// bm.Set("total", params.TotalFee).
// Set("currency", "CNY")
// })
// wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
// if err != nil {
// return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
// }
// if wxRsp.Code != wechat.Success {
// return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
// }
// return wxRsp.Response.CodeUrl, nil
// }
// func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// // 初始化 BodyMap
// bm := make(gopay.BodyMap)
// bm.Set("appid", s.config.AppId).
// Set("mchid", s.config.MchId).
// Set("description", params.Subject).
// Set("out_trade_no", params.OutTradeNo).
// Set("time_expire", expire).
// Set("notify_url", params.NotifyURL).
// SetBodyMap("amount", func(bm gopay.BodyMap) {
// bm.Set("total", params.TotalFee).
// Set("currency", "CNY")
// }).
// SetBodyMap("scene_info", func(bm gopay.BodyMap) {
// bm.Set("payer_client_ip", params.ClientIP).
// SetBodyMap("h5_info", func(bm gopay.BodyMap) {
// bm.Set("type", "Wap")
// })
// })
// wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
// if err != nil {
// return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
// }
// if wxRsp.Code != wechat.Success {
// return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
// }
// return wxRsp.Response.H5Url, nil
// }
// type NotifyResponse struct {
// Code string `json:"code"`
// Message string `xml:"message"`
// }
var _ PayService = (*WxPayService)(nil)

View File

@@ -10,36 +10,57 @@ package sms
import (
"fmt"
"geekai/core/types"
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
)
type AliYunSmsService struct {
config *types.SmsConfigAli
client *dysmsapi.Client
domain string
zoneId string
}
func NewAliYunSmsService(appConfig *types.AppConfig) (*AliYunSmsService, error) {
config := &appConfig.SMS.Ali
// 创建阿里云短信客户端
func NewAliYunSmsService(sysConfig *types.SystemConfig) (*AliYunSmsService, error) {
config := &sysConfig.SMS.Ali
domain := "dysmsapi.aliyuncs.com"
zoneId := "cn-hangzhou"
s := AliYunSmsService{
config: config,
domain: domain,
zoneId: zoneId,
}
if sysConfig.SMS.Active == Ali {
err := s.UpdateConfig(config)
if err != nil {
logger.Errorf("阿里云短信初始化失败: %v", err)
}
}
return &s, nil
}
func (s *AliYunSmsService) UpdateConfig(config *types.SmsConfigAli) error {
client, err := dysmsapi.NewClientWithAccessKey(
"cn-hangzhou",
s.zoneId,
config.AccessKey,
config.AccessSecret)
if err != nil {
return nil, fmt.Errorf("failed to create client: %v", err)
return fmt.Errorf("failed to create client: %v", err)
}
return &AliYunSmsService{
config: config,
client: client,
}, nil
s.client = client
s.config = config
return nil
}
func (s *AliYunSmsService) SendVerifyCode(mobile string, code int) error {
if s.client == nil {
return fmt.Errorf("阿里云短信服务未初始化")
}
// 创建短信请求并设置参数
request := dysmsapi.CreateSendSmsRequest()
request.Scheme = "https"
request.Domain = s.config.Domain
request.Domain = s.domain
request.PhoneNumbers = mobile
request.SignName = s.config.Sign
request.TemplateCode = s.config.CodeTempId

View File

@@ -20,19 +20,20 @@ import (
type BaoSmsService struct {
config *types.SmsConfigBao
domain string
}
func NewSmsBaoSmsService(appConfig *types.AppConfig) *BaoSmsService {
config := appConfig.SMS.Bao
if config.Domain == "" { // use default domain
config.Domain = "api.smsbao.com"
logger.Infof("Using default domain for SMS-BAO: %s", config.Domain)
}
func NewBaoSmsService(sysConfig *types.SystemConfig) *BaoSmsService {
return &BaoSmsService{
config: &config,
config: &sysConfig.SMS.Bao,
domain: "api.smsbao.com",
}
}
func (s *BaoSmsService) UpdateConfig(config *types.SmsConfigBao) {
s.config = config
}
var errMsg = map[string]string{
"0": "短信发送成功",
"-1": "参数不全",
@@ -56,7 +57,7 @@ func (s *BaoSmsService) SendVerifyCode(mobile string, code int) error {
params.Set("m", mobile)
params.Set("c", content)
apiURL := fmt.Sprintf("https://%s/sms?%s", s.config.Domain, params.Encode())
apiURL := fmt.Sprintf("https://%s/sms?%s", s.domain, params.Encode())
response, err := http.Get(apiURL)
if err != nil {
return err

View File

@@ -10,37 +10,32 @@ package sms
import (
"geekai/core/types"
logger2 "geekai/logger"
"strings"
)
type ServiceManager struct {
handler Service
type SmsManager struct {
aliyun *AliYunSmsService
bao *BaoSmsService
active string
}
var logger = logger2.GetLogger()
func NewSendServiceManager(config *types.AppConfig) (*ServiceManager, error) {
active := Ali
if config.SMS.Active != "" {
active = strings.ToUpper(config.SMS.Active)
}
var handler Service
switch active {
case Ali:
client, err := NewAliYunSmsService(config)
if err != nil {
return nil, err
}
handler = client
break
case Bao:
handler = NewSmsBaoSmsService(config)
break
}
func NewSmsManager(sysConfig *types.SystemConfig, aliyun *AliYunSmsService, bao *BaoSmsService) (*SmsManager, error) {
return &ServiceManager{handler: handler}, nil
return &SmsManager{
active: sysConfig.SMS.Active,
aliyun: aliyun,
bao: bao,
}, nil
}
func (m *ServiceManager) GetService() Service {
return m.handler
func (m *SmsManager) GetService() Service {
if m.active == Ali {
return m.aliyun
}
return m.bao
}
func (m *SmsManager) SetActive(active string) {
m.active = active
}

View File

@@ -0,0 +1,109 @@
package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"time"
"github.com/go-redis/redis/v8"
"github.com/imroc/req/v3"
)
type WxLoginService struct {
config types.WxLoginConfig
client *req.Client
redisClient *redis.Client
}
const loginStateKeyPrefix = "wx_login_state/"
type LoginStatus struct {
Status string `json:"status"`
OpenID string `json:"openid"`
Token string `json:"token"`
}
const (
LoginStatusPending = "pending"
LoginStatusSuccess = "success"
LoginStatusExpired = "expired" // 登录失效,需要重新登录
)
func NewWxLoginService(config types.WxLoginConfig, redisClient *redis.Client) *WxLoginService {
return &WxLoginService{
config: config,
client: req.C().SetTimeout(10 * time.Second),
redisClient: redisClient,
}
}
func (s *WxLoginService) UpdateConfig(config types.WxLoginConfig) {
s.config = config
}
func (s *WxLoginService) GetLoginQrCodeUrl(state string) (string, error) {
if s.config.ApiKey == "" {
return "", errors.New("无效的 API Key")
}
url := fmt.Sprintf("%s/api/auth/wechat/login", types.GeekAPIURL)
var res struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data struct {
Ticket string `json:"ticket"`
Url string `json:"url"`
} `json:"data"`
}
r, err := s.client.R().
SetHeader("Authorization", s.config.ApiKey).
SetBody(map[string]string{
"notify_url": s.config.NotifyURL,
"state": state,
}).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
return "", fmt.Errorf("请求 API 失败:%v", err)
}
if res.Code != types.Success {
return "", fmt.Errorf("请求 API 失败:%s", res.Message)
}
status := LoginStatus{
Status: LoginStatusPending,
OpenID: "",
}
s.redisClient.Set(context.Background(), loginStateKeyPrefix+state, utils.JsonEncode(status), time.Hour)
return res.Data.Url, nil
}
func (s *WxLoginService) GetLoginStatus(state string) (*LoginStatus, error) {
result, err := s.redisClient.Get(context.Background(), loginStateKeyPrefix+state).Result()
if err != nil {
return nil, errors.New("登录失败")
}
var status LoginStatus
err = utils.JsonDecode(result, &status)
if err != nil {
return nil, errors.New("登录失败")
}
return &status, nil
}
func (s *WxLoginService) SetLoginStatus(state string, status LoginStatus) {
s.redisClient.Set(context.Background(), loginStateKeyPrefix+state, utils.JsonEncode(status), time.Hour)
}

View File

@@ -7,21 +7,22 @@ import (
// Order 充值订单
type Order struct {
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
UserId uint `gorm:"column:user_id;type:int;not null;comment:用户ID" json:"user_id"`
ProductId uint `gorm:"column:product_id;type:int;not null;comment:产品ID" json:"product_id"`
Username string `gorm:"column:username;type:varchar(30);not null;comment:用户名" json:"username"`
OrderNo string `gorm:"column:order_no;type:varchar(30);uniqueIndex;not null;comment:订单ID" json:"order_no"`
TradeNo string `gorm:"column:trade_no;type:varchar(60);comment:支付平台交易流水号" json:"trade_no"`
Subject string `gorm:"column:subject;type:varchar(100);not null;comment:订单产品" json:"subject"`
Amount float64 `gorm:"column:amount;type:decimal(10,2);not null;default:0.00;comment:订单金额" json:"amount"`
Status types.OrderStatus `gorm:"column:status;type:tinyint(1);not null;default:0;comment:订单状态0待支付1已扫码2支付成功" json:"status"`
Remark string `gorm:"column:remark;type:varchar(255);not null;comment:备注" json:"remark"`
PayTime int64 `gorm:"column:pay_time;type:int;comment:支付时间" json:"pay_time"`
PayWay string `gorm:"column:pay_way;type:varchar(20);not null;comment:支付方式" json:"pay_way"`
PayType string `gorm:"column:pay_type;type:varchar(30);not null;comment:支付类型" json:"pay_type"`
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null" json:"updated_at"`
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
UserId uint `gorm:"column:user_id;type:int;not null;comment:用户ID" json:"user_id"`
ProductId uint `gorm:"column:product_id;type:int;not null;comment:产品ID" json:"product_id"`
Username string `gorm:"column:username;type:varchar(30);not null;comment:用户名" json:"username"`
OrderNo string `gorm:"column:order_no;type:varchar(30);uniqueIndex;not null;comment:订单ID" json:"order_no"`
TradeNo string `gorm:"column:trade_no;type:varchar(60);comment:支付平台交易流水号" json:"trade_no"`
Subject string `gorm:"column:subject;type:varchar(100);not null;comment:订单产品" json:"subject"`
Amount float64 `gorm:"column:amount;type:decimal(10,2);not null;default:0.00;comment:订单金额" json:"amount"`
Status types.OrderStatus `gorm:"column:status;type:tinyint(1);not null;default:0;comment:订单状态0待支付1已扫码2支付成功" json:"status"`
Remark string `gorm:"column:remark;type:varchar(255);not null;comment:备注" json:"remark"`
PayTime int64 `gorm:"column:pay_time;type:int;comment:支付时间" json:"pay_time"`
PayWay string `gorm:"column:pay_way;type:varchar(20);not null;comment:支付方式" json:"pay_way"`
Channel string `gorm:"column:channel;type:varchar(30);not null;comment:支付类型渠道:支付宝,微信,聚合支付"` // 支付类型渠道
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null" json:"updated_at"`
Checked bool `gorm:"column:checked;type:tinyint;not null;default:0;comment:是否已检查"` // 是否已检查
}
func (m *Order) TableName() string {

View File

@@ -6,18 +6,18 @@ import (
type Order struct {
BaseVo
UserId uint `json:"user_id"`
ProductId uint `json:"product_id"`
Username string `json:"username"`
OrderNo string `json:"order_no"`
TradeNo string `json:"trade_no"`
Subject string `json:"subject"`
Amount float64 `json:"amount"`
Status types.OrderStatus `json:"status"`
PayTime int64 `json:"pay_time"`
PayWay string `json:"pay_way"`
PayType string `json:"pay_type"`
PayMethod string `json:"pay_method"`
PayName string `json:"pay_name"`
Remark types.OrderRemark `json:"remark"`
UserId uint `json:"user_id"`
ProductId uint `json:"product_id"`
Username string `json:"username"`
OrderNo string `json:"order_no"`
TradeNo string `json:"trade_no"`
Subject string `json:"subject"`
Amount float64 `json:"amount"`
Status types.OrderStatus `json:"status"`
PayTime int64 `json:"pay_time"`
PayWay string `json:"pay_way"`
Channel string `json:"channel"`
ChannelName string `json:"channel_name"`
PayName string `json:"pay_name"`
Remark types.OrderRemark `json:"remark"`
}

View File

@@ -15,6 +15,7 @@ import (
"io"
"net/http"
"net/url"
"path/filepath"
)
var logger = logger2.GetLogger()
@@ -92,3 +93,11 @@ func GetBaseURL(strURL string) string {
}
return fmt.Sprintf("%s://%s", u.Scheme, u.Host)
}
func GetImgExt(filename string) string {
ext := filepath.Ext(filename)
if ext == "" {
return ".png"
}
return ext
}