From 536b4b8056b8840a671a63f371353129e3108b5b Mon Sep 17 00:00:00 2001 From: GeekMaster Date: Sun, 24 Aug 2025 19:32:45 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E4=BB=98=EF=BC=8COSS=20=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E9=87=8D=E6=9E=84=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/core/app_server.go | 77 +-- api/core/config.go | 77 +++ api/core/midware/auth_midware.go | 107 ++++ api/core/midware/parameter_midware.go | 80 +++ api/core/midware/rate_limit_midware.go | 43 ++ api/core/types/config.go | 54 +- api/core/types/geekai.go | 5 + api/core/types/order.go | 25 +- api/core/types/payment.go | 22 +- api/core/types/session.go | 1 + api/core/types/sms.go | 3 - api/go.mod | 1 + api/go.sum | 2 + api/handler/admin/admin_handler.go | 5 +- api/handler/admin/config_handler.go | 69 +-- api/handler/admin/order_handler.go | 12 +- api/handler/admin/upload_handler.go | 2 +- api/handler/captcha_handler.go | 6 +- api/handler/chat_handler.go | 8 +- api/handler/dalle_handler.go | 2 +- api/handler/function_handler.go | 8 +- api/handler/menu_handler.go | 2 +- api/handler/mj_handler.go | 16 +- api/handler/order_handler.go | 13 +- api/handler/payment_handler.go | 499 +++++++++--------- api/handler/prompt_handler.go | 22 +- api/handler/realtime_handler.go | 4 +- api/handler/sd_handler.go | 6 +- api/handler/sms_handler.go | 14 +- api/handler/suno_handler.go | 4 +- api/handler/test_handler.go | 7 +- api/handler/user_handler.go | 260 +++++---- api/handler/video_handler.go | 10 +- api/main.go | 63 ++- api/service/captcha_service.go | 4 +- api/service/config_migration.go | 120 +---- api/service/data_fix_service.go | 66 +++ api/service/license_service.go | 8 +- api/service/oss/aliyun_oss.go | 38 +- api/service/oss/localstorage.go | 12 +- api/service/oss/minio_oss.go | 24 +- api/service/oss/qiniu_oss.go | 35 +- api/service/oss/uploader_manager.go | 61 +-- api/service/payment/alipay_service.go | 131 ++--- .../{geekpay_service.go => epay_service.go} | 112 +++- api/service/payment/hupipay_serive.go | 171 ------ api/service/payment/pay_service.go | 54 ++ api/service/payment/types.go | 19 - api/service/payment/wepay_service.go | 141 ----- api/service/payment/wxpay_service.go | 216 ++++++++ api/service/sms/aliyun.go | 43 +- api/service/sms/bao.go | 17 +- api/service/sms/service_manager.go | 43 +- api/service/wxlogin_service.go | 109 ++++ api/store/model/order.go | 31 +- api/store/vo/order.go | 28 +- api/utils/net.go | 9 + 57 files changed, 1663 insertions(+), 1358 deletions(-) create mode 100644 api/core/midware/auth_midware.go create mode 100644 api/core/midware/parameter_midware.go create mode 100644 api/core/midware/rate_limit_midware.go create mode 100644 api/service/data_fix_service.go rename api/service/payment/{geekpay_service.go => epay_service.go} (50%) delete mode 100644 api/service/payment/hupipay_serive.go create mode 100644 api/service/payment/pay_service.go delete mode 100644 api/service/payment/types.go delete mode 100644 api/service/payment/wepay_service.go create mode 100644 api/service/payment/wxpay_service.go create mode 100644 api/service/wxlogin_service.go diff --git a/api/core/app_server.go b/api/core/app_server.go index 901dd033..8818d8ab 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -82,18 +82,21 @@ type AppServer struct { Config *types.AppConfig Engine *gin.Engine SysConfig *types.SystemConfig // system config cache + Redis *redis.Client } -func NewServer(appConfig *types.AppConfig) *AppServer { +func NewServer(appConfig *types.AppConfig, redis *redis.Client, sysConfig *types.SystemConfig) *AppServer { gin.SetMode(gin.ReleaseMode) gin.DefaultWriter = io.Discard return &AppServer{ - Config: appConfig, - Engine: gin.Default(), + Config: appConfig, + Redis: redis, + Engine: gin.Default(), + SysConfig: sysConfig, } } -func (s *AppServer) Init(debug bool, client *redis.Client) { +func (s *AppServer) Init(client *redis.Client) { s.Engine.Use(corsMiddleware()) s.Engine.Use(staticResourceMiddleware()) s.Engine.Use(authorizeMiddleware(s, client)) @@ -104,21 +107,6 @@ func (s *AppServer) Init(debug bool, client *redis.Client) { } func (s *AppServer) Run(db *gorm.DB) error { - - // 重命名 config 表字段 - if db.Migrator().HasColumn(&model.Config{}, "config_json") { - db.Migrator().RenameColumn(&model.Config{}, "config_json", "value") - } - if db.Migrator().HasColumn(&model.Config{}, "marker") { - db.Migrator().RenameColumn(&model.Config{}, "marker", "name") - } - if db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") { - db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key") - } - if db.Migrator().HasIndex(&model.Config{}, "marker") { - db.Migrator().DropIndex(&model.Config{}, "marker") - } - // load system configs var sysConfig model.Config err := db.Where("name", "system").First(&sysConfig).Error @@ -130,57 +118,6 @@ func (s *AppServer) Run(db *gorm.DB) error { return fmt.Errorf("failed to decode system config: %v", err) } - // 迁移数据表 - logger.Info("Migrating database tables...") - db.AutoMigrate( - &model.ChatItem{}, - &model.ChatMessage{}, - &model.ChatRole{}, - &model.ChatModel{}, - &model.InviteCode{}, - &model.InviteLog{}, - &model.Menu{}, - &model.Order{}, - &model.Product{}, - &model.User{}, - &model.Function{}, - &model.File{}, - &model.Redeem{}, - &model.Config{}, - &model.ApiKey{}, - &model.AdminUser{}, - &model.AppType{}, - &model.SdJob{}, - &model.SunoJob{}, - &model.PowerLog{}, - &model.VideoJob{}, - &model.MidJourneyJob{}, - &model.UserLoginLog{}, - &model.DallJob{}, - &model.JimengJob{}, - ) - // 手动删除字段 - if db.Migrator().HasColumn(&model.Order{}, "deleted_at") { - db.Migrator().DropColumn(&model.Order{}, "deleted_at") - } - if db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") { - db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at") - } - if db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") { - db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at") - } - if db.Migrator().HasColumn(&model.User{}, "chat_config") { - db.Migrator().DropColumn(&model.User{}, "chat_config") - } - if db.Migrator().HasColumn(&model.ChatModel{}, "category") { - db.Migrator().DropColumn(&model.ChatModel{}, "category") - } - if db.Migrator().HasColumn(&model.ChatModel{}, "description") { - db.Migrator().DropColumn(&model.ChatModel{}, "description") - } - - logger.Info("Database tables migrated successfully") - // 统计安装信息 go func() { info, err := host.Info() diff --git a/api/core/config.go b/api/core/config.go index 365f93c4..70f8727f 100644 --- a/api/core/config.go +++ b/api/core/config.go @@ -11,10 +11,12 @@ import ( "bytes" "geekai/core/types" logger2 "geekai/logger" + "geekai/store/model" "geekai/utils" "os" "github.com/BurntSushi/toml" + "gorm.io/gorm" ) var logger = logger2.GetLogger() @@ -72,3 +74,78 @@ func SaveConfig(config *types.AppConfig) error { return os.WriteFile(config.Path, buf.Bytes(), 0644) } + +func LoadSystemConfig(db *gorm.DB) *types.SystemConfig { + // 加载系统配置 + var sysConfig model.Config + var baseConfig types.BaseConfig + db.Where("name", "system").First(&sysConfig) + err := utils.JsonDecode(sysConfig.Value, &baseConfig) + if err != nil { + logger.Error("load system config error: ", err) + } + + // 加载许可证配置 + var license types.License + sysConfig.Id = 0 + db.Where("name", types.ConfigKeyLicense).First(&sysConfig) + err = utils.JsonDecode(sysConfig.Value, &license) + if err != nil { + logger.Error("load license config error: ", err) + } + + // 加载 GeekAPI 配置 + var geekAPIConfig types.GeekAPIConfig + sysConfig.Id = 0 + db.Where("name", types.ConfigKeyGeekAPI).First(&sysConfig) + err = utils.JsonDecode(sysConfig.Value, &geekAPIConfig) + if err != nil { + logger.Error("load geek service config error: ", err) + } + + // 加载短信配置 + var smsConfig types.SMSConfig + sysConfig.Id = 0 + db.Where("name", types.ConfigKeySms).First(&sysConfig) + err = utils.JsonDecode(sysConfig.Value, &smsConfig) + if err != nil { + logger.Error("load sms config error: ", err) + } + + // 加载 OSS 配置 + var ossConfig types.OSSConfig + sysConfig.Id = 0 + db.Where("name", types.ConfigKeyOss).First(&sysConfig) + err = utils.JsonDecode(sysConfig.Value, &ossConfig) + if err != nil { + logger.Error("load oss config error: ", err) + } + + // 加载 SMTP 配置 + var smtpConfig types.SmtpConfig + sysConfig.Id = 0 + db.Where("name", types.ConfigKeySmtp).First(&sysConfig) + err = utils.JsonDecode(sysConfig.Value, &smtpConfig) + if err != nil { + logger.Error("load smtp config error: ", err) + } + + // 加载支付配置 + var paymentConfig types.PaymentConfig + sysConfig.Id = 0 + db.Where("name", types.ConfigKeyPayment).First(&sysConfig) + err = utils.JsonDecode(sysConfig.Value, &paymentConfig) + if err != nil { + logger.Error("load payment config error: ", err) + } + + return &types.SystemConfig{ + Base: baseConfig, + License: license, + SMS: smsConfig, + OSS: ossConfig, + SMTP: smtpConfig, + Payment: paymentConfig, + GeekAPI: geekAPIConfig, + } +} diff --git a/api/core/midware/auth_midware.go b/api/core/midware/auth_midware.go new file mode 100644 index 00000000..027ad57f --- /dev/null +++ b/api/core/midware/auth_midware.go @@ -0,0 +1,107 @@ +package midware + +import ( + "context" + "fmt" + "geekai/core/types" + "geekai/utils" + "geekai/utils/resp" + "time" + + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" + "github.com/golang-jwt/jwt" +) + +// 用户授权验证 +func UserAuthMiddleware(secretKey string, redis *redis.Client) gin.HandlerFunc { + return func(c *gin.Context) { + tokenString := c.GetHeader(types.UserAuthHeader) + if tokenString == "" { + resp.NotAuth(c, "无效的授权令牌") + c.Abort() + return + } + + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("不支持的令牌签名方法: %v", token.Header["alg"]) + } + return []byte(secretKey), nil + }) + + if err != nil { + resp.NotAuth(c, fmt.Sprintf("解析授权令牌失败: %v", err)) + c.Abort() + return + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + resp.NotAuth(c, "令牌无效") + c.Abort() + return + } + + expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0) + if expr > 0 && int64(expr) < time.Now().Unix() { + resp.NotAuth(c, "令牌过期") + c.Abort() + return + } + + key := fmt.Sprintf("users/%v", claims["user_id"]) + if _, err := redis.Get(context.Background(), key).Result(); err != nil { + resp.NotAuth(c, "当前用户已退出登录") + c.Abort() + return + } + c.Set(types.LoginUserID, claims["user_id"]) + } +} + +func AdminAuthMiddleware(secretKey string, redis *redis.Client) gin.HandlerFunc { + return func(c *gin.Context) { + tokenString := c.GetHeader(types.AdminAuthHeader) + if tokenString == "" { + resp.NotAuth(c, "无效的授权令牌") + c.Abort() + return + } + + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("不支持的令牌签名方法: %v", token.Header["alg"]) + } + return []byte(secretKey), nil + }) + + if err != nil { + resp.NotAuth(c, fmt.Sprintf("解析授权令牌失败: %v", err)) + c.Abort() + return + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + resp.NotAuth(c, "令牌无效") + c.Abort() + return + } + + expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0) + if expr > 0 && int64(expr) < time.Now().Unix() { + resp.NotAuth(c, "令牌过期") + c.Abort() + return + } + + key := fmt.Sprintf("admin/%v", claims["user_id"]) + if _, err := redis.Get(context.Background(), key).Result(); err != nil { + resp.NotAuth(c, "当前用户已退出登录") + c.Abort() + return + } + c.Set(types.AdminUserID, claims["user_id"]) + } +} diff --git a/api/core/midware/parameter_midware.go b/api/core/midware/parameter_midware.go new file mode 100644 index 00000000..ae5d84d1 --- /dev/null +++ b/api/core/midware/parameter_midware.go @@ -0,0 +1,80 @@ +package midware + +import ( + "bytes" + "geekai/utils" + "io" + "strings" + + "github.com/gin-gonic/gin" +) + +// 统一参数处理 +func ParameterHandlerMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + // GET 参数处理 + params := c.Request.URL.Query() + for key, values := range params { + for i, value := range values { + params[key][i] = strings.TrimSpace(value) + } + } + // update get parameters + c.Request.URL.RawQuery = params.Encode() + // skip file upload requests + contentType := c.Request.Header.Get("Content-Type") + if strings.Contains(contentType, "multipart/form-data") { + c.Next() + return + } + + if strings.Contains(contentType, "application/json") { + // process POST JSON request body + bodyBytes, err := io.ReadAll(c.Request.Body) + if err != nil { + c.Next() + return + } + + // 还原请求体 + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + // 将请求体解析为 JSON + var jsonData map[string]any + if err := c.ShouldBindJSON(&jsonData); err != nil { + c.Next() + return + } + + // 对 JSON 数据中的字符串值去除两端空格 + trimJSONStrings(jsonData) + // 更新请求体 + c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData))) + } + + c.Next() + } +} + +// 递归对 JSON 数据中的字符串值去除两端空格 +func trimJSONStrings(data any) { + switch v := data.(type) { + case map[string]any: + for key, value := range v { + switch valueType := value.(type) { + case string: + v[key] = strings.TrimSpace(valueType) + case map[string]any, []any: + trimJSONStrings(value) + } + } + case []any: + for i, value := range v { + switch valueType := value.(type) { + case string: + v[i] = strings.TrimSpace(valueType) + case map[string]any, []any: + trimJSONStrings(value) + } + } + } +} diff --git a/api/core/midware/rate_limit_midware.go b/api/core/midware/rate_limit_midware.go new file mode 100644 index 00000000..bf8a5166 --- /dev/null +++ b/api/core/midware/rate_limit_midware.go @@ -0,0 +1,43 @@ +package midware + +import ( + "context" + "fmt" + "geekai/core/types" + "geekai/utils" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" +) + +// RateLimitEvery 使用 Redis 做固定间隔限流:在 interval 内仅允许一次请求 +// Key 优先使用登录用户ID,若没有则退化为 route + IP +func RateLimitEvery(redisClient *redis.Client, interval time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + keyID := "" + if userID, ok := c.Get(types.LoginUserID); ok { + keyID = fmt.Sprintf("user:%s", utils.InterfaceToString(userID)) + } else { + keyID = fmt.Sprintf("ip:%s", c.ClientIP()) + } + + fullPath := c.FullPath() + if fullPath == "" { + fullPath = c.Request.URL.Path + } + key := fmt.Sprintf("rl:%s:%s", fullPath, keyID) + + okSet, err := redisClient.SetNX(context.Background(), key, 1, interval).Result() + if err != nil { + // Redis 异常时放行,避免误伤可用性 + return + } + if !okSet { + c.JSON(http.StatusTooManyRequests, types.BizVo{Code: types.Failed, Message: "请求过于频繁,请稍后重试"}) + c.Abort() + return + } + } +} diff --git a/api/core/types/config.go b/api/core/types/config.go index 248b4bfe..7ea21ac4 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -17,18 +17,17 @@ type AppConfig struct { Session Session AdminSession Session ProxyURL string - MysqlDns string // mysql 连接地址 - StaticDir string // 静态资源目录 - StaticUrl string // 静态资源 URL - Redis RedisConfig // redis 连接信息 - SMS SMSConfig // send mobile message config - OSS OSSConfig // OSS config - SmtpConfig SmtpConfig // 邮件发送配置 - AlipayConfig AlipayConfig // 支付宝支付渠道配置 - HuPiPayConfig HuPiPayConfig // 虎皮椒支付配置 - GeekPayConfig GeekPayConfig // GEEK 支付配置 - WechatPayConfig WxPayConfig // 微信支付渠道配置 - TikaHost string // TiKa 服务器地址 + MysqlDns string // mysql 连接地址 + StaticDir string // 静态资源目录 + StaticUrl string // 静态资源 URL + Redis RedisConfig // redis 连接信息 + SMS SMSConfig // send mobile message config + OSS OSSConfig // OSS config + SmtpConfig SmtpConfig // 邮件发送配置 + AlipayConfig AlipayConfig // 支付宝支付渠道配置 + GeekPayConfig EpayConfig // GEEK 支付配置 + WechatPayConfig WxPayConfig // 微信支付渠道配置 + TikaHost string // TiKa 服务器地址 } type RedisConfig struct { @@ -58,7 +57,7 @@ func (c RedisConfig) Url() string { return fmt.Sprintf("%s:%d", c.Host, c.Port) } -type SystemConfig struct { +type BaseConfig struct { Title string `json:"title,omitempty"` // 网站标题 Slogan string `json:"slogan,omitempty"` // 网站 slogan AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题 @@ -103,18 +102,29 @@ type SystemConfig struct { EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表 AssistantModelId int `json:"assistant_model_id"` // 用来做提示词,翻译的AI模型 id MaxFileSize int `json:"max_file_size"` // 最大文件大小,单位:MB +} +type SystemConfig struct { + Base BaseConfig + Payment PaymentConfig + OSS OSSConfig + SMS SMSConfig + SMTP SmtpConfig + GeekAPI GeekAPIConfig + Jimeng JimengConfig + License License } // 配置键名常量 const ( - ConfigKeySystem = "system" - ConfigKeyNotice = "notice" - ConfigKeyAgreement = "agreement" - ConfigKeyPrivacy = "privacy" - ConfigKeyGeekService = "geekai" - ConfigKeySms = "sms" - ConfigKeySmtp = "smtp" - ConfigKeyOss = "oss" - ConfigKeyPayment = "payment" + ConfigKeySystem = "system" + ConfigKeyNotice = "notice" + ConfigKeyAgreement = "agreement" + ConfigKeyPrivacy = "privacy" + ConfigKeyGeekAPI = "geekapi" + ConfigKeyLicense = "license" + ConfigKeySms = "sms" + ConfigKeySmtp = "smtp" + ConfigKeyOss = "oss" + ConfigKeyPayment = "payment" ) diff --git a/api/core/types/geekai.go b/api/core/types/geekai.go index 1a7afd1d..fd674f34 100644 --- a/api/core/types/geekai.go +++ b/api/core/types/geekai.go @@ -23,3 +23,8 @@ type WxLoginConfig struct { NotifyURL string `json:"notify_url"` // 登录成功回调 URL Enabled bool `json:"enabled"` // 是否启用微信登录 } + +type GeekAPIConfig struct { + Captcha CaptchaConfig + WxLogin WxLoginConfig +} diff --git a/api/core/types/order.go b/api/core/types/order.go index c0dd13ac..41096d37 100644 --- a/api/core/types/order.go +++ b/api/core/types/order.go @@ -13,27 +13,24 @@ const ( OrderNotPaid = OrderStatus(0) OrderScanned = OrderStatus(1) // 已扫码 OrderPaidSuccess = OrderStatus(2) + OrderPaidFailed = OrderStatus(3) ) type OrderRemark struct { - Days int `json:"days"` // 有效期 - Power int `json:"power"` // 增加算力点数 - Name string `json:"name"` // 产品名称 - Price float64 `json:"price"` - Discount float64 `json:"discount"` + Days int `json:"days"` // 有效期 + Power int `json:"power"` // 增加算力点数 + Name string `json:"name"` // 产品名称 + Price float64 `json:"price"` } -var PayMethods = map[string]string{ +// PayChannel 支付渠道 +var PayChannel = map[string]string{ "alipay": "支付宝商号", - "wechat": "微信商号", - "hupi": "虎皮椒", - "geek": "易支付", + "wxpay": "微信商号", + "epay": "易支付", } -var PayNames = map[string]string{ + +var PayWays = map[string]string{ "alipay": "支付宝", "wxpay": "微信支付", - "qqpay": "QQ钱包", - "jdpay": "京东支付", - "douyin": "抖音支付", - "paypal": "PayPal支付", } diff --git a/api/core/types/payment.go b/api/core/types/payment.go index 4ec45642..e60d71b5 100644 --- a/api/core/types/payment.go +++ b/api/core/types/payment.go @@ -1,19 +1,9 @@ package types type PaymentConfig struct { - AlipayConfig AlipayConfig `json:"alipay"` // 支付宝支付渠道配置 - GeekPayConfig GeekPayConfig `json:"geekpay"` // GEEK 支付配置 - WxPayConfig WxPayConfig `json:"wxpay"` // 微信支付渠道配置 - HuPiPayConfig HuPiPayConfig `json:"hupi"` // 虎皮椒支付渠道配置 -} - -type HuPiPayConfig struct { //虎皮椒第四方支付配置 - Enabled bool // 是否启用该支付通道 - AppId string // App ID - AppSecret string // app 密钥 - ApiURL string // 支付网关 - NotifyURL string // 异步通知地址 - ReturnURL string // 同步通知地址 + Alipay AlipayConfig `json:"alipay"` // 支付宝支付渠道配置 + Epay EpayConfig `json:"epay"` // GEEK 支付配置 + WxPay WxPayConfig `json:"wxpay"` // 微信支付渠道配置 } // AlipayConfig 支付宝支付配置 @@ -53,8 +43,8 @@ func (c *WxPayConfig) Equal(other *WxPayConfig) bool { c.Domain == other.Domain } -// GeekPayConfig 易支付配置 -type GeekPayConfig struct { +// EpayConfig 易支付配置 +type EpayConfig struct { Enabled bool `json:"enabled"` // 是否启用该支付通道 AppId string `json:"app_id"` // 商户 ID PrivateKey string `json:"private_key"` // 私钥 @@ -62,7 +52,7 @@ type GeekPayConfig struct { Domain string `json:"domain"` // 支付回调域名 } -func (c *GeekPayConfig) Equal(other *GeekPayConfig) bool { +func (c *EpayConfig) Equal(other *EpayConfig) bool { return c.AppId == other.AppId && c.PrivateKey == other.PrivateKey && c.ApiURL == other.ApiURL && diff --git a/api/core/types/session.go b/api/core/types/session.go index 9108e51a..652f4c1f 100644 --- a/api/core/types/session.go +++ b/api/core/types/session.go @@ -8,6 +8,7 @@ package types // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ const LoginUserID = "LOGIN_USER_ID" +const AdminUserID = "ADMIN_USER_ID" const LoginUserCache = "LOGIN_USER_CACHE" const UserAuthHeader = "Authorization" diff --git a/api/core/types/sms.go b/api/core/types/sms.go index 510e8071..3db10668 100644 --- a/api/core/types/sms.go +++ b/api/core/types/sms.go @@ -17,8 +17,6 @@ type SMSConfig struct { type SmsConfigAli struct { AccessKey string AccessSecret string - Product string - Domain string Sign string // 短信签名 CodeTempId string // 验证码短信模板 ID } @@ -27,7 +25,6 @@ type SmsConfigAli struct { type SmsConfigBao struct { Username string //短信宝平台注册的用户名 Password string //短信宝平台注册的密码 - Domain string //域名 Sign string // 短信签名 CodeTemplate string // 验证码短信模板 匹配 } diff --git a/api/go.mod b/api/go.mod index adbcf263..2eb0affc 100644 --- a/api/go.mod +++ b/api/go.mod @@ -27,6 +27,7 @@ require ( require ( github.com/go-pay/gopay v1.5.101 github.com/go-rod/rod v0.116.2 + github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/go-tika v0.3.1 github.com/microcosm-cc/bluemonday v1.0.26 github.com/sashabaranov/go-openai v1.38.1 diff --git a/api/go.sum b/api/go.sum index addc00ee..6b50c5ab 100644 --- a/api/go.sum +++ b/api/go.sum @@ -87,6 +87,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4 github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= diff --git a/api/handler/admin/admin_handler.go b/api/handler/admin/admin_handler.go index ba538b66..65922932 100644 --- a/api/handler/admin/admin_handler.go +++ b/api/handler/admin/admin_handler.go @@ -19,9 +19,10 @@ import ( "geekai/store/vo" "geekai/utils" "geekai/utils/resp" + "time" + "github.com/go-redis/redis/v8" "github.com/golang-jwt/jwt/v5" - "time" "github.com/gin-gonic/gin" "gorm.io/gorm" @@ -72,7 +73,7 @@ func (h *ManagerHandler) Login(c *gin.Context) { return } - if h.App.SysConfig.EnabledVerify { + if h.App.SysConfig.Base.EnabledVerify { var check bool if data.X != 0 { check = h.captcha.SlideCheck(data) diff --git a/api/handler/admin/config_handler.go b/api/handler/admin/config_handler.go index 1676ff56..22afe0ae 100644 --- a/api/handler/admin/config_handler.go +++ b/api/handler/admin/config_handler.go @@ -47,7 +47,6 @@ func (h *ConfigHandler) RegisterRoutes() { group.GET("get", h.Get) group.POST("active", h.Active) group.POST("test", h.Test) - group.GET("fixData", h.FixData) group.GET("license", h.GetLicense) } @@ -70,7 +69,7 @@ func (h *ConfigHandler) Update(c *gin.Context) { resp.ERROR(c, "系统配置解析失败: "+err.Error()) return } - if (sys.Copyright != payload.ConfigBak.Copyright) && !h.licenseService.GetLicense().Configs.DeCopy { + if (sys.Base.Copyright != payload.ConfigBak.Base.Copyright) && !h.licenseService.GetLicense().Configs.DeCopy { resp.ERROR(c, "您无权修改版权信息,请先联系作者获取授权") return } @@ -162,69 +161,3 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) { license := h.licenseService.GetLicense() resp.SUCCESS(c, license) } - -// FixData 修复数据 -func (h *ConfigHandler) FixData(c *gin.Context) { - resp.ERROR(c, "当前升级版本没有数据需要修正!") - //var fixed bool - //version := "data_fix_4.1.4" - //err := h.levelDB.Get(version, &fixed) - //if err == nil || fixed { - // resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作") - // return - //} - //tx := h.DB.Begin() - //var users []model.User - //err = tx.Find(&users).Error - //if err != nil { - // resp.ERROR(c, err.Error()) - // return - //} - //for _, user := range users { - // if user.Email != "" || user.Mobile != "" { - // continue - // } - // if utils.IsValidEmail(user.Username) { - // user.Email = user.Username - // } else if utils.IsValidMobile(user.Username) { - // user.Mobile = user.Username - // } - // err = tx.Save(&user).Error - // if err != nil { - // resp.ERROR(c, err.Error()) - // tx.Rollback() - // return - // } - //} - // - //var orders []model.Order - //err = h.DB.Find(&orders).Error - //if err != nil { - // resp.ERROR(c, err.Error()) - // return - //} - //for _, order := range orders { - // if order.PayWay == "支付宝" { - // order.PayWay = "alipay" - // order.PayType = "alipay" - // } else if order.PayWay == "微信支付" { - // order.PayWay = "wechat" - // order.PayType = "wxpay" - // } else if order.PayWay == "hupi" { - // order.PayType = "wxpay" - // } - // err = tx.Save(&order).Error - // if err != nil { - // resp.ERROR(c, err.Error()) - // tx.Rollback() - // return - // } - //} - //tx.Commit() - //err = h.levelDB.Put(version, true) - //if err != nil { - // resp.ERROR(c, err.Error()) - // return - //} - //resp.SUCCESS(c) -} diff --git a/api/handler/admin/order_handler.go b/api/handler/admin/order_handler.go index bebe0f2b..a4199b32 100644 --- a/api/handler/admin/order_handler.go +++ b/api/handler/admin/order_handler.go @@ -76,16 +76,16 @@ func (h *OrderHandler) List(c *gin.Context) { order.Id = item.Id order.CreatedAt = item.CreatedAt.Unix() order.UpdatedAt = item.UpdatedAt.Unix() - payMethod, ok := types.PayMethods[item.PayWay] + payChannel, ok := types.PayChannel[item.Channel] if !ok { - payMethod = item.PayWay + payChannel = item.Channel } - payName, ok := types.PayNames[item.PayType] + payWays, ok := types.PayWays[item.PayWay] if !ok { - payName = item.PayWay + payWays = item.PayWay } - order.PayMethod = payMethod - order.PayName = payName + order.ChannelName = payChannel + order.PayName = payWays list = append(list, order) } else { logger.Error(err) diff --git a/api/handler/admin/upload_handler.go b/api/handler/admin/upload_handler.go index 4a6a21ca..4d973f89 100644 --- a/api/handler/admin/upload_handler.go +++ b/api/handler/admin/upload_handler.go @@ -41,7 +41,7 @@ func (h *UploadHandler) Upload(c *gin.Context) { return } - if h.App.SysConfig.MaxFileSize > 0 && f.Size > int64(h.App.SysConfig.MaxFileSize)*1024*1024 { + if h.App.SysConfig.Base.MaxFileSize > 0 && f.Size > int64(h.App.SysConfig.Base.MaxFileSize)*1024*1024 { resp.ERROR(c, "文件大小超过限制") return } diff --git a/api/handler/captcha_handler.go b/api/handler/captcha_handler.go index 4659ceef..0abf0cd4 100644 --- a/api/handler/captcha_handler.go +++ b/api/handler/captcha_handler.go @@ -21,11 +21,11 @@ import ( type CaptchaHandler struct { App *core.AppServer service *service.CaptchaService - config *types.CaptchaConfig + config types.CaptchaConfig } -func NewCaptchaHandler(app *core.AppServer, s *service.CaptchaService, config *types.CaptchaConfig) *CaptchaHandler { - return &CaptchaHandler{App: app, service: s, config: config} +func NewCaptchaHandler(app *core.AppServer, s *service.CaptchaService, sysConfig *types.SystemConfig) *CaptchaHandler { + return &CaptchaHandler{App: app, service: s, config: sysConfig.GeekAPI.Captcha} } // RegisterRoutes 注册路由 diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go index 0aad7821..f6942502 100644 --- a/api/handler/chat_handler.go +++ b/api/handler/chat_handler.go @@ -244,9 +244,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C // 加载聊天上下文 chatCtx := make([]any, 0) messages := make([]any, 0) - if h.App.SysConfig.EnableContext { + if h.App.SysConfig.Base.EnableContext { _ = utils.JsonDecode(input.ChatRole.Context, &messages) - if h.App.SysConfig.ContextDeep > 0 { + if h.App.SysConfig.Base.ContextDeep > 0 { var historyMessages []model.ChatMessage dbSession := h.DB.Session(&gorm.Session{}).Where("chat_id", input.ChatId) if input.LastMsgId > 0 { // 重新生成逻辑 @@ -254,7 +254,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C // 删除对应的聊天记录 h.DB.Debug().Where("chat_id", input.ChatId).Where("id >= ?", input.LastMsgId).Delete(&model.ChatMessage{}) } - err = dbSession.Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages).Error + err = dbSession.Limit(h.App.SysConfig.Base.ContextDeep).Order("id DESC").Find(&historyMessages).Error if err == nil { for i := len(historyMessages) - 1; i >= 0; i-- { msg := historyMessages[i] @@ -282,7 +282,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C } // 上下文的深度超出了模型的最大上下文深度 - if len(chatCtx) >= h.App.SysConfig.ContextDeep { + if len(chatCtx) >= h.App.SysConfig.Base.ContextDeep { break } diff --git a/api/handler/dalle_handler.go b/api/handler/dalle_handler.go index 3a1fdb99..4026db76 100644 --- a/api/handler/dalle_handler.go +++ b/api/handler/dalle_handler.go @@ -88,7 +88,7 @@ func (h *DallJobHandler) Image(c *gin.Context) { Quality: data.Quality, Size: data.Size, Style: data.Style, - TranslateModelId: h.App.SysConfig.AssistantModelId, + TranslateModelId: h.App.SysConfig.Base.AssistantModelId, Power: chatModel.Power, } job := model.DallJob{ diff --git a/api/handler/function_handler.go b/api/handler/function_handler.go index 38540fbe..ac2a624f 100644 --- a/api/handler/function_handler.go +++ b/api/handler/function_handler.go @@ -197,7 +197,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) { return } - if user.Power < h.App.SysConfig.DallPower { + if user.Power < h.App.SysConfig.Base.DallPower { resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足") return } @@ -209,17 +209,17 @@ func (h *FunctionHandler) Dall3(c *gin.Context) { Prompt: prompt, ModelId: 0, ModelName: "dall-e-3", - TranslateModelId: h.App.SysConfig.AssistantModelId, + TranslateModelId: h.App.SysConfig.Base.AssistantModelId, N: 1, Quality: "standard", Size: "1024x1024", Style: "vivid", - Power: h.App.SysConfig.DallPower, + Power: h.App.SysConfig.Base.DallPower, } job := model.DallJob{ UserId: user.Id, Prompt: prompt, - Power: h.App.SysConfig.DallPower, + Power: h.App.SysConfig.Base.DallPower, TaskInfo: utils.JsonEncode(task), } err := h.DB.Create(&job).Error diff --git a/api/handler/menu_handler.go b/api/handler/menu_handler.go index a0eba073..9e1df9ea 100644 --- a/api/handler/menu_handler.go +++ b/api/handler/menu_handler.go @@ -39,7 +39,7 @@ func (h *MenuHandler) List(c *gin.Context) { session := h.DB.Session(&gorm.Session{}) session = session.Where("enabled", true) if index { - session = session.Where("id IN ?", h.App.SysConfig.IndexNavs) + session = session.Where("id IN ?", h.App.SysConfig.Base.IndexNavs) } res := session.Order("sort_num ASC").Find(&items) if res.Error == nil { diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index e307c2e0..91269aeb 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -65,7 +65,7 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool { return false } - if user.Power < h.App.SysConfig.MjPower { + if user.Power < h.App.SysConfig.Base.MjPower { resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") return false } @@ -171,8 +171,8 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { Params: params, UserId: userId, ImgArr: data.ImgArr, - Mode: h.App.SysConfig.MjMode, - TranslateModelId: h.App.SysConfig.AssistantModelId, + Mode: h.App.SysConfig.Base.MjMode, + TranslateModelId: h.App.SysConfig.Base.AssistantModelId, } job := model.MidJourneyJob{ Type: data.TaskType, @@ -181,7 +181,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { TaskInfo: utils.JsonEncode(task), Progress: 0, Prompt: fmt.Sprintf("%s %s", data.Prompt, params), - Power: h.App.SysConfig.MjPower, + Power: h.App.SysConfig.Base.MjPower, CreatedAt: time.Now(), } opt := "绘图" @@ -244,7 +244,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { Index: data.Index, MessageId: data.MessageId, MessageHash: data.MessageHash, - Mode: h.App.SysConfig.MjMode, + Mode: h.App.SysConfig.Base.MjMode, } job := model.MidJourneyJob{ Type: types.TaskUpscale.String(), @@ -252,7 +252,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { TaskId: taskId, TaskInfo: utils.JsonEncode(task), Progress: 0, - Power: h.App.SysConfig.MjActionPower, + Power: h.App.SysConfig.Base.MjActionPower, CreatedAt: time.Now(), } if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 { @@ -299,7 +299,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { ChannelId: data.ChannelId, MessageId: data.MessageId, MessageHash: data.MessageHash, - Mode: h.App.SysConfig.MjMode, + Mode: h.App.SysConfig.Base.MjMode, } job := model.MidJourneyJob{ Type: types.TaskVariation.String(), @@ -308,7 +308,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { TaskId: taskId, TaskInfo: utils.JsonEncode(task), Progress: 0, - Power: h.App.SysConfig.MjActionPower, + Power: h.App.SysConfig.Base.MjActionPower, CreatedAt: time.Now(), } if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 { diff --git a/api/handler/order_handler.go b/api/handler/order_handler.go index fe7d3206..feb9d925 100644 --- a/api/handler/order_handler.go +++ b/api/handler/order_handler.go @@ -55,20 +55,21 @@ func (h *OrderHandler) List(c *gin.Context) { order.Id = item.Id order.CreatedAt = item.CreatedAt.Unix() order.UpdatedAt = item.UpdatedAt.Unix() - payMethod, ok := types.PayMethods[item.PayWay] + payChannel, ok := types.PayChannel[item.Channel] if !ok { - payMethod = item.PayWay + payChannel = item.PayWay } - payName, ok := types.PayNames[item.PayType] + payWays, ok := types.PayWays[item.PayWay] if !ok { - payName = item.PayWay + payWays = item.PayWay } - order.PayMethod = payMethod - order.PayName = payName + order.ChannelName = payChannel + order.PayName = payWays list = append(list, order) } else { logger.Error(err) } + } } resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list)) diff --git a/api/handler/payment_handler.go b/api/handler/payment_handler.go index 0ff0b6f3..a1026f94 100644 --- a/api/handler/payment_handler.go +++ b/api/handler/payment_handler.go @@ -9,8 +9,10 @@ package handler import ( "embed" + "errors" "fmt" "geekai/core" + "geekai/core/midware" "geekai/core/types" "geekai/service" "geekai/service/payment" @@ -33,63 +35,143 @@ type PayWay struct { // PaymentHandler 支付服务回调 handler type PaymentHandler struct { BaseHandler - alipayService *payment.AlipayService - huPiPayService *payment.HuPiPayService - geekPayService *payment.GeekPayService - wechatPayService *payment.WechatPayService - snowflake *service.Snowflake - userService *service.UserService - fs embed.FS - lock sync.Mutex - signKey string // 用来签名的随机秘钥 + alipayService *payment.AlipayService + epayService *payment.EPayService + wxpayService *payment.WxPayService + snowflake *service.Snowflake + userService *service.UserService + fs embed.FS + lock sync.Mutex + config *types.PaymentConfig } func NewPaymentHandler( server *core.AppServer, alipayService *payment.AlipayService, - huPiPayService *payment.HuPiPayService, - geekPayService *payment.GeekPayService, - wechatPayService *payment.WechatPayService, + geekPayService *payment.EPayService, + wxpayService *payment.WxPayService, db *gorm.DB, userService *service.UserService, snowflake *service.Snowflake, - fs embed.FS) *PaymentHandler { + fs embed.FS, + sysConfig *types.SystemConfig) *PaymentHandler { return &PaymentHandler{ - alipayService: alipayService, - huPiPayService: huPiPayService, - geekPayService: geekPayService, - wechatPayService: wechatPayService, - snowflake: snowflake, - userService: userService, - fs: fs, - lock: sync.Mutex{}, + alipayService: alipayService, + epayService: geekPayService, + wxpayService: wxpayService, + snowflake: snowflake, + userService: userService, + fs: fs, + lock: sync.Mutex{}, BaseHandler: BaseHandler{ App: server, DB: db, }, - signKey: utils.RandString(32), + config: &sysConfig.Payment, } } // RegisterRoutes 注册路由 func (h *PaymentHandler) RegisterRoutes() { - group := h.App.Engine.Group("/api/payment/") - group.POST("doPay", h.Pay) - group.GET("payWays", h.GetPayWays) - group.POST("notify/alipay", h.AlipayNotify) - group.GET("notify/geek", h.GeekPayNotify) - group.POST("notify/wechat", h.WechatPayNotify) - group.POST("notify/hupi", h.HuPiPayNotify) + rg := h.App.Engine.Group("/api/payment/") + + // 支付回调接口(公开) + rg.POST("notify/alipay", h.AlipayNotify) + rg.GET("notify/geek", h.GeekPayNotify) + rg.POST("notify/wechat", h.WechatPayNotify) + + // 需要用户登录的接口 + rg.Use(midware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis)) + { + rg.POST("create", h.Pay) + } + + // 同步订单状态 + h.StartSyncOrders() +} + +func (h *PaymentHandler) StartSyncOrders() { + go func() { + for { + err := h.SyncOrders() + if err != nil { + logger.Error(err) + } + time.Sleep(time.Second * 5) + } + }() +} + +// SyncOrders 同步订单状态 +func (h *PaymentHandler) SyncOrders() error { + defer func() { + if err := recover(); err != nil { + logger.Errorf("同步订单状态发生异常: %v", err) + } + }() + var orders []model.Order + err := h.DB.Where("status", types.OrderNotPaid).Where("checked", false).Find(&orders).Error + if err != nil { + return err + } + + for _, order := range orders { + // 超时15分钟的订单,直接标记为已关闭 + if time.Now().After(order.CreatedAt.Add(time.Minute * 15)) { + h.DB.Model(&model.Order{}).Where("id", order.Id).Update("checked", true) + return errors.New("订单超时") + } + // 查询订单状态 + var res payment.OrderInfo + switch order.Channel { + case payment.PayChannelEpay: + res, err = h.epayService.Query(order.OrderNo) + if err != nil { + return fmt.Errorf("error with query order info: %v", err) + } + // 微信支付 + case payment.PayChannelWX: + res, err = h.wxpayService.Query(order.OrderNo) + if err != nil { + return fmt.Errorf("error with query order info: %v", err) + } + case payment.PayChannelAL: + res, err = h.alipayService.Query(order.OrderNo) + if err != nil { + return fmt.Errorf("error with query order info: %v", err) + } + } + + // 订单已关闭 + if res.Closed() { + h.DB.Model(&model.Order{}).Where("id", order.Id).Updates(map[string]any{ + "checked": true, + "status": types.OrderPaidFailed, + }) + return errors.New("订单已关闭") + } + + // 订单未支付,不处理,继续轮询 + if !res.Success() { + return nil + } + + // 订单支付成功 + err = h.paySuccess(res) + if err != nil { + return fmt.Errorf("error with deal order: %v", err) + } + } + return nil } func (h *PaymentHandler) Pay(c *gin.Context) { var data struct { - PayWay string `json:"pay_way"` - PayType string `json:"pay_type"` - ProductId int `json:"product_id"` - UserId int `json:"user_id"` - Device string `json:"device"` - Host string `json:"host"` + PayWay string `json:"pay_way,omitempty"` // 支付方式:支付宝,微信 + Pid int `json:"pid,omitempty"` + Device string `json:"device,omitempty"` + Domain string `json:"domain,omitempty"` // 支付回调域名 + Channel string `json:"channel,omitempty"` } if err := c.ShouldBindJSON(&data); err != nil { resp.ERROR(c, types.InvalidArgs) @@ -97,7 +179,7 @@ func (h *PaymentHandler) Pay(c *gin.Context) { } var product model.Product - err := h.DB.Where("id", data.ProductId).First(&product).Error + err := h.DB.Where("id", data.Pid).First(&product).Error if err != nil { resp.ERROR(c, "Product not found") return @@ -108,136 +190,118 @@ func (h *PaymentHandler) Pay(c *gin.Context) { resp.ERROR(c, "error with generate trade no: "+err.Error()) return } + userId := h.GetLoginUserId(c) var user model.User - err = h.DB.Where("id", data.UserId).First(&user).Error + err = h.DB.Where("id", userId).First(&user).Error if err != nil { resp.NotAuth(c) return } - amount := product.Discount - var payURL, returnURL, notifyURL string + amount := product.Price + var payURL, notifyURL string switch data.PayWay { - case "alipay": - if h.App.Config.AlipayConfig.NotifyURL != "" { // 用于本地调试支付 - notifyURL = h.App.Config.AlipayConfig.NotifyURL - } else { - notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Host) - } - if h.App.Config.AlipayConfig.ReturnURL != "" { // 用于本地调试支付 - returnURL = h.App.Config.AlipayConfig.ReturnURL - } else { - returnURL = fmt.Sprintf("%s/payReturn", data.Host) - } - money := fmt.Sprintf("%.2f", amount) - if data.Device == "wechat" { - payURL, err = h.alipayService.PayMobile(payment.AlipayParams{ + case "wxpay": + logger.Debugf("微信支付,%+v", data) + data.Channel = payment.PayChannelWX + // 优先使用微信官方支付 + if h.config.WxPay.Enabled { + data.Channel = "wxpay" + if h.config.WxPay.Domain != "" { + data.Domain = h.config.WxPay.Domain + } + notifyURL = fmt.Sprintf("%s/api/payment/notify/wechat", data.Domain) + payURL, err = h.wxpayService.Pay(payment.PayRequest{ OutTradeNo: orderNo, - Subject: product.Name, - TotalFee: money, - ReturnURL: returnURL, - NotifyURL: notifyURL, - }) - } else { - payURL, err = h.alipayService.PayPC(payment.AlipayParams{ - OutTradeNo: orderNo, - Subject: product.Name, - TotalFee: money, - ReturnURL: returnURL, - NotifyURL: notifyURL, - }) - } - - if err != nil { - resp.ERROR(c, "error with generate pay url: "+err.Error()) - return - } - break - case "wechat": - if h.App.Config.WechatPayConfig.NotifyURL != "" { - notifyURL = h.App.Config.WechatPayConfig.NotifyURL - } else { - notifyURL = fmt.Sprintf("%s/api/payment/notify/wechat", data.Host) - } - if data.Device == "wechat" { - payURL, err = h.wechatPayService.PayUrlH5(payment.WechatPayParams{ - OutTradeNo: orderNo, - TotalFee: int(amount * 100), + TotalFee: fmt.Sprintf("%d", int(amount*100)), Subject: product.Name, NotifyURL: notifyURL, ClientIP: c.ClientIP(), + Device: data.Device, + PayWay: payment.PayWayWX, }) - } else { - payURL, err = h.wechatPayService.PayUrlNative(payment.WechatPayParams{ + if err != nil { + resp.ERROR(c, err.Error()) + return + } + } else if h.config.Epay.Enabled { // 聚合支付 + logger.Debugf("聚合支付%+v", data) + data.Channel = payment.PayChannelEpay + if h.config.Epay.Domain != "" { + data.Domain = h.config.Epay.Domain + } + notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Domain) + params := payment.PayRequest{ OutTradeNo: orderNo, - TotalFee: int(amount * 100), Subject: product.Name, + TotalFee: fmt.Sprintf("%f", amount), + ClientIP: c.ClientIP(), + Device: data.Device, + PayWay: payment.PayWayWX, + NotifyURL: notifyURL, + } + + r, err := h.epayService.Pay(params) + logger.Debugf("请求支付结果,%+v", r) + if err != nil { + resp.ERROR(c, err.Error()) + return + } else { + payURL = r + } + } else { + resp.ERROR(c, "系统没有配置可用的支付渠道!") + return + } + case "alipay": + if h.config.Alipay.Enabled { + logger.Debugf("支付宝,%+v", data) + data.Channel = payment.PayChannelAL + if h.config.Alipay.Domain != "" { // 用于本地调试支付 + data.Domain = h.config.Alipay.Domain + } + notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Domain) + money := fmt.Sprintf("%.2f", amount) + payURL, err = h.alipayService.Pay(payment.PayRequest{ + Device: data.Device, + OutTradeNo: orderNo, + Subject: product.Name, + TotalFee: money, NotifyURL: notifyURL, }) - } - if err != nil { - resp.ERROR(c, err.Error()) - return - } - break - case "hupi": - if h.App.Config.HuPiPayConfig.NotifyURL != "" { - notifyURL = h.App.Config.HuPiPayConfig.NotifyURL - } else { - notifyURL = fmt.Sprintf("%s/api/payment/notify/hupi", data.Host) - } - if h.App.Config.HuPiPayConfig.ReturnURL != "" { - returnURL = h.App.Config.HuPiPayConfig.ReturnURL - } else { - returnURL = fmt.Sprintf("%s/payReturn", data.Host) - } - r, err := h.huPiPayService.Pay(payment.HuPiPayParams{ - Version: "1.1", - TradeOrderId: orderNo, - TotalFee: fmt.Sprintf("%f", amount), - Title: product.Name, - NotifyURL: notifyURL, - ReturnURL: returnURL, - WapName: "GeekAI助手", - }) - if err != nil { - resp.ERROR(c, err.Error()) - return - } - payURL = r.URL - break - case "geek": - if h.App.Config.GeekPayConfig.NotifyURL != "" { - notifyURL = h.App.Config.GeekPayConfig.NotifyURL - } else { - notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Host) - } - if h.App.Config.GeekPayConfig.ReturnURL != "" { - data.Host = utils.GetBaseURL(h.App.Config.GeekPayConfig.ReturnURL) - } - if data.Device == "wechat" { // 微信客户端打开,调回手机端用户中心页面 - returnURL = fmt.Sprintf("%s/mobile/profile", data.Host) - } else { - returnURL = fmt.Sprintf("%s/payReturn", data.Host) - } - params := payment.GeekPayParams{ - OutTradeNo: orderNo, - Method: "web", - Name: product.Name, - Money: fmt.Sprintf("%f", amount), - ClientIP: c.ClientIP(), - Device: data.Device, - Type: data.PayType, - ReturnURL: returnURL, - NotifyURL: notifyURL, - } - res, err := h.geekPayService.Pay(params) - if err != nil { - resp.ERROR(c, err.Error()) + if err != nil { + resp.ERROR(c, "error with generate pay url: "+err.Error()) + return + } + } else if h.config.Epay.Enabled { // 聚合支付 + logger.Debugf("聚合支付,%+v", data) + data.Channel = payment.PayChannelEpay + if h.config.Epay.Domain != "" { + data.Domain = h.config.Epay.Domain + } + notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Domain) + params := payment.PayRequest{ + OutTradeNo: orderNo, + Subject: product.Name, + TotalFee: fmt.Sprintf("%f", amount), + ClientIP: c.ClientIP(), + Device: data.Device, + PayWay: data.PayWay, + NotifyURL: notifyURL, + } + + r, err := h.epayService.Pay(params) + if err != nil { + resp.ERROR(c, err.Error()) + return + } else { + payURL = r + } + } else { + resp.ERROR(c, "系统没有配置可用的支付渠道!") return } - payURL = res.PayURL default: resp.ERROR(c, "不支持的支付渠道") return @@ -245,43 +309,41 @@ func (h *PaymentHandler) Pay(c *gin.Context) { // 创建订单 remark := types.OrderRemark{ - Days: product.Days, - Power: product.Power, - Name: product.Name, - Price: product.Price, - Discount: product.Discount, + Days: product.Days, + Power: product.Power, + Name: product.Name, + Price: product.Price, } order := model.Order{ - UserId: user.Id, - Username: user.Username, - ProductId: product.Id, - OrderNo: orderNo, - Subject: product.Name, - Amount: amount, - Status: types.OrderNotPaid, - PayWay: data.PayWay, - PayType: data.PayType, - Remark: utils.JsonEncode(remark), + UserId: user.Id, + Username: user.Username, + OrderNo: orderNo, + Subject: product.Name, + Amount: amount, + Status: types.OrderNotPaid, + PayWay: data.PayWay, + Channel: data.Channel, + Remark: utils.JsonEncode(remark), } err = h.DB.Create(&order).Error if err != nil { resp.ERROR(c, "error with create order: "+err.Error()) return } - resp.SUCCESS(c, payURL) + resp.SUCCESS(c, gin.H{"pay_url": payURL, "order_no": orderNo}) } -// 异步通知回调公共逻辑 -func (h *PaymentHandler) notify(orderNo string, tradeNo string) error { +// 支付成功处理 +func (h *PaymentHandler) paySuccess(info payment.OrderInfo) error { + h.lock.Lock() + defer h.lock.Unlock() + var order model.Order - err := h.DB.Where("order_no = ?", orderNo).First(&order).Error + err := h.DB.Where("order_no", info.OutTradeNo).First(&order).Error if err != nil { return fmt.Errorf("error with fetch order: %v", err) } - h.lock.Lock() - defer h.lock.Unlock() - // 已支付订单,直接返回 if order.Status == types.OrderPaidSuccess { return nil @@ -301,18 +363,20 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error { // 增加用户算力 err = h.userService.IncreasePower(order.UserId, remark.Power, model.PowerLog{ - Type: types.PowerRecharge, - Model: order.PayWay, - Remark: fmt.Sprintf("充值算力,金额:%f,订单号:%s", order.Amount, order.OrderNo), + Type: types.PowerRecharge, + Model: order.Subject, + Remark: fmt.Sprintf("充值算力,金额:%f,订单号:%s", order.Amount, order.OrderNo), + CreatedAt: time.Now(), }) if err != nil { return err } // 更新订单状态 - order.PayTime = time.Now().Unix() + order.PayTime = utils.Str2stamp(info.PayTime) order.Status = types.OrderPaidSuccess - order.TradeNo = tradeNo + order.TradeNo = info.TradeId + order.Checked = true err = h.DB.Updates(&order).Error if err != nil { return fmt.Errorf("error with update order info: %v", err) @@ -328,54 +392,6 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error { return nil } -// GetPayWays 获取支付方式 -func (h *PaymentHandler) GetPayWays(c *gin.Context) { - payWays := make([]gin.H, 0) - if h.App.Config.AlipayConfig.Enabled { - payWays = append(payWays, gin.H{"pay_way": "alipay", "pay_type": "alipay"}) - } - if h.App.Config.HuPiPayConfig.Enabled { - payWays = append(payWays, gin.H{"pay_way": "hupi", "pay_type": "wxpay"}) - } - if h.App.Config.GeekPayConfig.Enabled { - for _, v := range h.App.Config.GeekPayConfig.Methods { - payWays = append(payWays, gin.H{"pay_way": "geek", "pay_type": v}) - } - } - if h.App.Config.WechatPayConfig.Enabled { - payWays = append(payWays, gin.H{"pay_way": "wechat", "pay_type": "wxpay"}) - } - resp.SUCCESS(c, payWays) -} - -// HuPiPayNotify 虎皮椒支付异步回调 -func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) { - err := c.Request.ParseForm() - if err != nil { - c.String(http.StatusOK, "fail") - return - } - - orderNo := c.Request.Form.Get("trade_order_id") - tradeNo := c.Request.Form.Get("open_order_id") - logger.Infof("收到虎皮椒订单支付回调,%+v", c.Request.Form) - - if err = h.huPiPayService.Check(orderNo); err != nil { - logger.Error("订单校验失败:", err) - c.String(http.StatusOK, "fail") - return - } - - err = h.notify(orderNo, tradeNo) - if err != nil { - logger.Error(err) - c.String(http.StatusOK, "fail") - return - } - - c.String(http.StatusOK, "success") -} - // AlipayNotify 支付宝支付回调 func (h *PaymentHandler) AlipayNotify(c *gin.Context) { err := c.Request.ParseForm() @@ -384,16 +400,15 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) { return } - result := h.alipayService.TradeVerify(c.Request) - logger.Infof("收到支付宝商号订单支付回调:%+v", result) - if !result.Success() { - logger.Error("订单校验失败:", result.Message) + orderInfo, err := h.alipayService.Query(c.Request.Form.Get("out_trade_no")) + logger.Infof("收到支付宝商号订单支付回调:%+v", orderInfo) + if !orderInfo.Success() { + logger.Errorf("订单校验失败:%v", err) c.String(http.StatusOK, "fail") return } - tradeNo := c.Request.Form.Get("trade_no") - err = h.notify(result.OutTradeNo, tradeNo) + err = h.paySuccess(orderInfo) if err != nil { logger.Error(err) c.String(http.StatusOK, "fail") @@ -411,20 +426,27 @@ func (h *PaymentHandler) GeekPayNotify(c *gin.Context) { } logger.Infof("收到GeekPay订单支付回调:%+v", params) - // 检查支付状态 + // 检查支付状态, 如果未支付,则返回成功 if params["trade_status"] != "TRADE_SUCCESS" { c.String(http.StatusOK, "success") return } - sign := h.geekPayService.Sign(params) + sign := h.epayService.Sign(params) if sign != c.Query("sign") { logger.Errorf("签名验证失败, %s, %s", sign, c.Query("sign")) c.String(http.StatusOK, "fail") return } + // 查询订单状态 + order, err := h.epayService.Query(params["out_trade_no"]) + if err != nil { + logger.Error(err) + c.String(http.StatusOK, "fail") + return + } - err := h.notify(params["out_trade_no"], params["trade_no"]) + err = h.paySuccess(order) if err != nil { logger.Error(err) c.String(http.StatusOK, "fail") @@ -442,18 +464,15 @@ func (h *PaymentHandler) WechatPayNotify(c *gin.Context) { return } - result := h.wechatPayService.TradeVerify(c.Request) - logger.Infof("收到微信商号订单支付回调:%+v", result) - if !result.Success() { - logger.Error("订单校验失败:", err) - c.JSON(http.StatusBadRequest, gin.H{ - "code": "FAIL", - "message": err.Error(), - }) + orderInfo, err := h.wxpayService.TradeVerify(c.Request) + logger.Infof("收到微信商号订单支付回调:%+v", orderInfo) + if err != nil { + logger.Errorf("订单校验失败:%v", err) + c.JSON(http.StatusBadRequest, gin.H{"code": "FAIL"}) return } - err = h.notify(result.OutTradeNo, result.TradeId) + err = h.paySuccess(orderInfo) if err != nil { logger.Error(err) c.String(http.StatusOK, "fail") diff --git a/api/handler/prompt_handler.go b/api/handler/prompt_handler.go index 591410f9..ffb55d4d 100644 --- a/api/handler/prompt_handler.go +++ b/api/handler/prompt_handler.go @@ -57,15 +57,15 @@ func (h *PromptHandler) Lyric(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } - content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId) + content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId) if err != nil { resp.ERROR(c, err.Error()) return } - if h.App.SysConfig.PromptPower > 0 { + if h.App.SysConfig.Base.PromptPower > 0 { userId := h.GetLoginUserId(c) - err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{ + err = h.userService.DecreasePower(userId, h.App.SysConfig.Base.PromptPower, model.PowerLog{ Type: types.PowerConsume, Model: h.getPromptModel(), Remark: "生成歌词", @@ -88,14 +88,14 @@ func (h *PromptHandler) Image(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } - content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.AssistantModelId) + content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId) if err != nil { resp.ERROR(c, err.Error()) return } - if h.App.SysConfig.PromptPower > 0 { + if h.App.SysConfig.Base.PromptPower > 0 { userId := h.GetLoginUserId(c) - err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{ + err = h.userService.DecreasePower(userId, h.App.SysConfig.Base.PromptPower, model.PowerLog{ Type: types.PowerConsume, Model: h.getPromptModel(), Remark: "生成绘画提示词", @@ -117,15 +117,15 @@ func (h *PromptHandler) Video(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } - content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId) + content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId) if err != nil { resp.ERROR(c, err.Error()) return } - if h.App.SysConfig.PromptPower > 0 { + if h.App.SysConfig.Base.PromptPower > 0 { userId := h.GetLoginUserId(c) - err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{ + err = h.userService.DecreasePower(userId, h.App.SysConfig.Base.PromptPower, model.PowerLog{ Type: types.PowerConsume, Model: h.getPromptModel(), Remark: "生成视频脚本", @@ -167,9 +167,9 @@ func (h *PromptHandler) MetaPrompt(c *gin.Context) { } func (h *PromptHandler) getPromptModel() string { - if h.App.SysConfig.AssistantModelId > 0 { + if h.App.SysConfig.Base.AssistantModelId > 0 { var chatModel model.ChatModel - h.DB.Where("id", h.App.SysConfig.AssistantModelId).First(&chatModel) + h.DB.Where("id", h.App.SysConfig.Base.AssistantModelId).First(&chatModel) return chatModel.Value } return "gpt-4o" diff --git a/api/handler/realtime_handler.go b/api/handler/realtime_handler.go index f5874197..8340dc45 100644 --- a/api/handler/realtime_handler.go +++ b/api/handler/realtime_handler.go @@ -160,7 +160,7 @@ func (h *RealtimeHandler) VoiceChat(c *gin.Context) { return } - if user.Power < h.App.SysConfig.AdvanceVoicePower { + if user.Power < h.App.SysConfig.Base.AdvanceVoicePower { resp.ERROR(c, "当前用户算力不足,无法使用该功能") return } @@ -204,7 +204,7 @@ func (h *RealtimeHandler) VoiceChat(c *gin.Context) { h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) // 扣减算力 - err = h.userService.DecreasePower(userId, h.App.SysConfig.AdvanceVoicePower, model.PowerLog{ + err = h.userService.DecreasePower(userId, h.App.SysConfig.Base.AdvanceVoicePower, model.PowerLog{ Type: types.PowerConsume, Model: "advanced-voice", Remark: "实时语音通话", diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index 51a6fe08..e1bb46b8 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -73,7 +73,7 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool { return false } - if user.Power < h.App.SysConfig.SdPower { + if user.Power < h.App.SysConfig.Base.SdPower { resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") return false } @@ -141,7 +141,7 @@ func (h *SdJobHandler) Image(c *gin.Context) { HdSteps: data.HdSteps, }, UserId: userId, - TranslateModelId: h.App.SysConfig.AssistantModelId, + TranslateModelId: h.App.SysConfig.Base.AssistantModelId, } job := model.SdJob{ @@ -152,7 +152,7 @@ func (h *SdJobHandler) Image(c *gin.Context) { TaskInfo: utils.JsonEncode(task), Prompt: data.Prompt, Progress: 0, - Power: h.App.SysConfig.SdPower, + Power: h.App.SysConfig.Base.SdPower, CreatedAt: time.Now(), } res := h.DB.Create(&job) diff --git a/api/handler/sms_handler.go b/api/handler/sms_handler.go index 55332b96..5acc7216 100644 --- a/api/handler/sms_handler.go +++ b/api/handler/sms_handler.go @@ -25,7 +25,7 @@ const CodeStorePrefix = "/verify/codes/" type SmsHandler struct { BaseHandler redis *redis.Client - sms *sms.ServiceManager + sms *sms.SmsManager smtp *service.SmtpService captcha *service.CaptchaService } @@ -33,7 +33,7 @@ type SmsHandler struct { func NewSmsHandler( app *core.AppServer, client *redis.Client, - sms *sms.ServiceManager, + sms *sms.SmsManager, smtp *service.SmtpService, captcha *service.CaptchaService) *SmsHandler { return &SmsHandler{ @@ -62,7 +62,7 @@ func (h *SmsHandler) SendCode(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } - if h.App.SysConfig.EnabledVerify { + if h.App.SysConfig.Base.EnabledVerify { var check bool if data.X != 0 { check = h.captcha.SlideCheck(data) @@ -78,14 +78,14 @@ func (h *SmsHandler) SendCode(c *gin.Context) { code := utils.RandomNumber(6) var err error if strings.Contains(data.Receiver, "@") { // email - if !utils.Contains(h.App.SysConfig.RegisterWays, "email") { + if !utils.Contains(h.App.SysConfig.Base.RegisterWays, "email") { resp.ERROR(c, "系统已禁用邮箱注册!") return } // 检查邮箱后缀是否在白名单 - if len(h.App.SysConfig.EmailWhiteList) > 0 { + if len(h.App.SysConfig.Base.EmailWhiteList) > 0 { inWhiteList := false - for _, suffix := range h.App.SysConfig.EmailWhiteList { + for _, suffix := range h.App.SysConfig.Base.EmailWhiteList { if strings.HasSuffix(data.Receiver, suffix) { inWhiteList = true break @@ -98,7 +98,7 @@ func (h *SmsHandler) SendCode(c *gin.Context) { } err = h.smtp.SendVerifyCode(data.Receiver, code) } else { - if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") { + if !utils.Contains(h.App.SysConfig.Base.RegisterWays, "mobile") { resp.ERROR(c, "系统已禁用手机号注册!") return } diff --git a/api/handler/suno_handler.go b/api/handler/suno_handler.go index f8ada5c1..4579bc62 100644 --- a/api/handler/suno_handler.go +++ b/api/handler/suno_handler.go @@ -82,7 +82,7 @@ func (h *SunoHandler) Create(c *gin.Context) { return } - if user.Power < h.App.SysConfig.SunoPower { + if user.Power < h.App.SysConfig.Base.SunoPower { resp.ERROR(c, "您的算力不足,请充值后再试!") return } @@ -130,7 +130,7 @@ func (h *SunoHandler) Create(c *gin.Context) { RefSongId: data.RefSongId, RefTaskId: data.RefTaskId, ExtendSecs: data.ExtendSecs, - Power: h.App.SysConfig.SunoPower, + Power: h.App.SysConfig.Base.SunoPower, SongId: utils.RandString(32), } if data.Lyrics != "" { diff --git a/api/handler/test_handler.go b/api/handler/test_handler.go index f2b6b54c..3a2508eb 100644 --- a/api/handler/test_handler.go +++ b/api/handler/test_handler.go @@ -4,19 +4,20 @@ import ( "geekai/core" "geekai/service" "geekai/service/payment" + "net/http" + "github.com/gin-gonic/gin" "gorm.io/gorm" - "net/http" ) type TestHandler struct { App *core.AppServer db *gorm.DB snowflake *service.Snowflake - js *payment.GeekPayService + js *payment.EPayService } -func NewTestHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, js *payment.GeekPayService) *TestHandler { +func NewTestHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, js *payment.EPayService) *TestHandler { return &TestHandler{App: app, db: db, snowflake: snowflake, js: js} } diff --git a/api/handler/user_handler.go b/api/handler/user_handler.go index 31c85c07..ec83eddb 100644 --- a/api/handler/user_handler.go +++ b/api/handler/user_handler.go @@ -20,8 +20,6 @@ import ( "strings" "time" - "github.com/imroc/req/v3" - "github.com/go-redis/redis/v8" "github.com/golang-jwt/jwt/v5" @@ -98,7 +96,7 @@ func (h *UserHandler) Register(c *gin.Context) { return } - if h.App.SysConfig.EnabledVerify && data.RegWay == "username" { + if h.App.SysConfig.Base.EnabledVerify && data.RegWay == "username" { var check bool if data.X != 0 { check = h.captcha.SlideCheck(data) @@ -163,7 +161,7 @@ func (h *UserHandler) Register(c *gin.Context) { ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色 ChatConfig: "{}", ChatModels: "{}", - Power: h.App.SysConfig.InitPower, + Power: h.App.SysConfig.Base.InitPower, } // check if the username is existing @@ -188,13 +186,13 @@ func (h *UserHandler) Register(c *gin.Context) { // 被邀请人也获得赠送算力 if data.InviteCode != "" { - user.Power += h.App.SysConfig.InvitePower + user.Power += h.App.SysConfig.Base.InvitePower } if h.licenseService.GetLicense().Configs.DeCopy { user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6)) } else { - defaultNickname := h.App.SysConfig.DefaultNickname + defaultNickname := h.App.SysConfig.Base.DefaultNickname if defaultNickname == "" { defaultNickname = "极客学长" } @@ -211,11 +209,11 @@ func (h *UserHandler) Register(c *gin.Context) { if data.InviteCode != "" { // 增加邀请数量 h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1)) - if h.App.SysConfig.InvitePower > 0 { - err := h.userService.IncreasePower(inviteCode.UserId, h.App.SysConfig.InvitePower, model.PowerLog{ + if h.App.SysConfig.Base.InvitePower > 0 { + err := h.userService.IncreasePower(inviteCode.UserId, h.App.SysConfig.Base.InvitePower, model.PowerLog{ Type: types.PowerInvite, Model: "Invite", - Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username), + Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.Base.InvitePower, inviteCode.Code, user.Username), }) if err != nil { tx.Rollback() @@ -230,7 +228,7 @@ func (h *UserHandler) Register(c *gin.Context) { UserId: user.Id, Username: user.Username, InviteCode: inviteCode.Code, - Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower), + Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.Base.InvitePower), }).Error if err != nil { tx.Rollback() @@ -276,7 +274,7 @@ func (h *UserHandler) Login(c *gin.Context) { verifyKey := fmt.Sprintf("users/verify/%s", data.Username) needVerify, err := h.redis.Get(c, verifyKey).Bool() - if h.App.SysConfig.EnabledVerify && needVerify { + if h.App.SysConfig.Base.EnabledVerify && needVerify { var check bool if data.X != 0 { check = h.captcha.SlideCheck(data) @@ -353,150 +351,128 @@ func (h *UserHandler) Logout(c *gin.Context) { // CLogin 第三方登录请求二维码 func (h *UserHandler) CLogin(c *gin.Context) { - returnURL := h.GetTrim(c, "return_url") - var res types.BizVo - apiURL := fmt.Sprintf("%s/api/clogin/request", h.App.Config.ApiConfig.ApiURL) - r, err := req.C().R().SetBody(gin.H{"login_type": "wx", "return_url": returnURL}). - SetHeader("AppId", h.App.Config.ApiConfig.AppId). - SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)). - SetSuccessResult(&res). - Post(apiURL) - if err != nil { - resp.ERROR(c, err.Error()) - return - } - if r.IsErrorState() { - resp.ERROR(c, "error with login http status: "+r.Status) - return - } - if res.Code != types.Success { - resp.ERROR(c, "error with http response: "+res.Message) - return - } - - resp.SUCCESS(c, res.Data) } // CLoginCallback 第三方登录回调 func (h *UserHandler) CLoginCallback(c *gin.Context) { - loginType := c.Query("login_type") - code := c.Query("code") - userId := h.GetInt(c, "user_id", 0) - action := c.Query("action") + // loginType := c.Query("login_type") + // code := c.Query("code") + // userId := h.GetInt(c, "user_id", 0) + // action := c.Query("action") - var res types.BizVo - apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL) - r, err := req.C().R().SetBody(gin.H{"login_type": loginType, "code": code}). - SetHeader("AppId", h.App.Config.ApiConfig.AppId). - SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)). - SetSuccessResult(&res). - Post(apiURL) - if err != nil { - resp.ERROR(c, err.Error()) - return - } - if r.IsErrorState() { - resp.ERROR(c, "error with login http status: "+r.Status) - return - } + // var res types.BizVo + // apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL) + // r, err := req.C().R().SetBody(gin.H{"login_type": loginType, "code": code}). + // SetHeader("AppId", h.App.Config.ApiConfig.AppId). + // SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)). + // SetSuccessResult(&res). + // Post(apiURL) + // if err != nil { + // resp.ERROR(c, err.Error()) + // return + // } + // if r.IsErrorState() { + // resp.ERROR(c, "error with login http status: "+r.Status) + // return + // } - if res.Code != types.Success { - resp.ERROR(c, "error with http response: "+res.Message) - return - } + // if res.Code != types.Success { + // resp.ERROR(c, "error with http response: "+res.Message) + // return + // } - // login successfully - data := res.Data.(map[string]interface{}) - var user model.User - if action == "bind" && userId > 0 { - err = h.DB.Where("openid", data["openid"]).First(&user).Error - if err == nil { - resp.ERROR(c, "该微信已经绑定其他账号,请先解绑") - return - } + // // login successfully + // data := res.Data.(map[string]interface{}) + // var user model.User + // if action == "bind" && userId > 0 { + // err = h.DB.Where("openid", data["openid"]).First(&user).Error + // if err == nil { + // resp.ERROR(c, "该微信已经绑定其他账号,请先解绑") + // return + // } - err = h.DB.Where("id", userId).First(&user).Error - if err != nil { - resp.ERROR(c, "绑定用户不存在") - return - } + // err = h.DB.Where("id", userId).First(&user).Error + // if err != nil { + // resp.ERROR(c, "绑定用户不存在") + // return + // } - err = h.DB.Model(&user).UpdateColumn("openid", data["openid"]).Error - if err != nil { - resp.ERROR(c, "更新用户信息失败,"+err.Error()) - return - } + // err = h.DB.Model(&user).UpdateColumn("openid", data["openid"]).Error + // if err != nil { + // resp.ERROR(c, "更新用户信息失败,"+err.Error()) + // return + // } - resp.SUCCESS(c, gin.H{"token": ""}) - return - } + // resp.SUCCESS(c, gin.H{"token": ""}) + // return + // } - session := gin.H{} - tx := h.DB.Where("openid", data["openid"]).First(&user) - if tx.Error != nil { - // create new user - var totalUser int64 - h.DB.Model(&model.User{}).Count(&totalUser) - if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum { - resp.ERROR(c, "当前注册用户数已达上限,请请升级 License") - return - } + // session := gin.H{} + // tx := h.DB.Where("openid", data["openid"]).First(&user) + // if tx.Error != nil { + // // create new user + // var totalUser int64 + // h.DB.Model(&model.User{}).Count(&totalUser) + // if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum { + // resp.ERROR(c, "当前注册用户数已达上限,请请升级 License") + // return + // } - salt := utils.RandString(8) - password := fmt.Sprintf("%d", utils.RandomNumber(8)) - user = model.User{ - Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)), - Password: utils.GenPassword(password, salt), - Avatar: fmt.Sprintf("%s", data["avatar"]), - Salt: salt, - Status: true, - ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色 - Power: h.App.SysConfig.InitPower, - OpenId: fmt.Sprintf("%s", data["openid"]), - Nickname: fmt.Sprintf("%s", data["nickname"]), - } + // salt := utils.RandString(8) + // password := fmt.Sprintf("%d", utils.RandomNumber(8)) + // user = model.User{ + // Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)), + // Password: utils.GenPassword(password, salt), + // Avatar: fmt.Sprintf("%s", data["avatar"]), + // Salt: salt, + // Status: true, + // ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色 + // Power: h.App.SysConfig.InitPower, + // OpenId: fmt.Sprintf("%s", data["openid"]), + // Nickname: fmt.Sprintf("%s", data["nickname"]), + // } - tx = h.DB.Create(&user) - if tx.Error != nil { - resp.ERROR(c, "保存数据失败") - logger.Error(tx.Error) - return - } - session["username"] = user.Username - session["password"] = password - } else { // login directly - // 更新最后登录时间和IP - user.LastLoginIp = c.ClientIP() - user.LastLoginAt = time.Now().Unix() - h.DB.Model(&user).Updates(user) + // tx = h.DB.Create(&user) + // if tx.Error != nil { + // resp.ERROR(c, "保存数据失败") + // logger.Error(tx.Error) + // return + // } + // session["username"] = user.Username + // session["password"] = password + // } else { // login directly + // // 更新最后登录时间和IP + // user.LastLoginIp = c.ClientIP() + // user.LastLoginAt = time.Now().Unix() + // h.DB.Model(&user).Updates(user) - h.DB.Create(&model.UserLoginLog{ - UserId: user.Id, - Username: user.Username, - LoginIp: c.ClientIP(), - LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()), - }) - } + // h.DB.Create(&model.UserLoginLog{ + // UserId: user.Id, + // Username: user.Username, + // LoginIp: c.ClientIP(), + // LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()), + // }) + // } - // 创建 token - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "user_id": user.Id, - "expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(), - }) - tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey)) - if err != nil { - resp.ERROR(c, "Failed to generate token, "+err.Error()) - return - } - // 保存到 redis - key := fmt.Sprintf("users/%d", user.Id) - if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil { - resp.ERROR(c, "error with save token: "+err.Error()) - return - } - session["token"] = tokenString - resp.SUCCESS(c, session) + // // 创建 token + // token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + // "user_id": user.Id, + // "expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(), + // }) + // tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey)) + // if err != nil { + // resp.ERROR(c, "Failed to generate token, "+err.Error()) + // return + // } + // // 保存到 redis + // key := fmt.Sprintf("users/%d", user.Id) + // if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil { + // resp.ERROR(c, "error with save token: "+err.Error()) + // return + // } + // session["token"] = tokenString + // resp.SUCCESS(c, session) } // Session 获取/验证会话 @@ -760,11 +736,11 @@ func (h *UserHandler) SignIn(c *gin.Context) { // 签到 h.levelDB.Put(key, true) - if h.App.SysConfig.DailyPower > 0 { - h.userService.IncreasePower(userId, h.App.SysConfig.DailyPower, model.PowerLog{ + if h.App.SysConfig.Base.DailyPower > 0 { + h.userService.IncreasePower(userId, h.App.SysConfig.Base.DailyPower, model.PowerLog{ Type: types.PowerSignIn, Model: "SignIn", - Remark: fmt.Sprintf("每日签到奖励,金额:%d", h.App.SysConfig.DailyPower), + Remark: fmt.Sprintf("每日签到奖励,金额:%d", h.App.SysConfig.Base.DailyPower), }) } resp.SUCCESS(c) diff --git a/api/handler/video_handler.go b/api/handler/video_handler.go index 607bec1a..8d8df750 100644 --- a/api/handler/video_handler.go +++ b/api/handler/video_handler.go @@ -78,7 +78,7 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) { return } - if user.Power < h.App.SysConfig.LumaPower { + if user.Power < h.App.SysConfig.Base.LumaPower { resp.ERROR(c, "您的算力不足,请充值后再试!") return } @@ -95,14 +95,14 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) { Type: types.VideoLuma, Prompt: data.Prompt, Params: params, - TranslateModelId: h.App.SysConfig.AssistantModelId, + TranslateModelId: h.App.SysConfig.Base.AssistantModelId, } // 插入数据库 job := model.VideoJob{ UserId: uint(userId), Type: types.VideoLuma, Prompt: data.Prompt, - Power: h.App.SysConfig.LumaPower, + Power: h.App.SysConfig.Base.LumaPower, TaskInfo: utils.JsonEncode(task), } tx := h.DB.Create(&job) @@ -157,7 +157,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) { // 计算当前任务所需算力 key := fmt.Sprintf("%s_%s_%s", data.Model, data.Mode, data.Duration) - power := h.App.SysConfig.KeLingPowers[key] + power := h.App.SysConfig.Base.KeLingPowers[key] if power == 0 { resp.ERROR(c, "当前模型暂不支持") return @@ -191,7 +191,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) { Type: types.VideoKeLing, Prompt: data.Prompt, Params: params, - TranslateModelId: h.App.SysConfig.AssistantModelId, + TranslateModelId: h.App.SysConfig.Base.AssistantModelId, Channel: data.Channel, } // 插入数据库 diff --git a/api/main.go b/api/main.go index 35305a39..f042cee1 100644 --- a/api/main.go +++ b/api/main.go @@ -30,7 +30,7 @@ import ( "log" "os" "os/signal" - "strconv" + "runtime/debug" "syscall" "time" @@ -71,15 +71,16 @@ func main() { if configFile == "" { configFile = "config.toml" } - debug, _ := strconv.ParseBool(os.Getenv("APP_DEBUG")) logger.Info("Loading config file: ", configFile) - if !debug { - defer func() { - if err := recover(); err != nil { - logger.Error("Panic Error:", err) + defer func() { + if err := recover(); err != nil { + logger.Error("Panic Error:", err) + // 打印堆栈信息 + if os.Getenv("GEEKAI_DEBUG") == "true" { + debug.PrintStack() } - }() - } + } + }() app := fx.New( // 初始化配置应用配置 @@ -89,16 +90,16 @@ func main() { log.Fatal(err) } config.Path = configFile - if debug { - _ = core.SaveConfig(config) - } return config }), // 创建应用服务 fx.Provide(core.NewServer), // 初始化 fx.Invoke(func(s *core.AppServer, client *redis.Client) { - s.Init(debug, client) + s.Init(client) + }), + fx.Provide(func(db *gorm.DB) *types.SystemConfig { + return core.LoadSystemConfig(db) }), // 初始化数据库 @@ -111,6 +112,12 @@ func main() { return xdbFS }), + // 数据修复 + fx.Provide(service.NewDataFixService), + fx.Invoke(func(s *core.AppServer, dfs *service.DataFixService) { + dfs.FixData() + }), + // 创建 Ip2Region 查询对象 fx.Provide(func() (*xdb.Searcher, error) { file, err := xdbFS.Open("res/ip2region.xdb") @@ -215,20 +222,34 @@ func main() { fx.Invoke(func(service *jimeng.Service) { service.Start() }), - fx.Provide(service.NewUserService), - fx.Provide(payment.NewAlipayService), - fx.Provide(payment.NewHuPiPay), - fx.Provide(payment.NewJPayService), - fx.Provide(payment.NewWechatService), fx.Provide(service.NewSnowflake), - // 创建服务 - fx.Provide(sms.NewSendServiceManager), - fx.Provide(func(config *types.AppConfig) *service.CaptchaService { - return service.NewCaptchaService(config.ApiConfig) + // 创建短信服务 + fx.Provide(sms.NewAliYunSmsService), + fx.Provide(sms.NewBaoSmsService), + fx.Provide(sms.NewSmsManager), + fx.Provide(func(config *types.SystemConfig) *service.CaptchaService { + return service.NewCaptchaService(config.GeekAPI.Captcha) }), + fx.Provide(func(config *types.SystemConfig, client *redis.Client) *service.WxLoginService { + return service.NewWxLoginService(config.GeekAPI.WxLogin, client) + }), + + // 支付服务 + fx.Provide(payment.NewAlipayService), + fx.Provide(payment.NewEPayService), + fx.Provide(payment.NewWxpayService), + + // 文件上传服务 + fx.Provide(oss.NewLocalStorage), + fx.Provide(oss.NewMiniOss), + fx.Provide(oss.NewQiNiuOss), + fx.Provide(oss.NewAliYunOss), fx.Provide(oss.NewUploaderManager), + // 用户服务 + fx.Provide(service.NewUserService), + // 注册路由 fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) { h.RegisterRoutes() diff --git a/api/service/captcha_service.go b/api/service/captcha_service.go index 01ec929b..b2e76d45 100644 --- a/api/service/captcha_service.go +++ b/api/service/captcha_service.go @@ -20,9 +20,9 @@ type CaptchaService struct { client *req.Client } -func NewCaptchaService(config types.CaptchaConfig) *CaptchaService { +func NewCaptchaService(captchaConfig types.CaptchaConfig) *CaptchaService { return &CaptchaService{ - config: config, + config: captchaConfig, client: req.C().SetTimeout(10 * time.Second), } } diff --git a/api/service/config_migration.go b/api/service/config_migration.go index eff6760a..606741b2 100644 --- a/api/service/config_migration.go +++ b/api/service/config_migration.go @@ -65,12 +65,6 @@ func (s *ConfigMigrationService) MigrateFromConfig(config *types.AppConfig) erro return err } - // 迁移API配置 - if err := s.migrateApiConfig(config); err != nil { - logger.Errorf("迁移API配置失败: %v", err) - return err - } - // 标记迁移完成 if err := s.markMigrationCompleted(); err != nil { logger.Errorf("标记迁移完成失败: %v", err) @@ -101,58 +95,13 @@ func (s *ConfigMigrationService) markMigrationCompleted() error { // 迁移支付配置 func (s *ConfigMigrationService) migratePaymentConfig(config *types.AppConfig) error { - // 支付宝配置 - alipayConfig := map[string]any{ - "enabled": config.AlipayConfig.Enabled, - "sand_box": config.AlipayConfig.SandBox, - "app_id": config.AlipayConfig.AppId, - "private_key": config.AlipayConfig.PrivateKey, - "alipay_public_key": config.AlipayConfig.AlipayPublicKey, - "notify_url": config.AlipayConfig.NotifyURL, - "return_url": config.AlipayConfig.ReturnURL, - } - if err := s.saveConfig("alipay", alipayConfig); err != nil { - return err - } - // 微信支付配置 - wechatConfig := map[string]any{ - "enabled": config.WechatPayConfig.Enabled, - "app_id": config.WechatPayConfig.AppId, - "mch_id": config.WechatPayConfig.MchId, - "serial_no": config.WechatPayConfig.SerialNo, - "private_key": config.WechatPayConfig.PrivateKey, - "api_v3_key": config.WechatPayConfig.ApiV3Key, - "notify_url": config.WechatPayConfig.NotifyURL, + paymentConfig := types.PaymentConfig{ + Alipay: config.AlipayConfig, + Epay: config.GeekPayConfig, + WxPay: config.WechatPayConfig, } - if err := s.saveConfig("wechat", wechatConfig); err != nil { - return err - } - - // 虎皮椒配置 - hupiConfig := map[string]any{ - "enabled": config.HuPiPayConfig.Enabled, - "app_id": config.HuPiPayConfig.AppId, - "app_secret": config.HuPiPayConfig.AppSecret, - "api_url": config.HuPiPayConfig.ApiURL, - "notify_url": config.HuPiPayConfig.NotifyURL, - "return_url": config.HuPiPayConfig.ReturnURL, - } - if err := s.saveConfig("hupi", hupiConfig); err != nil { - return err - } - - // GeekPay配置 - geekpayConfig := map[string]any{ - "enabled": config.GeekPayConfig.Enabled, - "app_id": config.GeekPayConfig.AppId, - "private_key": config.GeekPayConfig.PrivateKey, - "api_url": config.GeekPayConfig.ApiURL, - "notify_url": config.GeekPayConfig.NotifyURL, - "return_url": config.GeekPayConfig.ReturnURL, - "methods": config.GeekPayConfig.Methods, - } - if err := s.saveConfig("geekpay", geekpayConfig); err != nil { + if err := s.saveConfig(types.ConfigKeyPayment, paymentConfig); err != nil { return err } @@ -161,37 +110,15 @@ func (s *ConfigMigrationService) migratePaymentConfig(config *types.AppConfig) e // 迁移存储配置 func (s *ConfigMigrationService) migrateStorageConfig(config *types.AppConfig) error { - ossConfig := map[string]any{ - "active": config.OSS.Active, - "local": map[string]any{ - "base_path": config.OSS.Local.BasePath, - "base_url": config.OSS.Local.BaseURL, - }, - "minio": map[string]any{ - "endpoint": config.OSS.Minio.Endpoint, - "access_key": config.OSS.Minio.AccessKey, - "access_secret": config.OSS.Minio.AccessSecret, - "bucket": config.OSS.Minio.Bucket, - "use_ssl": config.OSS.Minio.UseSSL, - "domain": config.OSS.Minio.Domain, - }, - "qiniu": map[string]any{ - "zone": config.OSS.QiNiu.Zone, - "access_key": config.OSS.QiNiu.AccessKey, - "access_secret": config.OSS.QiNiu.AccessSecret, - "bucket": config.OSS.QiNiu.Bucket, - "domain": config.OSS.QiNiu.Domain, - }, - "aliyun": map[string]any{ - "endpoint": config.OSS.AliYun.Endpoint, - "access_key": config.OSS.AliYun.AccessKey, - "access_secret": config.OSS.AliYun.AccessSecret, - "bucket": config.OSS.AliYun.Bucket, - "sub_dir": config.OSS.AliYun.SubDir, - "domain": config.OSS.AliYun.Domain, - }, + + ossConfig := types.OSSConfig{ + Active: config.OSS.Active, + Local: config.OSS.Local, + Minio: config.OSS.Minio, + QiNiu: config.OSS.QiNiu, + AliYun: config.OSS.AliYun, } - return s.saveConfig("oss", ossConfig) + return s.saveConfig(types.ConfigKeyOss, ossConfig) } // 迁移通信配置 @@ -205,7 +132,7 @@ func (s *ConfigMigrationService) migrateCommunicationConfig(config *types.AppCon "from": config.SmtpConfig.From, "password": config.SmtpConfig.Password, } - if err := s.saveConfig("smtp", smtpConfig); err != nil { + if err := s.saveConfig(types.ConfigKeySmtp, smtpConfig); err != nil { return err } @@ -215,34 +142,17 @@ func (s *ConfigMigrationService) migrateCommunicationConfig(config *types.AppCon "ali": map[string]any{ "access_key": config.SMS.Ali.AccessKey, "access_secret": config.SMS.Ali.AccessSecret, - "product": config.SMS.Ali.Product, - "domain": config.SMS.Ali.Domain, "sign": config.SMS.Ali.Sign, "code_temp_id": config.SMS.Ali.CodeTempId, }, "bao": map[string]any{ "username": config.SMS.Bao.Username, "password": config.SMS.Bao.Password, - "domain": config.SMS.Bao.Domain, "sign": config.SMS.Bao.Sign, "code_template": config.SMS.Bao.CodeTemplate, }, } - return s.saveConfig("sms", smsConfig) -} - -// 迁移API配置 -func (s *ConfigMigrationService) migrateApiConfig(config *types.AppConfig) error { - apiConfig := map[string]any{ - "api_url": config.ApiConfig.ApiURL, - "app_id": config.ApiConfig.AppId, - "token": config.ApiConfig.Token, - "jimeng_config": map[string]any{ - "access_key": config.ApiConfig.JimengConfig.AccessKey, - "secret_key": config.ApiConfig.JimengConfig.SecretKey, - }, - } - return s.saveConfig("api", apiConfig) + return s.saveConfig(types.ConfigKeySms, smsConfig) } // 保存配置到数据库 diff --git a/api/service/data_fix_service.go b/api/service/data_fix_service.go new file mode 100644 index 00000000..860da884 --- /dev/null +++ b/api/service/data_fix_service.go @@ -0,0 +1,66 @@ +package service + +import ( + "geekai/store/model" + + "github.com/go-redis/redis/v8" + "gorm.io/gorm" +) + +type DataFixService struct { + db *gorm.DB + redis *redis.Client +} + +func NewDataFixService(db *gorm.DB, redis *redis.Client) *DataFixService { + return &DataFixService{db: db, redis: redis} +} + +func (s *DataFixService) FixData() { + s.FixColumn() +} + +// 字段修正 +func (s *DataFixService) FixColumn() { + // 订单字段整理 + if s.db.Migrator().HasColumn(&model.Order{}, "pay_type") { + s.db.Migrator().RenameColumn(&model.Order{}, "pay_type", "channel") + } + if !s.db.Migrator().HasColumn(&model.Order{}, "check") { + s.db.Migrator().AddColumn(&model.Order{}, "checked") + } + + // 重命名 config 表字段 + if s.db.Migrator().HasColumn(&model.Config{}, "config_json") { + s.db.Migrator().RenameColumn(&model.Config{}, "config_json", "value") + } + if s.db.Migrator().HasColumn(&model.Config{}, "marker") { + s.db.Migrator().RenameColumn(&model.Config{}, "marker", "name") + } + if s.db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") { + s.db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key") + } + if s.db.Migrator().HasIndex(&model.Config{}, "marker") { + s.db.Migrator().DropIndex(&model.Config{}, "marker") + } + + // 手动删除字段 + if s.db.Migrator().HasColumn(&model.Order{}, "deleted_at") { + s.db.Migrator().DropColumn(&model.Order{}, "deleted_at") + } + if s.db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") { + s.db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at") + } + if s.db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") { + s.db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at") + } + if s.db.Migrator().HasColumn(&model.User{}, "chat_config") { + s.db.Migrator().DropColumn(&model.User{}, "chat_config") + } + if s.db.Migrator().HasColumn(&model.ChatModel{}, "category") { + s.db.Migrator().DropColumn(&model.ChatModel{}, "category") + } + if s.db.Migrator().HasColumn(&model.ChatModel{}, "description") { + s.db.Migrator().DropColumn(&model.ChatModel{}, "description") + } +} diff --git a/api/service/license_service.go b/api/service/license_service.go index 1dd85c1d..ea104a55 100644 --- a/api/service/license_service.go +++ b/api/service/license_service.go @@ -21,7 +21,6 @@ import ( ) type LicenseService struct { - config types.GeekServiceConfig levelDB *store.LevelDB license *types.License urlWhiteList []string @@ -39,7 +38,6 @@ func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseS } logger.Infof("License: %+v", license) return &LicenseService{ - config: server.Config.ApiConfig, levelDB: levelDB, license: &license, machineId: machineId, @@ -63,7 +61,7 @@ func (s *LicenseService) ActiveLicense(license string, machineId string) error { Message string `json:"message"` Data License `json:"data"` } - apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active") + apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/active") response, err := req.C().R(). SetBody(map[string]string{"license": license, "machine_id": machineId}). SetSuccessResult(&res).Post(apiURL) @@ -129,7 +127,7 @@ func (s *LicenseService) fetchLicense() (*types.License, error) { Message string `json:"message"` Data License `json:"data"` } - apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check") + apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/check") response, err := req.C().R(). SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}). SetSuccessResult(&res).Post(apiURL) @@ -158,7 +156,7 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) { Message string `json:"message"` Data []string `json:"data"` } - apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls") + apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/urls") response, err := req.C().R().SetSuccessResult(&res).Get(apiURL) if err != nil { return nil, fmt.Errorf("发送请求失败: %v", err) diff --git a/api/service/oss/aliyun_oss.go b/api/service/oss/aliyun_oss.go index 271cdfff..827c85d3 100644 --- a/api/service/oss/aliyun_oss.go +++ b/api/service/oss/aliyun_oss.go @@ -28,30 +28,32 @@ type AliYunOss struct { proxyURL string } -func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) { - config := &appConfig.OSS.AliYun - // 创建 OSS 客户端 +func NewAliYunOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*AliYunOss, error) { + s := &AliYunOss{ + proxyURL: appConfig.ProxyURL, + } + if sysConfig.OSS.Active == AliYun { + err := s.UpdateConfig(&sysConfig.OSS.AliYun) + if err != nil { + logger.Errorf("阿里云OSS初始化失败: %v", err) + } + } + return s, nil + +} + +func (s *AliYunOss) UpdateConfig(config *types.AliYunOssConfig) error { client, err := oss.New(config.Endpoint, config.AccessKey, config.AccessSecret) if err != nil { - return nil, err + return err } - - // 获取存储空间 bucket, err := client.Bucket(config.Bucket) if err != nil { - return nil, err + return err } - - if config.SubDir == "" { - config.SubDir = "gpt" - } - - return &AliYunOss{ - config: config, - bucket: bucket, - proxyURL: appConfig.ProxyURL, - }, nil - + s.bucket = bucket + s.config = config + return nil } func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) { diff --git a/api/service/oss/localstorage.go b/api/service/oss/localstorage.go index 37d4f5ff..b2337e8c 100644 --- a/api/service/oss/localstorage.go +++ b/api/service/oss/localstorage.go @@ -25,13 +25,17 @@ type LocalStorage struct { proxyURL string } -func NewLocalStorage(config *types.AppConfig) LocalStorage { - return LocalStorage{ - config: &config.OSS.Local, - proxyURL: config.ProxyURL, +func NewLocalStorage(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *LocalStorage { + return &LocalStorage{ + config: &sysConfig.OSS.Local, + proxyURL: appConfig.ProxyURL, } } +func (s *LocalStorage) UpdateConfig(config *types.LocalStorageConfig) { + s.config = config +} + func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) { file, err := ctx.FormFile(name) if err != nil { diff --git a/api/service/oss/minio_oss.go b/api/service/oss/minio_oss.go index 530dd0e0..d7f28d4b 100644 --- a/api/service/oss/minio_oss.go +++ b/api/service/oss/minio_oss.go @@ -29,19 +29,29 @@ type MiniOss struct { proxyURL string } -func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) { - config := &appConfig.OSS.Minio +func NewMiniOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*MiniOss, error) { + + s := &MiniOss{proxyURL: appConfig.ProxyURL} + if sysConfig.OSS.Active == Minio { + err := s.UpdateConfig(&sysConfig.OSS.Minio) + if err != nil { + logger.Errorf("MinioOSS初始化失败: %v", err) + } + } + return s, nil +} + +func (s *MiniOss) UpdateConfig(config *types.MiniOssConfig) error { minioClient, err := minio.New(config.Endpoint, &minio.Options{ Creds: credentials.NewStaticV4(config.AccessKey, config.AccessSecret, ""), Secure: config.UseSSL, }) if err != nil { - return MiniOss{}, err + return err } - if config.SubDir == "" { - config.SubDir = "gpt" - } - return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil + s.config = config + s.client = minioClient + return nil } func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) { diff --git a/api/service/oss/qiniu_oss.go b/api/service/oss/qiniu_oss.go index 3913410e..aec18918 100644 --- a/api/service/oss/qiniu_oss.go +++ b/api/service/oss/qiniu_oss.go @@ -29,13 +29,21 @@ type QinNiuOss struct { mac *qbox.Mac putPolicy storage.PutPolicy uploader *storage.FormUploader - manager *storage.BucketManager + bucket *storage.BucketManager proxyURL string } -func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss { - config := &appConfig.OSS.QiNiu - // build storage uploader +func NewQiNiuOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *QinNiuOss { + s := &QinNiuOss{ + proxyURL: appConfig.ProxyURL, + } + if sysConfig.OSS.Active == QiNiu { + s.UpdateConfig(&sysConfig.OSS.QiNiu) + } + return s +} + +func (s *QinNiuOss) UpdateConfig(config *types.QiNiuOssConfig) { zone, ok := storage.GetRegionByID(storage.RegionID(config.Zone)) if !ok { zone = storage.ZoneHuanan @@ -47,19 +55,12 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss { putPolicy := storage.PutPolicy{ Scope: config.Bucket, } - if config.SubDir == "" { - config.SubDir = "gpt" - } - return QinNiuOss{ - config: config, - mac: mac, - putPolicy: putPolicy, - uploader: formUploader, - manager: storage.NewBucketManager(mac, &storeConfig), - proxyURL: appConfig.ProxyURL, - } + s.config = config + s.mac = mac + s.putPolicy = putPolicy + s.uploader = formUploader + s.bucket = storage.NewBucketManager(mac, &storeConfig) } - func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) { // 解析表单 file, err := ctx.FormFile(name) @@ -147,7 +148,7 @@ func (s QinNiuOss) Delete(fileURL string) error { objectKey = fileURL } - return s.manager.Delete(s.config.Bucket, objectKey) + return s.bucket.Delete(s.config.Bucket, objectKey) } var _ Uploader = QinNiuOss{} diff --git a/api/service/oss/uploader_manager.go b/api/service/oss/uploader_manager.go index 573891b5..be011aff 100644 --- a/api/service/oss/uploader_manager.go +++ b/api/service/oss/uploader_manager.go @@ -10,44 +10,45 @@ package oss import ( "geekai/core/types" "strings" + + logger2 "geekai/logger" ) +var logger = logger2.GetLogger() + type UploaderManager struct { - handler Uploader + local *LocalStorage + aliyun *AliYunOss + mini *MiniOss + qiniu *QinNiuOss + config *types.OSSConfig } -func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) { - active := Local - if config.OSS.Active != "" { - active = strings.ToUpper(config.OSS.Active) - } - var handler Uploader - switch active { - case Local: - handler = NewLocalStorage(config) - break - case Minio: - client, err := NewMiniOss(config) - if err != nil { - return nil, err - } - handler = client - break - case QiNiu: - handler = NewQiNiuOss(config) - break - case AliYun: - client, err := NewAliYunOss(config) - if err != nil { - return nil, err - } - handler = client - break +func NewUploaderManager(sysConfig *types.SystemConfig, local *LocalStorage, aliyun *AliYunOss, mini *MiniOss, qiniu *QinNiuOss) (*UploaderManager, error) { + if sysConfig.OSS.Active == "" { + sysConfig.OSS.Active = Local } + sysConfig.OSS.Active = strings.ToLower(sysConfig.OSS.Active) - return &UploaderManager{handler: handler}, nil + return &UploaderManager{ + config: &sysConfig.OSS, + local: local, + aliyun: aliyun, + mini: mini, + qiniu: qiniu, + }, nil } func (m *UploaderManager) GetUploadHandler() Uploader { - return m.handler + switch m.config.Active { + case Local: + return m.local + case AliYun: + return m.aliyun + case Minio: + return m.mini + case QiNiu: + return m.qiniu + } + return m.local } diff --git a/api/service/payment/alipay_service.go b/api/service/payment/alipay_service.go index b2c10e8c..ce567f2a 100644 --- a/api/service/payment/alipay_service.go +++ b/api/service/payment/alipay_service.go @@ -20,109 +20,90 @@ import ( ) type AlipayService struct { - config *types.AlipayConfig client *alipay.Client + config *types.AlipayConfig } var logger = logger2.GetLogger() -func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) { - config := appConfig.AlipayConfig +func NewAlipayService(sysConfig *types.SystemConfig) (*AlipayService, error) { + config := sysConfig.Payment.Alipay if !config.Enabled { - logger.Info("Disabled Alipay service") - return nil, nil + logger.Debug("Disabled Alipay service") } + service := &AlipayService{config: &config} + if config.Enabled { + err := service.UpdateConfig(&config) + if err != nil { + logger.Errorf("支付宝服务初始化失败: %v", err) + } + } + + return service, nil +} + +func (s *AlipayService) UpdateConfig(config *types.AlipayConfig) error { client, err := alipay.NewClient(config.AppId, config.PrivateKey, !config.SandBox) if err != nil { - return nil, fmt.Errorf("error with initialize alipay service: %v", err) + return fmt.Errorf("error with initialize alipay service: %v", err) } - return &AlipayService{config: &config, client: client}, nil + s.client = client + s.config = config + if os.Getenv("GEEKAI_DEBUG") == "true" { + logger.Info("Alipay Debug mode is enabled") + client.DebugSwitch = gopay.DebugOn + } + return nil } -type AlipayParams struct { - OutTradeNo string `json:"out_trade_no"` - Subject string `json:"subject"` - TotalFee string `json:"total_fee"` - ReturnURL string `json:"return_url"` - NotifyURL string `json:"notify_url"` -} - -func (s *AlipayService) PayMobile(params AlipayParams) (string, error) { - bm := make(gopay.BodyMap) - bm.Set("subject", params.Subject) - bm.Set("out_trade_no", params.OutTradeNo) - bm.Set("quit_url", params.ReturnURL) - bm.Set("total_amount", params.TotalFee) - bm.Set("product_code", "QUICK_WAP_WAY") - return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradeWapPay(context.Background(), bm) -} - -func (s *AlipayService) PayPC(params AlipayParams) (string, error) { +func (s *AlipayService) Pay(params PayRequest) (string, error) { bm := make(gopay.BodyMap) bm.Set("subject", params.Subject) bm.Set("out_trade_no", params.OutTradeNo) bm.Set("total_amount", params.TotalFee) - bm.Set("product_code", "FAST_INSTANT_TRADE_PAY") - return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradePagePay(context.Background(), bm) + return s.client.TradeWapPay(context.Background(), bm) +} + +func (s *AlipayService) Query(outTradeNo string) (OrderInfo, error) { + bm := make(gopay.BodyMap) + bm.Set("out_trade_no", outTradeNo) + rsp, err := s.client.TradeQuery(context.Background(), bm) + if err != nil { + return OrderInfo{}, fmt.Errorf("error with trade query: %v", err) + } + + switch rsp.Response.TradeStatus { + case "TRADE_SUCCESS": + logger.Debugf("支付宝查询订单成功:%+v", rsp.Response) + return OrderInfo{ + OutTradeNo: rsp.Response.OutTradeNo, + TradeId: rsp.Response.TradeNo, + Amount: rsp.Response.TotalAmount, + Status: Success, + PayTime: rsp.Response.SendPayDate, + }, nil + case "TRADE_CLOSED": + return OrderInfo{Status: Closed}, nil + default: + return OrderInfo{}, fmt.Errorf("error with trade query: %v", rsp.Response.TradeStatus) + } } // TradeVerify 交易验证 -func (s *AlipayService) TradeVerify(request *http.Request) NotifyVo { +func (s *AlipayService) TradeVerify(request *http.Request) (OrderInfo, error) { notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法 if err != nil { - return NotifyVo{ - Status: Failure, - Message: "error with parse notify request: " + err.Error(), - } + return OrderInfo{}, fmt.Errorf("error with parse notify request: %v", err) } _, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq) if err != nil { - return NotifyVo{ - Status: Failure, - Message: "error with verify sign: " + err.Error(), - } + return OrderInfo{}, fmt.Errorf("error with verify sign: %v", err) } - return s.TradeQuery(request.Form.Get("out_trade_no")) + return s.Query(request.Form.Get("out_trade_no")) } -func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo { - bm := make(gopay.BodyMap) - bm.Set("out_trade_no", outTradeNo) - - //查询订单 - rsp, err := s.client.TradeQuery(context.Background(), bm) - if err != nil { - return NotifyVo{ - Status: Failure, - Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(), - } - } - - if rsp.Response.TradeStatus == "TRADE_SUCCESS" { - return NotifyVo{ - Status: Success, - OutTradeNo: rsp.Response.OutTradeNo, - TradeId: rsp.Response.TradeNo, - Amount: rsp.Response.TotalAmount, - Subject: rsp.Response.Subject, - Message: "OK", - } - } else { - return NotifyVo{ - Status: Failure, - Message: "异步查询验证订单信息发生错误" + outTradeNo, - } - } -} - -func readKey(filename string) (string, error) { - data, err := os.ReadFile(filename) - if err != nil { - return "", err - } - return string(data), nil -} +var _ PayService = (*AlipayService)(nil) diff --git a/api/service/payment/geekpay_service.go b/api/service/payment/epay_service.go similarity index 50% rename from api/service/payment/geekpay_service.go rename to api/service/payment/epay_service.go index c5306c6f..3f069323 100644 --- a/api/service/payment/geekpay_service.go +++ b/api/service/payment/epay_service.go @@ -22,41 +22,30 @@ import ( "time" ) -// GeekPayService Geek 支付服务 -type GeekPayService struct { - config *types.GeekPayConfig +// EPayService 易支付服务 +type EPayService struct { + config *types.EpayConfig } -func NewJPayService(appConfig *types.AppConfig) *GeekPayService { - return &GeekPayService{ - config: &appConfig.GeekPayConfig, +func NewEPayService(sysConfig *types.SystemConfig) *EPayService { + return &EPayService{ + config: &sysConfig.Payment.Epay, } } -type GeekPayParams struct { - Method string `json:"method"` // 接口类型 - Device string `json:"device"` // 设备类型 - Type string `json:"type"` // 支付方式 - OutTradeNo string `json:"out_trade_no"` // 商户订单号 - Name string `json:"name"` // 商品名称 - Money string `json:"money"` // 商品金额 - ClientIP string `json:"clientip"` //用户IP地址 - SubOpenId string `json:"sub_openid"` // 微信用户 openid,仅小程序支付需要 - SubAppId string `json:"sub_appid"` // 小程序 AppId,仅小程序支付需要 - NotifyURL string `json:"notify_url"` - ReturnURL string `json:"return_url"` +func (s *EPayService) UpdateConfig(config *types.EpayConfig) { + s.config = config } // Pay 支付订单 -func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) { +func (s *EPayService) Pay(params PayRequest) (string, error) { p := map[string]string{ - "pid": s.config.AppId, - //"method": params.Method, + "pid": s.config.AppId, "device": params.Device, - "type": params.Type, + "type": params.PayWay, "out_trade_no": params.OutTradeNo, - "name": params.Name, - "money": params.Money, + "name": params.Subject, + "money": params.TotalFee, "clientip": params.ClientIP, "notify_url": params.NotifyURL, "return_url": params.ReturnURL, @@ -64,10 +53,21 @@ func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) { } p["sign"] = s.Sign(p) p["sign_type"] = "MD5" - return s.sendRequest(s.config.ApiURL, p) + resp, err := s.sendRequest(s.config.ApiURL, p) + if err != nil { + return "", err + } + if resp.Code != 1 { + return "", errors.New(resp.Msg) + } + if resp.PayURL != "" { + return resp.PayURL, nil + } else { + return resp.QrCode, nil + } } -func (s *GeekPayService) Sign(params map[string]string) string { +func (s *EPayService) Sign(params map[string]string) string { // 按字母顺序排序参数 var keys []string for k := range params { @@ -100,7 +100,7 @@ type GeekPayResp struct { UrlScheme string `json:"urlscheme"` // 小程序跳转支付链接 } -func (s *GeekPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) { +func (s *EPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) { form := url.Values{} for k, v := range params { form.Add(k, v) @@ -137,3 +137,61 @@ func (s *GeekPayService) sendRequest(endpoint string, params map[string]string) } return &r, nil } + +func (s *EPayService) Query(outTradeNo string) (OrderInfo, error) { + + params := url.Values{} + params.Set("act", "order") + params.Set("pid", s.config.AppId) + params.Set("key", s.config.PrivateKey) + params.Set("out_trade_no", outTradeNo) + + apiURL := fmt.Sprintf("%s/api.php?%s", s.config.ApiURL, params.Encode()) + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + client := &http.Client{Transport: tr} + resp, err := client.Get(apiURL) + if err != nil { + return OrderInfo{}, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return OrderInfo{}, err + } + logger.Debugf(string(body)) + + var result struct { + Code int `json:"code"` + Msg string `json:"msg"` + Status string `json:"status"` + Name string `json:"name"` + Money string `json:"money"` + EndTime string `json:"endtime"` + TradeNo string `json:"trade_no"` + } + if err := json.Unmarshal(body, &result); err != nil { + return OrderInfo{}, errors.New("订单查询响应解析失败") + } + if result.Code != 1 { + return OrderInfo{}, errors.New(result.Msg) + } + logger.Debugf("订单信息:%+v", result) + orderInfo := OrderInfo{ + OutTradeNo: outTradeNo, + TradeId: result.TradeNo, + Amount: result.Money, + PayTime: result.EndTime, + } + if result.Status == "1" { + orderInfo.Status = Success + } else { + orderInfo.Status = Failure + } + return orderInfo, nil +} + +var _ PayService = (*EPayService)(nil) diff --git a/api/service/payment/hupipay_serive.go b/api/service/payment/hupipay_serive.go deleted file mode 100644 index b7266b7e..00000000 --- a/api/service/payment/hupipay_serive.go +++ /dev/null @@ -1,171 +0,0 @@ -package payment - -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// * Copyright 2023 The Geek-AI Authors. All rights reserved. -// * Use of this source code is governed by a Apache-2.0 license -// * that can be found in the LICENSE file. -// * @Author yangjian102621@163.com -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -import ( - "crypto/md5" - "encoding/hex" - "errors" - "fmt" - "geekai/core/types" - "geekai/utils" - "io" - "net/http" - "net/url" - "sort" - "strconv" - "strings" - "time" -) - -type HuPiPayService struct { - appId string - appSecret string - apiURL string -} - -func NewHuPiPay(config *types.AppConfig) *HuPiPayService { - return &HuPiPayService{ - appId: config.HuPiPayConfig.AppId, - appSecret: config.HuPiPayConfig.AppSecret, - apiURL: config.HuPiPayConfig.ApiURL, - } -} - -type HuPiPayParams struct { - AppId string `json:"appid"` - Version string `json:"version"` - TradeOrderId string `json:"trade_order_id"` - TotalFee string `json:"total_fee"` - Title string `json:"title"` - NotifyURL string `json:"notify_url"` - ReturnURL string `json:"return_url"` - WapName string `json:"wap_name"` - CallbackURL string `json:"callback_url"` - Time string `json:"time"` - NonceStr string `json:"nonce_str"` - Type string `json:"type"` - WapUrl string `json:"wap_url"` -} - -type HuPiPayResp struct { - Openid interface{} `json:"openid"` - UrlQrcode string `json:"url_qrcode"` - URL string `json:"url"` - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg,omitempty"` -} - -// Pay 执行支付请求操作 -func (s *HuPiPayService) Pay(params HuPiPayParams) (HuPiPayResp, error) { - data := url.Values{} - simple := strconv.FormatInt(time.Now().Unix(), 10) - params.AppId = s.appId - params.Time = simple - params.NonceStr = simple - encode := utils.JsonEncode(params) - m := make(map[string]string) - _ = utils.JsonDecode(encode, &m) - for k, v := range m { - data.Add(k, fmt.Sprintf("%v", v)) - } - // 生成签名 - data.Add("hash", s.Sign(data)) - // 发送支付请求 - apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL) - resp, err := http.PostForm(apiURL, data) - if err != nil { - return HuPiPayResp{}, fmt.Errorf("error with requst api: %v", err) - } - defer resp.Body.Close() - all, err := io.ReadAll(resp.Body) - if err != nil { - return HuPiPayResp{}, fmt.Errorf("error with reading response: %v", err) - } - - var res HuPiPayResp - err = utils.JsonDecode(string(all), &res) - if err != nil { - return HuPiPayResp{}, fmt.Errorf("error with decode payment result: %v", err) - } - - if res.ErrCode != 0 { - return HuPiPayResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg) - } - - return res, nil -} - -// Sign 签名方法 -func (s *HuPiPayService) Sign(params url.Values) string { - params.Del(`Sign`) - var keys = make([]string, 0, 0) - for key := range params { - if params.Get(key) != `` { - keys = append(keys, key) - } - } - sort.Strings(keys) - - var pList = make([]string, 0, 0) - for _, key := range keys { - var value = strings.TrimSpace(params.Get(key)) - if len(value) > 0 { - pList = append(pList, key+"="+value) - } - } - var src = strings.Join(pList, "&") - src += s.appSecret - - md5bs := md5.Sum([]byte(src)) - return hex.EncodeToString(md5bs[:]) -} - -// Check 校验订单状态 -func (s *HuPiPayService) Check(outTradeNo string) error { - data := url.Values{} - data.Add("appid", s.appId) - data.Add("out_trade_order", outTradeNo) - stamp := strconv.FormatInt(time.Now().Unix(), 10) - data.Add("time", stamp) - data.Add("nonce_str", stamp) - data.Add("hash", s.Sign(data)) - - apiURL := fmt.Sprintf("%s/payment/query.html", s.apiURL) - resp, err := http.PostForm(apiURL, data) - if err != nil { - return fmt.Errorf("error with http reqeust: %v", err) - } - - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("error with reading response: %v", err) - } - - var r struct { - ErrCode int `json:"errcode"` - Data struct { - Status string `json:"status"` - OpenOrderId string `json:"open_order_id"` - } `json:"data,omitempty"` - ErrMsg string `json:"errmsg"` - Hash string `json:"hash"` - } - err = utils.JsonDecode(string(body), &r) - if err != nil { - return fmt.Errorf("error with decode response: %v", err) - } - - if r.ErrCode == 0 && r.Data.Status == "OD" { - return nil - } else { - logger.Debugf("%+v", r) - return errors.New("order not paid:" + r.ErrMsg) - } -} diff --git a/api/service/payment/pay_service.go b/api/service/payment/pay_service.go new file mode 100644 index 00000000..06c65728 --- /dev/null +++ b/api/service/payment/pay_service.go @@ -0,0 +1,54 @@ +package payment + +// 支付渠道定义 +const PayChannelAL = "alipay" // 支付宝 +const PayChannelWX = "wxpay" // 微信支付 +const PayChannelEpay = "epay" // 易支付 + +// 支付方式 +const PayWayAL = "alipay" +const PayWayWX = "wxpay" + +const ( + Success = 0 + Failure = 1 + Closed = 2 +) + +type PayRequest struct { + OutTradeNo string // 商户订单号 + Subject string // 商品名称 + TotalFee string // 商品金额 + ReturnURL string // 回调地址 + NotifyURL string // 回调地址 + + // 易支付专有参数 + Method string // 接口类型 + Device string // 设备类型 + PayWay string // 支付方式 + ClientIP string //用户IP地址 + OpenID string // 用户openid + +} + +type OrderInfo struct { + Mchid string // 商户号 + OutTradeNo string // 商户订单号 + TradeId string // 交易号 + Amount string // 金额 + Status int // 状态 0: 未支付 1: 已支付 2: 已关闭 + PayTime string // 完成支付时间 +} + +func (o OrderInfo) Closed() bool { + return o.Status == Closed +} + +func (o OrderInfo) Success() bool { + return o.Status == Success +} + +type PayService interface { + Pay(params PayRequest) (string, error) // 生成支付链接 + Query(outTradeNo string) (OrderInfo, error) // 查询订单 +} diff --git a/api/service/payment/types.go b/api/service/payment/types.go deleted file mode 100644 index ef8ff24c..00000000 --- a/api/service/payment/types.go +++ /dev/null @@ -1,19 +0,0 @@ -package payment - -type NotifyVo struct { - Status int - OutTradeNo string // 商户订单号 - TradeId string // 交易ID - Amount string // 交易金额 - Message string - Subject string -} - -func (v NotifyVo) Success() bool { - return v.Status == Success -} - -const ( - Success = 0 - Failure = 1 -) diff --git a/api/service/payment/wepay_service.go b/api/service/payment/wepay_service.go deleted file mode 100644 index 7137718a..00000000 --- a/api/service/payment/wepay_service.go +++ /dev/null @@ -1,141 +0,0 @@ -package payment - -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// * Copyright 2023 The Geek-AI Authors. All rights reserved. -// * Use of this source code is governed by a Apache-2.0 license -// * that can be found in the LICENSE file. -// * @Author yangjian102621@163.com -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -import ( - "context" - "fmt" - "geekai/core/types" - "net/http" - "time" - - "github.com/go-pay/gopay" - "github.com/go-pay/gopay/wechat/v3" -) - -type WechatPayService struct { - config *types.WechatPayConfig - client *wechat.ClientV3 -} - -func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) { - config := appConfig.WechatPayConfig - if !config.Enabled { - logger.Info("Disabled WechatPay service") - return nil, nil - } - - client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, config.PrivateKey) - if err != nil { - return nil, fmt.Errorf("error with initialize WechatPay service: %v", err) - } - err = client.AutoVerifySign() - if err != nil { - return nil, fmt.Errorf("error with autoVerifySign: %v", err) - } - //client.DebugSwitch = gopay.DebugOn - - return &WechatPayService{config: &config, client: client}, nil -} - -type WechatPayParams struct { - OutTradeNo string `json:"out_trade_no"` - TotalFee int `json:"total_fee"` - Subject string `json:"subject"` - ClientIP string `json:"client_ip"` - ReturnURL string `json:"return_url"` - NotifyURL string `json:"notify_url"` -} - -func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) { - expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339) - // 初始化 BodyMap - bm := make(gopay.BodyMap) - bm.Set("appid", s.config.AppId). - Set("mchid", s.config.MchId). - Set("description", params.Subject). - Set("out_trade_no", params.OutTradeNo). - Set("time_expire", expire). - Set("notify_url", params.NotifyURL). - SetBodyMap("amount", func(bm gopay.BodyMap) { - bm.Set("total", params.TotalFee). - Set("currency", "CNY") - }) - - wxRsp, err := s.client.V3TransactionNative(context.Background(), bm) - if err != nil { - return "", fmt.Errorf("error with client v3 transaction Native: %v", err) - } - if wxRsp.Code != wechat.Success { - return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error) - } - return wxRsp.Response.CodeUrl, nil -} - -func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) { - expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339) - // 初始化 BodyMap - bm := make(gopay.BodyMap) - bm.Set("appid", s.config.AppId). - Set("mchid", s.config.MchId). - Set("description", params.Subject). - Set("out_trade_no", params.OutTradeNo). - Set("time_expire", expire). - Set("notify_url", params.NotifyURL). - SetBodyMap("amount", func(bm gopay.BodyMap) { - bm.Set("total", params.TotalFee). - Set("currency", "CNY") - }). - SetBodyMap("scene_info", func(bm gopay.BodyMap) { - bm.Set("payer_client_ip", params.ClientIP). - SetBodyMap("h5_info", func(bm gopay.BodyMap) { - bm.Set("type", "Wap") - }) - }) - - wxRsp, err := s.client.V3TransactionH5(context.Background(), bm) - if err != nil { - return "", fmt.Errorf("error with client v3 transaction H5: %v", err) - } - if wxRsp.Code != wechat.Success { - return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error) - } - return wxRsp.Response.H5Url, nil -} - -type NotifyResponse struct { - Code string `json:"code"` - Message string `xml:"message"` -} - -// TradeVerify 交易验证 -func (s *WechatPayService) TradeVerify(request *http.Request) NotifyVo { - notifyReq, err := wechat.V3ParseNotify(request) - if err != nil { - return NotifyVo{Status: 1, Message: fmt.Sprintf("error with client v3 parse notify: %v", err)} - } - - // TODO: 这里验签程序有 Bug,一直报错:crypto/rsa: verification error,先暂时取消验签 - //err = notifyReq.VerifySignByPK(s.client.WxPublicKey()) - //if err != nil { - // return fmt.Errorf("error with client v3 verify sign: %v", err) - //} - - // 解密支付密文,验证订单信息 - result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key) - if err != nil { - return NotifyVo{Status: Failure, Message: fmt.Sprintf("error with client v3 decrypt: %v", err)} - } - - return NotifyVo{ - Status: Success, - OutTradeNo: result.OutTradeNo, - TradeId: result.TransactionId, - Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100), - } -} diff --git a/api/service/payment/wxpay_service.go b/api/service/payment/wxpay_service.go new file mode 100644 index 00000000..fc4e100a --- /dev/null +++ b/api/service/payment/wxpay_service.go @@ -0,0 +1,216 @@ +package payment + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "context" + "fmt" + "geekai/core/types" + "geekai/utils" + "net/http" + "os" + "time" + + "github.com/go-pay/gopay" + "github.com/go-pay/gopay/wechat/v3" +) + +type WxPayService struct { + config *types.WxPayConfig + client *wechat.ClientV3 +} + +func NewWxpayService(sysConfig *types.SystemConfig) (*WxPayService, error) { + config := sysConfig.Payment.WxPay + if !config.Enabled { + logger.Debug("Disabled WechatPay service") + } + + service := &WxPayService{config: &config} + if config.Enabled { + err := service.UpdateConfig(&config) + if err != nil { + logger.Errorf("微信支付服务初始化失败: %v", err) + } + } + + return service, nil +} + +func (s *WxPayService) UpdateConfig(config *types.WxPayConfig) error { + client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, config.PrivateKey) + if err != nil { + return fmt.Errorf("error with initialize WechatPay service: %v", err) + } + err = client.AutoVerifySign() + if err != nil { + return fmt.Errorf("error with autoVerifySign: %v", err) + } + s.client = client + if os.Getenv("GEEKAI_DEBUG") == "true" { + logger.Info("WechatPay Debug mode is enabled") + client.DebugSwitch = gopay.DebugOn + } + s.config = config + return nil +} + +func (s *WxPayService) Pay(params PayRequest) (string, error) { + expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339) + // 初始化 BodyMap + bm := make(gopay.BodyMap) + bm.Set("appid", s.config.AppId). + Set("mchid", s.config.MchId). + Set("description", params.Subject). + Set("out_trade_no", params.OutTradeNo). + Set("time_expire", expire). + Set("notify_url", params.NotifyURL). + SetBodyMap("amount", func(bm gopay.BodyMap) { + bm.Set("total", utils.IntValue(params.TotalFee, 0)). + Set("currency", "CNY") + }) + if params.Device == "mobile" { + bm.SetBodyMap("scene_info", func(bm gopay.BodyMap) { + bm.Set("payer_client_ip", params.ClientIP) + }).SetBodyMap("payer", func(bm gopay.BodyMap) { + bm.Set("openid", params.OpenID) + }) + wxRsp, err := s.client.V3TransactionJsapi(context.Background(), bm) + if err != nil { + return "", fmt.Errorf("error with client v3 transaction Jsapi: %v", err) + } + if wxRsp.Code != wechat.Success { + return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error) + } + return wxRsp.Response.PrepayId, nil + } else if params.Device == "pc" { + wxRsp, err := s.client.V3TransactionNative(context.Background(), bm) + if err != nil { + return "", fmt.Errorf("error with client v3 transaction Native: %v", err) + } + if wxRsp.Code != wechat.Success { + return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error) + } + return wxRsp.Response.CodeUrl, nil + + } + return "", nil +} + +func (s *WxPayService) Query(outTradeNo string) (OrderInfo, error) { + wxRsp, err := s.client.V3TransactionQueryOrder(context.Background(), wechat.OutTradeNo, outTradeNo) + if err != nil { + return OrderInfo{}, fmt.Errorf("error with client v3 transaction query: %v", err) + } + + if wxRsp.Code != wechat.Success { + return OrderInfo{}, fmt.Errorf("error status with querying order: %v", wxRsp.Error) + } + + if wxRsp.Response.TradeState == "CLOSED" { + return OrderInfo{Status: Closed}, nil + } + + orderInfo := OrderInfo{ + OutTradeNo: wxRsp.Response.OutTradeNo, + TradeId: wxRsp.Response.TransactionId, + Amount: fmt.Sprintf("%d", wxRsp.Response.Amount.Total/100), + PayTime: wxRsp.Response.SuccessTime, + } + if wxRsp.Response.TradeState == "SUCCESS" { + orderInfo.Status = Success + } else { + orderInfo.Status = Failure + } + return orderInfo, nil +} + +// TradeVerify 交易验证 +func (s *WxPayService) TradeVerify(request *http.Request) (OrderInfo, error) { + notifyReq, err := wechat.V3ParseNotify(request) + if err != nil { + return OrderInfo{}, fmt.Errorf("error with client v3 parse notify: %v", err) + } + + // 解密支付密文,验证订单信息 + result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key) + if err != nil { + return OrderInfo{}, fmt.Errorf("error with client v3 decrypt: %v", err) + } + + return OrderInfo{ + Status: Success, + OutTradeNo: result.OutTradeNo, + TradeId: result.TransactionId, + Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100), + PayTime: result.SuccessTime, + }, nil +} + +// func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) { +// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339) +// // 初始化 BodyMap +// bm := make(gopay.BodyMap) +// bm.Set("appid", s.config.AppId). +// Set("mchid", s.config.MchId). +// Set("description", params.Subject). +// Set("out_trade_no", params.OutTradeNo). +// Set("time_expire", expire). +// Set("notify_url", params.NotifyURL). +// SetBodyMap("amount", func(bm gopay.BodyMap) { +// bm.Set("total", params.TotalFee). +// Set("currency", "CNY") +// }) + +// wxRsp, err := s.client.V3TransactionNative(context.Background(), bm) +// if err != nil { +// return "", fmt.Errorf("error with client v3 transaction Native: %v", err) +// } +// if wxRsp.Code != wechat.Success { +// return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error) +// } +// return wxRsp.Response.CodeUrl, nil +// } + +// func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) { +// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339) +// // 初始化 BodyMap +// bm := make(gopay.BodyMap) +// bm.Set("appid", s.config.AppId). +// Set("mchid", s.config.MchId). +// Set("description", params.Subject). +// Set("out_trade_no", params.OutTradeNo). +// Set("time_expire", expire). +// Set("notify_url", params.NotifyURL). +// SetBodyMap("amount", func(bm gopay.BodyMap) { +// bm.Set("total", params.TotalFee). +// Set("currency", "CNY") +// }). +// SetBodyMap("scene_info", func(bm gopay.BodyMap) { +// bm.Set("payer_client_ip", params.ClientIP). +// SetBodyMap("h5_info", func(bm gopay.BodyMap) { +// bm.Set("type", "Wap") +// }) +// }) + +// wxRsp, err := s.client.V3TransactionH5(context.Background(), bm) +// if err != nil { +// return "", fmt.Errorf("error with client v3 transaction H5: %v", err) +// } +// if wxRsp.Code != wechat.Success { +// return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error) +// } +// return wxRsp.Response.H5Url, nil +// } + +// type NotifyResponse struct { +// Code string `json:"code"` +// Message string `xml:"message"` +// } + +var _ PayService = (*WxPayService)(nil) diff --git a/api/service/sms/aliyun.go b/api/service/sms/aliyun.go index d0ea1b97..858263dd 100644 --- a/api/service/sms/aliyun.go +++ b/api/service/sms/aliyun.go @@ -10,36 +10,57 @@ package sms import ( "fmt" "geekai/core/types" + "github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi" ) type AliYunSmsService struct { config *types.SmsConfigAli client *dysmsapi.Client + domain string + zoneId string } -func NewAliYunSmsService(appConfig *types.AppConfig) (*AliYunSmsService, error) { - config := &appConfig.SMS.Ali - // 创建阿里云短信客户端 +func NewAliYunSmsService(sysConfig *types.SystemConfig) (*AliYunSmsService, error) { + config := &sysConfig.SMS.Ali + domain := "dysmsapi.aliyuncs.com" + zoneId := "cn-hangzhou" + + s := AliYunSmsService{ + config: config, + domain: domain, + zoneId: zoneId, + } + if sysConfig.SMS.Active == Ali { + err := s.UpdateConfig(config) + if err != nil { + logger.Errorf("阿里云短信初始化失败: %v", err) + } + } + return &s, nil +} + +func (s *AliYunSmsService) UpdateConfig(config *types.SmsConfigAli) error { client, err := dysmsapi.NewClientWithAccessKey( - "cn-hangzhou", + s.zoneId, config.AccessKey, config.AccessSecret) if err != nil { - return nil, fmt.Errorf("failed to create client: %v", err) + return fmt.Errorf("failed to create client: %v", err) } - - return &AliYunSmsService{ - config: config, - client: client, - }, nil + s.client = client + s.config = config + return nil } func (s *AliYunSmsService) SendVerifyCode(mobile string, code int) error { + if s.client == nil { + return fmt.Errorf("阿里云短信服务未初始化") + } // 创建短信请求并设置参数 request := dysmsapi.CreateSendSmsRequest() request.Scheme = "https" - request.Domain = s.config.Domain + request.Domain = s.domain request.PhoneNumbers = mobile request.SignName = s.config.Sign request.TemplateCode = s.config.CodeTempId diff --git a/api/service/sms/bao.go b/api/service/sms/bao.go index a00398de..19a1537c 100644 --- a/api/service/sms/bao.go +++ b/api/service/sms/bao.go @@ -20,19 +20,20 @@ import ( type BaoSmsService struct { config *types.SmsConfigBao + domain string } -func NewSmsBaoSmsService(appConfig *types.AppConfig) *BaoSmsService { - config := appConfig.SMS.Bao - if config.Domain == "" { // use default domain - config.Domain = "api.smsbao.com" - logger.Infof("Using default domain for SMS-BAO: %s", config.Domain) - } +func NewBaoSmsService(sysConfig *types.SystemConfig) *BaoSmsService { return &BaoSmsService{ - config: &config, + config: &sysConfig.SMS.Bao, + domain: "api.smsbao.com", } } +func (s *BaoSmsService) UpdateConfig(config *types.SmsConfigBao) { + s.config = config +} + var errMsg = map[string]string{ "0": "短信发送成功", "-1": "参数不全", @@ -56,7 +57,7 @@ func (s *BaoSmsService) SendVerifyCode(mobile string, code int) error { params.Set("m", mobile) params.Set("c", content) - apiURL := fmt.Sprintf("https://%s/sms?%s", s.config.Domain, params.Encode()) + apiURL := fmt.Sprintf("https://%s/sms?%s", s.domain, params.Encode()) response, err := http.Get(apiURL) if err != nil { return err diff --git a/api/service/sms/service_manager.go b/api/service/sms/service_manager.go index 0a4fcac8..3e2ac652 100644 --- a/api/service/sms/service_manager.go +++ b/api/service/sms/service_manager.go @@ -10,37 +10,32 @@ package sms import ( "geekai/core/types" logger2 "geekai/logger" - "strings" ) -type ServiceManager struct { - handler Service +type SmsManager struct { + aliyun *AliYunSmsService + bao *BaoSmsService + active string } var logger = logger2.GetLogger() -func NewSendServiceManager(config *types.AppConfig) (*ServiceManager, error) { - active := Ali - if config.SMS.Active != "" { - active = strings.ToUpper(config.SMS.Active) - } - var handler Service - switch active { - case Ali: - client, err := NewAliYunSmsService(config) - if err != nil { - return nil, err - } - handler = client - break - case Bao: - handler = NewSmsBaoSmsService(config) - break - } +func NewSmsManager(sysConfig *types.SystemConfig, aliyun *AliYunSmsService, bao *BaoSmsService) (*SmsManager, error) { - return &ServiceManager{handler: handler}, nil + return &SmsManager{ + active: sysConfig.SMS.Active, + aliyun: aliyun, + bao: bao, + }, nil } -func (m *ServiceManager) GetService() Service { - return m.handler +func (m *SmsManager) GetService() Service { + if m.active == Ali { + return m.aliyun + } + return m.bao +} + +func (m *SmsManager) SetActive(active string) { + m.active = active } diff --git a/api/service/wxlogin_service.go b/api/service/wxlogin_service.go new file mode 100644 index 00000000..ac6c704b --- /dev/null +++ b/api/service/wxlogin_service.go @@ -0,0 +1,109 @@ +package service + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "context" + "errors" + "fmt" + "geekai/core/types" + "geekai/utils" + "time" + + "github.com/go-redis/redis/v8" + "github.com/imroc/req/v3" +) + +type WxLoginService struct { + config types.WxLoginConfig + client *req.Client + redisClient *redis.Client +} + +const loginStateKeyPrefix = "wx_login_state/" + +type LoginStatus struct { + Status string `json:"status"` + OpenID string `json:"openid"` + Token string `json:"token"` +} + +const ( + LoginStatusPending = "pending" + LoginStatusSuccess = "success" + LoginStatusExpired = "expired" // 登录失效,需要重新登录 +) + +func NewWxLoginService(config types.WxLoginConfig, redisClient *redis.Client) *WxLoginService { + return &WxLoginService{ + config: config, + client: req.C().SetTimeout(10 * time.Second), + redisClient: redisClient, + } +} + +func (s *WxLoginService) UpdateConfig(config types.WxLoginConfig) { + s.config = config +} + +func (s *WxLoginService) GetLoginQrCodeUrl(state string) (string, error) { + if s.config.ApiKey == "" { + return "", errors.New("无效的 API Key") + } + + url := fmt.Sprintf("%s/api/auth/wechat/login", types.GeekAPIURL) + var res struct { + Code types.BizCode `json:"code"` + Message string `json:"message"` + Data struct { + Ticket string `json:"ticket"` + Url string `json:"url"` + } `json:"data"` + } + r, err := s.client.R(). + SetHeader("Authorization", s.config.ApiKey). + SetBody(map[string]string{ + "notify_url": s.config.NotifyURL, + "state": state, + }). + SetSuccessResult(&res).Post(url) + if err != nil || r.IsErrorState() { + return "", fmt.Errorf("请求 API 失败:%v", err) + } + + if res.Code != types.Success { + return "", fmt.Errorf("请求 API 失败:%s", res.Message) + } + + status := LoginStatus{ + Status: LoginStatusPending, + OpenID: "", + } + s.redisClient.Set(context.Background(), loginStateKeyPrefix+state, utils.JsonEncode(status), time.Hour) + + return res.Data.Url, nil +} + +func (s *WxLoginService) GetLoginStatus(state string) (*LoginStatus, error) { + result, err := s.redisClient.Get(context.Background(), loginStateKeyPrefix+state).Result() + if err != nil { + return nil, errors.New("登录失败") + } + + var status LoginStatus + err = utils.JsonDecode(result, &status) + if err != nil { + return nil, errors.New("登录失败") + } + + return &status, nil +} + +func (s *WxLoginService) SetLoginStatus(state string, status LoginStatus) { + s.redisClient.Set(context.Background(), loginStateKeyPrefix+state, utils.JsonEncode(status), time.Hour) +} diff --git a/api/store/model/order.go b/api/store/model/order.go index 9eda023a..3fc2a06b 100644 --- a/api/store/model/order.go +++ b/api/store/model/order.go @@ -7,21 +7,22 @@ import ( // Order 充值订单 type Order struct { - Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"` - UserId uint `gorm:"column:user_id;type:int;not null;comment:用户ID" json:"user_id"` - ProductId uint `gorm:"column:product_id;type:int;not null;comment:产品ID" json:"product_id"` - Username string `gorm:"column:username;type:varchar(30);not null;comment:用户名" json:"username"` - OrderNo string `gorm:"column:order_no;type:varchar(30);uniqueIndex;not null;comment:订单ID" json:"order_no"` - TradeNo string `gorm:"column:trade_no;type:varchar(60);comment:支付平台交易流水号" json:"trade_no"` - Subject string `gorm:"column:subject;type:varchar(100);not null;comment:订单产品" json:"subject"` - Amount float64 `gorm:"column:amount;type:decimal(10,2);not null;default:0.00;comment:订单金额" json:"amount"` - Status types.OrderStatus `gorm:"column:status;type:tinyint(1);not null;default:0;comment:订单状态(0:待支付,1:已扫码,2:支付成功)" json:"status"` - Remark string `gorm:"column:remark;type:varchar(255);not null;comment:备注" json:"remark"` - PayTime int64 `gorm:"column:pay_time;type:int;comment:支付时间" json:"pay_time"` - PayWay string `gorm:"column:pay_way;type:varchar(20);not null;comment:支付方式" json:"pay_way"` - PayType string `gorm:"column:pay_type;type:varchar(30);not null;comment:支付类型" json:"pay_type"` - CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"` - UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null" json:"updated_at"` + Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"` + UserId uint `gorm:"column:user_id;type:int;not null;comment:用户ID" json:"user_id"` + ProductId uint `gorm:"column:product_id;type:int;not null;comment:产品ID" json:"product_id"` + Username string `gorm:"column:username;type:varchar(30);not null;comment:用户名" json:"username"` + OrderNo string `gorm:"column:order_no;type:varchar(30);uniqueIndex;not null;comment:订单ID" json:"order_no"` + TradeNo string `gorm:"column:trade_no;type:varchar(60);comment:支付平台交易流水号" json:"trade_no"` + Subject string `gorm:"column:subject;type:varchar(100);not null;comment:订单产品" json:"subject"` + Amount float64 `gorm:"column:amount;type:decimal(10,2);not null;default:0.00;comment:订单金额" json:"amount"` + Status types.OrderStatus `gorm:"column:status;type:tinyint(1);not null;default:0;comment:订单状态(0:待支付,1:已扫码,2:支付成功)" json:"status"` + Remark string `gorm:"column:remark;type:varchar(255);not null;comment:备注" json:"remark"` + PayTime int64 `gorm:"column:pay_time;type:int;comment:支付时间" json:"pay_time"` + PayWay string `gorm:"column:pay_way;type:varchar(20);not null;comment:支付方式" json:"pay_way"` + Channel string `gorm:"column:channel;type:varchar(30);not null;comment:支付类型渠道:支付宝,微信,聚合支付"` // 支付类型渠道 + CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null" json:"updated_at"` + Checked bool `gorm:"column:checked;type:tinyint;not null;default:0;comment:是否已检查"` // 是否已检查 } func (m *Order) TableName() string { diff --git a/api/store/vo/order.go b/api/store/vo/order.go index 8ec4f2f9..22cdb267 100644 --- a/api/store/vo/order.go +++ b/api/store/vo/order.go @@ -6,18 +6,18 @@ import ( type Order struct { BaseVo - UserId uint `json:"user_id"` - ProductId uint `json:"product_id"` - Username string `json:"username"` - OrderNo string `json:"order_no"` - TradeNo string `json:"trade_no"` - Subject string `json:"subject"` - Amount float64 `json:"amount"` - Status types.OrderStatus `json:"status"` - PayTime int64 `json:"pay_time"` - PayWay string `json:"pay_way"` - PayType string `json:"pay_type"` - PayMethod string `json:"pay_method"` - PayName string `json:"pay_name"` - Remark types.OrderRemark `json:"remark"` + UserId uint `json:"user_id"` + ProductId uint `json:"product_id"` + Username string `json:"username"` + OrderNo string `json:"order_no"` + TradeNo string `json:"trade_no"` + Subject string `json:"subject"` + Amount float64 `json:"amount"` + Status types.OrderStatus `json:"status"` + PayTime int64 `json:"pay_time"` + PayWay string `json:"pay_way"` + Channel string `json:"channel"` + ChannelName string `json:"channel_name"` + PayName string `json:"pay_name"` + Remark types.OrderRemark `json:"remark"` } diff --git a/api/utils/net.go b/api/utils/net.go index 5e8a0985..9fd2c09a 100644 --- a/api/utils/net.go +++ b/api/utils/net.go @@ -15,6 +15,7 @@ import ( "io" "net/http" "net/url" + "path/filepath" ) var logger = logger2.GetLogger() @@ -92,3 +93,11 @@ func GetBaseURL(strURL string) string { } return fmt.Sprintf("%s://%s", u.Scheme, u.Host) } + +func GetImgExt(filename string) string { + ext := filepath.Ext(filename) + if ext == "" { + return ".png" + } + return ext +}