merge v4.2.6

整合 v4.2.6 的后端中间件与服务层重构、前端样式体系迁移和管理端/移动端功能更新,统一清理历史冲突并完成版本升级。

Made-with: Cursor
This commit is contained in:
RockYang
2026-04-08 15:08:34 +08:00
390 changed files with 35519 additions and 25073 deletions

View File

@@ -8,118 +8,50 @@ package core
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bytes"
"context"
"fmt"
"geekai/core/middleware"
"geekai/core/types"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"image"
"image/jpeg"
"io"
"net/http"
"os"
"runtime/debug"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/imroc/req/v3"
"github.com/nfnt/resize"
"github.com/shirou/gopsutil/host"
"golang.org/x/image/webp"
"gorm.io/gorm"
)
// AuthConfig 定义授权配置
type AuthConfig struct {
ExactPaths map[string]bool // 精确匹配的路径
PrefixPaths map[string]bool // 前缀匹配的路径
}
var authConfig = &AuthConfig{
ExactPaths: map[string]bool{
"/api/user/login": false,
"/api/user/logout": false,
"/api/user/resetPass": false,
"/api/user/register": false,
"/api/admin/login": false,
"/api/admin/logout": false,
"/api/admin/login/captcha": false,
"/api/app/list": false,
"/api/app/type/list": false,
"/api/app/list/user": false,
"/api/model/list": false,
"/api/mj/imgWall": false,
"/api/mj/notify": false,
"/api/invite/hits": false,
"/api/sd/imgWall": false,
"/api/dall/imgWall": false,
"/api/product/list": false,
"/api/menu/list": false,
"/api/markMap/client": false,
"/api/payment/doPay": false,
"/api/payment/payWays": false,
"/api/download": false,
"/api/dall/models": false,
},
PrefixPaths: map[string]bool{
"/api/test/": false,
"/api/payment/notify/": false,
"/api/user/clogin": false,
"/api/config/": false,
"/api/function/": false,
"/api/sms/": false,
"/api/captcha/": false,
"/static/": false,
},
}
type AppServer struct {
Config *types.AppConfig
Engine *gin.Engine
SysConfig *types.SystemConfig // system config cache
Redis *redis.Client
}
func NewServer(appConfig *types.AppConfig) *AppServer {
func NewServer(appConfig *types.AppConfig, redis *redis.Client, sysConfig *types.SystemConfig) *AppServer {
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
return &AppServer{
Config: appConfig,
Engine: gin.Default(),
Config: appConfig,
Redis: redis,
Engine: gin.Default(),
SysConfig: sysConfig,
}
}
func (s *AppServer) Init(debug bool, client *redis.Client) {
// 允许跨域请求 API
s.Engine.Use(corsMiddleware())
s.Engine.Use(staticResourceMiddleware())
s.Engine.Use(authorizeMiddleware(s, client))
s.Engine.Use(parameterHandlerMiddleware())
func (s *AppServer) Init(client *redis.Client) {
s.Engine.Use(middleware.ParameterHandlerMiddleware())
s.Engine.Use(errorHandler)
// 添加静态资源访问
s.Engine.Static("/static", s.Config.StaticDir)
s.Engine.Use(middleware.StaticMiddleware())
}
func (s *AppServer) Run(db *gorm.DB) error {
// 重命名 config 表字段
if db.Migrator().HasColumn(&model.Config{}, "config_json") {
db.Migrator().RenameColumn(&model.Config{}, "config_json", "value")
}
if db.Migrator().HasColumn(&model.Config{}, "marker") {
db.Migrator().RenameColumn(&model.Config{}, "marker", "name")
}
if db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") {
db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key")
}
if db.Migrator().HasIndex(&model.Config{}, "marker") {
db.Migrator().DropIndex(&model.Config{}, "marker")
}
// load system configs
var sysConfig model.Config
err := db.Where("name", "system").First(&sysConfig).Error
@@ -131,94 +63,22 @@ func (s *AppServer) Run(db *gorm.DB) error {
return fmt.Errorf("failed to decode system config: %v", err)
}
// 迁移数据表
logger.Info("Migrating database tables...")
db.AutoMigrate(
&model.ChatItem{},
&model.ChatMessage{},
&model.ChatRole{},
&model.ChatModel{},
&model.InviteCode{},
&model.InviteLog{},
&model.Menu{},
&model.Order{},
&model.Product{},
&model.User{},
&model.Function{},
&model.File{},
&model.Redeem{},
&model.Config{},
&model.ApiKey{},
&model.AdminUser{},
&model.AppType{},
&model.SdJob{},
&model.SunoJob{},
&model.PowerLog{},
&model.VideoJob{},
&model.MidJourneyJob{},
&model.UserLoginLog{},
&model.DallJob{},
&model.JimengJob{},
)
// 手动删除字段
if db.Migrator().HasColumn(&model.Order{}, "deleted_at") {
db.Migrator().DropColumn(&model.Order{}, "deleted_at")
}
if db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") {
db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at")
}
if db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") {
db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at")
}
if db.Migrator().HasColumn(&model.User{}, "chat_config") {
db.Migrator().DropColumn(&model.User{}, "chat_config")
}
if db.Migrator().HasColumn(&model.ChatModel{}, "category") {
db.Migrator().DropColumn(&model.ChatModel{}, "category")
}
if db.Migrator().HasColumn(&model.ChatModel{}, "description") {
db.Migrator().DropColumn(&model.ChatModel{}, "description")
}
logger.Info("Database tables migrated successfully")
// 统计安装信息
go func() {
info, err := host.Info()
if err == nil {
apiURL := fmt.Sprintf("%s/%s", s.Config.ApiConfig.ApiURL, "api/installs/push")
apiURL := fmt.Sprintf("%s/api/installs/push", types.GeekAPIURL)
timestamp := time.Now().Unix()
product := "geekai-plus"
signStr := fmt.Sprintf("%s#%s#%d", product, info.HostID, timestamp)
sign := utils.Sha256(signStr)
resp, err := req.C().R().SetBody(map[string]interface{}{"product": product, "device_id": info.HostID, "timestamp": timestamp, "sign": sign}).Post(apiURL)
if err != nil {
logger.Errorf("register install info failed: %v", err)
} else {
if err == nil {
logger.Debugf("register install info success: %v", resp.String())
}
}
}()
logger.Infof("http://%s", s.Config.Listen)
// 统计安装信息
go func() {
info, err := host.Info()
if err == nil {
apiURL := fmt.Sprintf("%s/%s", s.Config.ApiConfig.ApiURL, "api/installs/push")
timestamp := time.Now().Unix()
product := "geekai-plus"
signStr := fmt.Sprintf("%s#%s#%d", product, info.HostID, timestamp)
sign := utils.Sha256(signStr)
resp, err := req.C().R().SetBody(map[string]interface{}{"product": product, "device_id": info.HostID, "timestamp": timestamp, "sign": sign}).Post(apiURL)
if err != nil {
logger.Errorf("register install info failed: %v", err)
} else {
logger.Debugf("register install info success: %v", resp.String())
}
}
}()
return s.Engine.Run(s.Config.Listen)
}
@@ -235,283 +95,3 @@ func errorHandler(c *gin.Context) {
//加载完 defer recover继续后续接口调用
c.Next()
}
// 跨域中间件设置
func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
method := c.Request.Method
origin := c.Request.Header.Get("Origin")
// 设置允许的请求源
if origin != "" {
c.Header("Access-Control-Allow-Origin", origin)
} else {
c.Header("Access-Control-Allow-Origin", "*")
}
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
//允许跨域设置可以返回其他子段,可以自定义字段
c.Header("Access-Control-Allow-Headers", "Authorization, Body-Length, Body-Type, Admin-Authorization,content-type")
// 允许浏览器(客户端)可以解析的头部 (重要)
c.Header("Access-Control-Expose-Headers", "Body-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
//设置缓存时间
c.Header("Access-Control-Max-Age", "172800")
//允许客户端传递校验信息比如 cookie (重要)
c.Header("Access-Control-Allow-Credentials", "true")
if method == http.MethodOptions {
c.JSON(http.StatusOK, "ok!")
}
defer func() {
if err := recover(); err != nil {
logger.Info("Panic info is: %v", err)
}
}()
c.Next()
}
}
// 用户授权验证
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) {
if !needLogin(c) {
c.Next()
return
}
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
var tokenString string
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
if isAdminApi { // 后台管理 API
tokenString = c.GetHeader(types.AdminAuthHeader)
} else if clientProtocols != "" { // Websocket 连接
// 解析子协议内容
protocols := strings.Split(clientProtocols, ",")
if protocols[0] == "realtime" {
tokenString = strings.TrimSpace(protocols[1][25:])
} else if protocols[0] == "token" {
tokenString = strings.TrimSpace(protocols[1])
}
} else {
tokenString = c.GetHeader(types.UserAuthHeader)
}
if tokenString == "" {
resp.NotAuth(c, "You should put Authorization in request headers")
c.Abort()
return
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
if isAdminApi {
return []byte(s.Config.AdminSession.SecretKey), nil
} else {
return []byte(s.Config.Session.SecretKey), nil
}
})
if err != nil {
resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
c.Abort()
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
resp.NotAuth(c, "Token is invalid")
c.Abort()
return
}
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
if expr > 0 && int64(expr) < time.Now().Unix() {
resp.NotAuth(c, "Token is expired")
c.Abort()
return
}
key := fmt.Sprintf("users/%v", claims["user_id"])
if isAdminApi {
key = fmt.Sprintf("admin/%v", claims["user_id"])
}
if _, err := client.Get(context.Background(), key).Result(); err != nil {
resp.NotAuth(c, "Token is not found in redis")
c.Abort()
return
}
c.Set(types.LoginUserID, claims["user_id"])
c.Next()
}
}
func needLogin(c *gin.Context) bool {
path := c.Request.URL.Path
// 如果不是 API 路径,不需要登录
if !strings.HasPrefix(path, "/api") {
return false
}
// 检查精确匹配的路径
if skip, exists := authConfig.ExactPaths[path]; exists {
return skip
}
// 检查前缀匹配的路径
for prefix, skip := range authConfig.PrefixPaths {
if strings.HasPrefix(path, prefix) {
return skip
}
}
return true
}
// 跳过授权
func (s *AppServer) SkipAuth(url string, prefix bool) {
if prefix {
authConfig.PrefixPaths[url] = false
} else {
authConfig.ExactPaths[url] = false
}
}
// 统一参数处理
func parameterHandlerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// GET 参数处理
params := c.Request.URL.Query()
for key, values := range params {
for i, value := range values {
params[key][i] = strings.TrimSpace(value)
}
}
// update get parameters
c.Request.URL.RawQuery = params.Encode()
// skip file upload requests
contentType := c.Request.Header.Get("Content-Type")
if strings.Contains(contentType, "multipart/form-data") {
c.Next()
return
}
if strings.Contains(contentType, "application/json") {
// process POST JSON request body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
c.Next()
return
}
// 还原请求体
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 将请求体解析为 JSON
var jsonData map[string]interface{}
if err := c.ShouldBindJSON(&jsonData); err != nil {
c.Next()
return
}
// 对 JSON 数据中的字符串值去除两端空格
trimJSONStrings(jsonData)
// 更新请求体
c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData)))
}
c.Next()
}
}
// 递归对 JSON 数据中的字符串值去除两端空格
func trimJSONStrings(data interface{}) {
switch v := data.(type) {
case map[string]interface{}:
for key, value := range v {
switch valueType := value.(type) {
case string:
v[key] = strings.TrimSpace(valueType)
case map[string]interface{}, []interface{}:
trimJSONStrings(value)
}
}
case []interface{}:
for i, value := range v {
switch valueType := value.(type) {
case string:
v[i] = strings.TrimSpace(valueType)
case map[string]interface{}, []interface{}:
trimJSONStrings(value)
}
}
}
}
// 静态资源中间件
func staticResourceMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
url := c.Request.URL.String()
// 拦截生成缩略图请求
if strings.HasPrefix(url, "/static/") && strings.Contains(url, "?imageView2") {
r := strings.SplitAfter(url, "imageView2")
size := strings.Split(r[1], "/")
if len(size) != 8 {
c.String(http.StatusNotFound, "invalid thumb args")
return
}
with := utils.IntValue(size[3], 0)
height := utils.IntValue(size[5], 0)
quality := utils.IntValue(size[7], 75)
// 打开图片文件
filePath := strings.TrimLeft(c.Request.URL.Path, "/")
file, err := os.Open(filePath)
if err != nil {
c.String(http.StatusNotFound, "Image not found")
return
}
defer file.Close()
// 解码图片
img, _, err := image.Decode(file)
// for .webp image
if err != nil {
img, err = webp.Decode(file)
}
if err != nil {
c.String(http.StatusInternalServerError, "Error decoding image")
return
}
var newImg image.Image
if height == 0 || with == 0 {
// 固定宽度,高度自适应
newImg = resize.Resize(uint(with), uint(height), img, resize.Lanczos3)
} else {
// 生成缩略图
newImg = resize.Thumbnail(uint(with), uint(height), img, resize.Lanczos3)
}
var buffer bytes.Buffer
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
if err != nil {
logger.Error(err)
c.String(http.StatusInternalServerError, err.Error())
return
}
// 设置图片缓存有效期为一年 (365天)
c.Header("Cache-Control", "max-age=31536000, public")
// 直接输出图像数据流
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
c.Abort() // 中断请求
}
c.Next()
}
}

View File

@@ -11,10 +11,12 @@ import (
"bytes"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/store/model"
"geekai/utils"
"os"
"github.com/BurntSushi/toml"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
@@ -30,7 +32,6 @@ func NewDefaultConfig() *types.AppConfig {
SecretKey: utils.RandString(64),
MaxAge: 86400,
},
ApiConfig: types.ApiConfig{},
OSS: types.OSSConfig{
Active: "local",
Local: types.LocalStorageConfig{
@@ -38,7 +39,6 @@ func NewDefaultConfig() *types.AppConfig {
BasePath: "./static/upload",
},
},
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
}
}
@@ -74,3 +74,108 @@ func SaveConfig(config *types.AppConfig) error {
return os.WriteFile(config.Path, buf.Bytes(), 0644)
}
func LoadSystemConfig(db *gorm.DB) *types.SystemConfig {
// 加载系统配置
var sysConfig model.Config
var baseConfig types.BaseConfig
db.Where("name", "system").First(&sysConfig)
err := utils.JsonDecode(sysConfig.Value, &baseConfig)
if err != nil {
logger.Error("load system config error: ", err)
}
// 加载许可证配置
var license types.License
sysConfig.Id = 0
db.Where("name", types.ConfigKeyLicense).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &license)
if err != nil {
logger.Error("load license config error: ", err)
}
// 加载验证码配置
var captchaConfig types.CaptchaConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyCaptcha).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &captchaConfig)
if err != nil {
logger.Error("load geek service config error: ", err)
}
// 加载微信登录配置
var wxLoginConfig types.WxLoginConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyWxLogin).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &wxLoginConfig)
if err != nil {
logger.Error("load wx login config error: ", err)
}
// 加载短信配置
var smsConfig types.SMSConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeySms).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &smsConfig)
if err != nil {
logger.Error("load sms config error: ", err)
}
// 加载 OSS 配置
var ossConfig types.OSSConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyOss).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &ossConfig)
if err != nil {
logger.Error("load oss config error: ", err)
}
// 加载 SMTP 配置
var smtpConfig types.SmtpConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeySmtp).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &smtpConfig)
if err != nil {
logger.Error("load smtp config error: ", err)
}
// 加载支付配置
var paymentConfig types.PaymentConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyPayment).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &paymentConfig)
if err != nil {
logger.Error("load payment config error: ", err)
}
// 加载文本审查配置
var moderationConfig types.ModerationConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyModeration).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &moderationConfig)
if err != nil {
logger.Error("load moderation config error: ", err)
}
// 加载即梦AI配置
var jimengConfig types.JimengConfig
sysConfig.Id = 0
db.Where("name", types.ConfigKeyJimeng).First(&sysConfig)
err = utils.JsonDecode(sysConfig.Value, &jimengConfig)
if err != nil {
logger.Error("load jimeng config error: ", err)
}
return &types.SystemConfig{
Base: baseConfig,
License: license,
SMS: smsConfig,
OSS: ossConfig,
SMTP: smtpConfig,
Payment: paymentConfig,
Captcha: captchaConfig,
WxLogin: wxLoginConfig,
Moderation: moderationConfig,
Jimeng: jimengConfig,
}
}

109
api/core/middleware/auth.go Normal file
View File

@@ -0,0 +1,109 @@
package middleware
import (
"context"
"fmt"
"geekai/core/types"
"geekai/utils"
"geekai/utils/resp"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt"
)
// 前端用户授权验证
func UserAuthMiddleware(secretKey string, redis *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) {
tokenString := c.GetHeader(types.UserAuthHeader)
if tokenString == "" {
resp.NotAuth(c, "无效的授权令牌")
c.Abort()
return
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("不支持的令牌签名方法: %v", token.Header["alg"])
}
return []byte(secretKey), nil
})
if err != nil {
resp.NotAuth(c, fmt.Sprintf("解析授权令牌失败: %v", err))
c.Abort()
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
resp.NotAuth(c, "令牌无效")
c.Abort()
return
}
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
if expr > 0 && int64(expr) < time.Now().Unix() {
resp.NotAuth(c, "令牌过期")
c.Abort()
return
}
key := fmt.Sprintf("users/%v", claims["user_id"])
if _, err := redis.Get(context.Background(), key).Result(); err != nil {
resp.NotAuth(c, "当前用户已退出登录")
c.Abort()
return
}
c.Set(types.LoginUserID, claims["user_id"])
}
}
// 管理后台用户授权验证
func AdminAuthMiddleware(secretKey string, redis *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) {
tokenString := c.GetHeader(types.AdminAuthHeader)
if tokenString == "" {
resp.NotAuth(c, "无效的授权令牌")
c.Abort()
return
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("不支持的令牌签名方法: %v", token.Header["alg"])
}
return []byte(secretKey), nil
})
if err != nil {
resp.NotAuth(c, fmt.Sprintf("解析授权令牌失败: %v", err))
c.Abort()
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
resp.NotAuth(c, "令牌无效")
c.Abort()
return
}
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
if expr > 0 && int64(expr) < time.Now().Unix() {
resp.NotAuth(c, "令牌过期")
c.Abort()
return
}
key := fmt.Sprintf("admin/%v", claims["user_id"])
if _, err := redis.Get(context.Background(), key).Result(); err != nil {
resp.NotAuth(c, "当前用户已退出登录")
c.Abort()
return
}
c.Set(types.AdminUserID, claims["user_id"])
}
}

View File

@@ -0,0 +1,80 @@
package middleware
import (
"bytes"
"geekai/utils"
"io"
"strings"
"github.com/gin-gonic/gin"
)
// 统一参数处理
func ParameterHandlerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// GET 参数处理
params := c.Request.URL.Query()
for key, values := range params {
for i, value := range values {
params[key][i] = strings.TrimSpace(value)
}
}
// update get parameters
c.Request.URL.RawQuery = params.Encode()
// skip file upload requests
contentType := c.Request.Header.Get("Content-Type")
if strings.Contains(contentType, "multipart/form-data") {
c.Next()
return
}
if strings.Contains(contentType, "application/json") {
// process POST JSON request body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
c.Next()
return
}
// 还原请求体
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 将请求体解析为 JSON
var jsonData map[string]any
if err := c.ShouldBindJSON(&jsonData); err != nil {
c.Next()
return
}
// 对 JSON 数据中的字符串值去除两端空格
trimJSONStrings(jsonData)
// 更新请求体
c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData)))
}
c.Next()
}
}
// 递归对 JSON 数据中的字符串值去除两端空格
func trimJSONStrings(data any) {
switch v := data.(type) {
case map[string]any:
for key, value := range v {
switch valueType := value.(type) {
case string:
v[key] = strings.TrimSpace(valueType)
case map[string]any, []any:
trimJSONStrings(value)
}
}
case []any:
for i, value := range v {
switch valueType := value.(type) {
case string:
v[i] = strings.TrimSpace(valueType)
case map[string]any, []any:
trimJSONStrings(value)
}
}
}
}

View File

@@ -0,0 +1,43 @@
package middleware
import (
"context"
"fmt"
"geekai/core/types"
"geekai/utils"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)
// RateLimitEvery 使用 Redis 做固定间隔限流:在 interval 内仅允许一次请求
// Key 优先使用登录用户ID若没有则退化为 route + IP
func RateLimitEvery(redisClient *redis.Client, interval time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
keyID := ""
if userID, ok := c.Get(types.LoginUserID); ok {
keyID = fmt.Sprintf("user:%s", utils.InterfaceToString(userID))
} else {
keyID = fmt.Sprintf("ip:%s", c.ClientIP())
}
fullPath := c.FullPath()
if fullPath == "" {
fullPath = c.Request.URL.Path
}
key := fmt.Sprintf("rl:%s:%s", fullPath, keyID)
okSet, err := redisClient.SetNX(context.Background(), key, 1, interval).Result()
if err != nil {
// Redis 异常时放行,避免误伤可用性
return
}
if !okSet {
c.JSON(http.StatusTooManyRequests, types.BizVo{Code: types.Failed, Message: "请求过于频繁,请稍后重试"})
c.Abort()
return
}
}
}

View File

@@ -0,0 +1,78 @@
package middleware
import (
"bytes"
"geekai/utils"
"image"
"image/jpeg"
"net/http"
"os"
"strings"
"github.com/gin-gonic/gin"
"github.com/nfnt/resize"
"golang.org/x/image/webp"
)
// 静态资源中间件
func StaticMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
url := c.Request.URL.String()
// 拦截生成缩略图请求
if strings.HasPrefix(url, "/static/") && strings.Contains(url, "?imageView2") {
r := strings.SplitAfter(url, "imageView2")
size := strings.Split(r[1], "/")
if len(size) != 8 {
c.String(http.StatusNotFound, "invalid thumb args")
return
}
with := utils.IntValue(size[3], 0)
height := utils.IntValue(size[5], 0)
quality := utils.IntValue(size[7], 75)
// 打开图片文件
filePath := strings.TrimLeft(c.Request.URL.Path, "/")
file, err := os.Open(filePath)
if err != nil {
c.String(http.StatusNotFound, "Image not found")
return
}
defer file.Close()
// 解码图片
img, _, err := image.Decode(file)
// for .webp image
if err != nil {
img, err = webp.Decode(file)
}
if err != nil {
c.String(http.StatusInternalServerError, "Error decoding image")
return
}
var newImg image.Image
if height == 0 || with == 0 {
// 固定宽度,高度自适应
newImg = resize.Resize(uint(with), uint(height), img, resize.Lanczos3)
} else {
// 生成缩略图
newImg = resize.Thumbnail(uint(with), uint(height), img, resize.Lanczos3)
}
var buffer bytes.Buffer
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
if err != nil {
c.String(http.StatusInternalServerError, err.Error())
return
}
// 设置图片缓存有效期为一年 (365天)
c.Header("Cache-Control", "max-age=31536000, public")
// 直接输出图像数据流
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
c.Abort() // 中断请求
}
c.Next()
}
}

View File

@@ -17,88 +17,17 @@ type AppConfig struct {
Session Session
AdminSession Session
ProxyURL string
MysqlDns string // mysql 连接地址
StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息
ApiConfig ApiConfig // ChatPlus API authorization configs
SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config
SmtpConfig SmtpConfig // 邮件发送配置
XXLConfig XXLConfig
AlipayConfig AlipayConfig // 支付宝支付渠道配置
HuPiPayConfig HuPiPayConfig // 虎皮椒支付配置
GeekPayConfig GeekPayConfig // GEEK 支付配置
WechatPayConfig WechatPayConfig // 微信支付渠道配置
TikaHost string // TiKa 服务器地址
}
type SmtpConfig struct {
UseTls bool // 是否使用 TLS 发送
Host string
Port int
AppName string // 应用名称
From string // 发件人邮箱地址
Password string // 发件人邮箱密码
}
type ApiConfig struct {
ApiURL string
AppId string
Token string
JimengConfig JimengConfig // 即梦AI配置
}
type AlipayConfig struct {
Enabled bool // 是否启用该支付通道
SandBox bool // 是否沙盒环境
AppId string // 应用 ID
UserId string // 支付宝用户 ID
PrivateKey string // 用户私钥文件路径
PublicKey string // 用户公钥文件路径
AlipayPublicKey string // 支付宝公钥文件路径
RootCert string // Root 秘钥路径
NotifyURL string // 异步通知地址
ReturnURL string // 同步通知地址
}
type WechatPayConfig struct {
Enabled bool // 是否启用该支付通道
AppId string // 公众号的APPID,如wxd678efh567hg6787
MchId string // 直连商户的商户号,由微信支付生成并下发
SerialNo string // 商户证书的证书序列号
PrivateKey string // 用户私钥文件路径
ApiV3Key string // API V3 秘钥
NotifyURL string // 异步通知地址
}
type HuPiPayConfig struct { //虎皮椒第四方支付配置
Enabled bool // 是否启用该支付通道
AppId string // App ID
AppSecret string // app 密钥
ApiURL string // 支付网关
NotifyURL string // 异步通知地址
ReturnURL string // 同步通知地址
}
// GeekPayConfig GEEK支付配置
type GeekPayConfig struct {
Enabled bool
AppId string // 商户 ID
PrivateKey string // 私钥
ApiURL string // API 网关
NotifyURL string // 异步通知地址
ReturnURL string // 同步通知地址
Methods []string // 支付方式
}
type XXLConfig struct { // XXL 任务调度配置
Enabled bool
ServerAddr string
ExecutorIp string
ExecutorPort string
AccessToken string
RegistryKey string
MysqlDns string // mysql 连接地址
StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息
SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config
SmtpConfig SmtpConfig // 邮件发送配置
AlipayConfig AlipayConfig // 支付宝支付渠道配置
GeekPayConfig EpayConfig // GEEK 支付配置
WechatPayConfig WxPayConfig // 微信支付渠道配置
TikaHost string // TiKa 服务器地址
}
type RedisConfig struct {
@@ -128,32 +57,28 @@ func (c RedisConfig) Url() string {
return fmt.Sprintf("%s:%d", c.Host, c.Port)
}
type SystemConfig struct {
Title string `json:"title,omitempty"` // 网站标题
Slogan string `json:"slogan,omitempty"` // 网站 slogan
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
Logo string `json:"logo,omitempty"` // 圆形 Logo
BarLogo string `json:"bar_logo,omitempty"` // 条形 Logo
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
type BaseConfig struct {
Title string `json:"title,omitempty"` // 网站标题
Slogan string `json:"slogan,omitempty"` // 网站 slogan
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
Logo string `json:"logo,omitempty"` // 圆形 Logo
BarLogo string `json:"bar_logo,omitempty"` // 条形 Logo
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式支持手机mobile邮箱注册email账号密码注册
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间,单位:分钟
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
KeLingPowers map[string]int `json:"keling_powers,omitempty"` // 可灵生成视频消耗算力
AdvanceVoicePower int `json:"advance_voice_power,omitempty"` // 高级语音对话消耗算力
PromptPower int `json:"prompt_power,omitempty"` // 生成提示词消耗算力
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
@@ -163,15 +88,44 @@ type SystemConfig struct {
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
MjMode string `json:"mj_mode"` // midjourney 默认的API模式relax, fast, turbo
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
Copyright string `json:"copyright"` // 版权信息
DefaultNickname string `json:"default_nickname"` // 默认昵称
ICP string `json:"icp"` // ICP 备案号
MarkMapText string `json:"mark_map_text"` // 思维导入的默认文本
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
Copyright string `json:"copyright"` // 版权信息
ICP string `json:"icp"` // ICP 备案号
GaBeian string `json:"ga_beian"` // 公安备案号
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
AssistantModelId int `json:"assistant_model_id"` // 用来做提示词,翻译的AI模型 id
MaxFileSize int `json:"max_file_size"` // 最大文件大小,单位MB
}
type SystemConfig struct {
Base BaseConfig
Payment PaymentConfig
OSS OSSConfig
SMS SMSConfig
SMTP SmtpConfig
Captcha CaptchaConfig
WxLogin WxLoginConfig
Jimeng JimengConfig
License License
Moderation ModerationConfig
}
// 配置键名常量
const (
ConfigKeySystem = "system"
ConfigKeyNotice = "notice"
ConfigKeyAgreement = "agreement"
ConfigKeyPrivacy = "privacy"
ConfigKeyMarkMap = "mark_map"
ConfigKeyCaptcha = "captcha"
ConfigKeyWxLogin = "wx_login"
ConfigKeyLicense = "license"
ConfigKeySms = "sms"
ConfigKeySmtp = "smtp"
ConfigKeyOss = "oss"
ConfigKeyPayment = "payment"
ConfigKeyModeration = "moderation"
ConfigKeyAI3D = "ai3d"
ConfigKeyJimeng = "jimeng"
)

33
api/core/types/geekai.go Normal file
View File

@@ -0,0 +1,33 @@
package types
import "os"
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// GeekAI 增值服务
var GeekAPIURL = "https://sapi.geekai.me"
func init() {
if os.Getenv("GEEK_API_URL") != "" {
GeekAPIURL = os.Getenv("GEEK_API_URL")
}
}
// CaptchaConfig 行为验证码配置
type CaptchaConfig struct {
ApiKey string `json:"api_key"`
Type string `json:"type"` // 验证码类型, 可选值: "dot" 或 "slide"
Enabled bool `json:"enabled"`
}
// WxLoginConfig 微信登录配置
type WxLoginConfig struct {
ApiKey string `json:"api_key"`
NotifyURL string `json:"notify_url"` // 登录成功回调 URL
Enabled bool `json:"enabled"` // 是否启用微信登录
}

View File

@@ -0,0 +1,73 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// 文本审查
type ModerationConfig struct {
Enable bool `json:"enable"` // 是否启用文本审查
Active string `json:"active"`
EnableGuide bool `json:"enable_guide"` // 是否启用模型引导提示词
GuidePrompt string `json:"guide_prompt"` // 模型引导提示词
Gitee ModerationGiteeConfig `json:"gitee"`
Baidu ModerationBaiduConfig `json:"baidu"`
Tencent ModerationTencentConfig `json:"tencent"`
}
const (
ModerationGitee = "gitee"
ModerationBaidu = "baidu"
ModerationTencent = "tencent"
)
// GiteeAI 文本审查配置
type ModerationGiteeConfig struct {
ApiKey string `json:"api_key"`
Model string `json:"model"` // 文本审核模型
}
// 百度文本审查配置
type ModerationBaiduConfig struct {
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
}
// 腾讯云文本审查配置
type ModerationTencentConfig struct {
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
}
type ModerationResult struct {
Flagged bool `json:"flagged"`
Categories map[string]bool `json:"categories"`
CategoryScores map[string]float64 `json:"category_scores"`
}
var ModerationCategories = map[string]string{
"politic": "内容涉及人物、事件或敏感的政治观点",
"porn": "明确的色情内容",
"insult": "具有侮辱、攻击性语言、人身攻击或冒犯性表达",
"violence": "包含暴力、血腥、攻击行为或煽动暴力的言论",
"illegal": "涉及违法活动的内容,如诈骗、赌博等",
"terror": "宣扬恐怖主义、极端暴力或煽动恐怖行为的内容",
"ad": "垃圾广告或未经许可的推广内容",
"spam": "无意义重复内容或诱导性信息",
"abuse": "人身攻击、恶意辱骂或侮辱性言论",
"polity": "涉及国家政治、领导人或政策的违规讨论内容",
}
// 敏感词来源
const (
ModerationSourceChat = "chat"
ModerationSourceMJ = "mj"
ModerationSourceDalle = "dalle"
ModerationSourceSD = "sd"
ModerationSourceSuno = "suno"
ModerationSourceVideo = "video"
ModerationSourceJiMeng = "jimeng"
)

View File

@@ -11,29 +11,25 @@ type OrderStatus int
const (
OrderNotPaid = OrderStatus(0)
OrderScanned = OrderStatus(1) // 已扫码
OrderPaidSuccess = OrderStatus(2)
OrderPaidSuccess = OrderStatus(2) // 已支付
OrderPaidFailed = OrderStatus(3) // 已关闭
)
type OrderRemark struct {
Days int `json:"days"` // 有效期
Power int `json:"power"` // 增加算力点数
Name string `json:"name"` // 产品名称
Price float64 `json:"price"`
Discount float64 `json:"discount"`
Days int `json:"days"` // 有效期
Power int `json:"power"` // 增加算力点数
Name string `json:"name"` // 产品名称
Price float64 `json:"price"`
}
var PayMethods = map[string]string{
// PayChannel 支付渠道
var PayChannel = map[string]string{
"alipay": "支付宝商号",
"wechat": "微信商号",
"hupi": "虎皮椒",
"geek": "易支付",
"wxpay": "微信商号",
"epay": "易支付",
}
var PayNames = map[string]string{
var PayWays = map[string]string{
"alipay": "支付宝",
"wxpay": "微信支付",
"qqpay": "QQ钱包",
"jdpay": "京东支付",
"douyin": "抖音支付",
"paypal": "PayPal支付",
}

View File

@@ -8,41 +8,39 @@ package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type OSSConfig struct {
Active string
Local LocalStorageConfig
Minio MiniOssConfig
QiNiu QiNiuOssConfig
AliYun AliYunOssConfig
Active string `json:"active"`
Local LocalStorageConfig `json:"local"`
Minio MiniOssConfig `json:"minio"`
QiNiu QiNiuOssConfig `json:"qiniu"`
AliYun AliYunOssConfig `json:"aliyun"`
}
type MiniOssConfig struct {
Endpoint string
AccessKey string
AccessSecret string
Bucket string
SubDir string
UseSSL bool
Domain string
Endpoint string `json:"endpoint"`
AccessKey string `json:"access_key"`
AccessSecret string `json:"access_secret"`
Bucket string `json:"bucket"`
UseSSL bool `json:"use_ssl"`
Domain string `json:"domain"`
}
type QiNiuOssConfig struct {
Zone string
AccessKey string
AccessSecret string
Bucket string
SubDir string
Domain string
Zone string `json:"zone"`
AccessKey string `json:"access_key"`
AccessSecret string `json:"access_secret"`
Bucket string `json:"bucket"`
Domain string `json:"domain"`
}
type AliYunOssConfig struct {
Endpoint string
AccessKey string
AccessSecret string
Bucket string
SubDir string
Domain string
Endpoint string `json:"endpoint"`
AccessKey string `json:"access_key"`
AccessSecret string `json:"access_secret"`
Bucket string `json:"bucket"`
Domain string `json:"domain"`
}
type LocalStorageConfig struct {
BasePath string
BaseURL string
BasePath string `json:"base_path"`
BaseURL string `json:"base_url"`
}

60
api/core/types/payment.go Normal file
View File

@@ -0,0 +1,60 @@
package types
type PaymentConfig struct {
Alipay AlipayConfig `json:"alipay"` // 支付宝支付渠道配置
Epay EpayConfig `json:"epay"` // 易支付配置
WxPay WxPayConfig `json:"wxpay"` // 微信支付渠道配置
}
// AlipayConfig 支付宝支付配置
type AlipayConfig struct {
Enabled bool `json:"enabled"` // 是否启用该支付通道
SandBox bool `json:"sandbox"` // 是否沙盒环境
AppId string `json:"app_id"` // 应用 ID
PrivateKey string `json:"private_key"` // 应用私钥
AlipayPublicKey string `json:"alipay_public_key"` // 支付宝公钥
Domain string `json:"domain"` // 支付回调域名
}
func (c *AlipayConfig) Equal(other *AlipayConfig) bool {
return c.AppId == other.AppId &&
c.PrivateKey == other.PrivateKey &&
c.AlipayPublicKey == other.AlipayPublicKey &&
c.Domain == other.Domain
}
// WxPayConfig 微信支付配置
type WxPayConfig struct {
Enabled bool `json:"enabled"` // 是否启用该支付通道
AppId string `json:"app_id"` // 公众号的APPID,如wxd678efh567hg6787
MchId string `json:"mch_id"` // 直连商户的商户号,由微信支付生成并下发
SerialNo string `json:"serial_no"` // 商户证书的证书序列号
PrivateKey string `json:"private_key"` // 商户证书私钥
ApiV3Key string `json:"api_v3_key"` // API V3 秘钥
Domain string `json:"domain"` // 支付回调域名
}
func (c *WxPayConfig) Equal(other *WxPayConfig) bool {
return c.AppId == other.AppId &&
c.MchId == other.MchId &&
c.SerialNo == other.SerialNo &&
c.PrivateKey == other.PrivateKey &&
c.ApiV3Key == other.ApiV3Key &&
c.Domain == other.Domain
}
// EpayConfig 易支付配置
type EpayConfig struct {
Enabled bool `json:"enabled"` // 是否启用该支付通道
AppId string `json:"app_id"` // 商户 ID
PrivateKey string `json:"private_key"` // 私钥
ApiURL string `json:"api_url"` // z支付 API 网关
Domain string `json:"domain"` // 支付回调域名
}
func (c *EpayConfig) Equal(other *EpayConfig) bool {
return c.AppId == other.AppId &&
c.PrivateKey == other.PrivateKey &&
c.ApiURL == other.ApiURL &&
c.Domain == other.Domain
}

View File

@@ -8,6 +8,7 @@ package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
const LoginUserID = "LOGIN_USER_ID"
const AdminUserID = "ADMIN_USER_ID"
const LoginUserCache = "LOGIN_USER_CACHE"
const UserAuthHeader = "Authorization"

View File

@@ -8,26 +8,23 @@ package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type SMSConfig struct {
Active string
Ali SmsConfigAli
Bao SmsConfigBao
Active string `json:"active"`
Ali SmsConfigAli `json:"aliyun"`
Bao SmsConfigBao `json:"bao"`
}
// SmsConfigAli 阿里云短信平台配置
type SmsConfigAli struct {
AccessKey string
AccessSecret string
Product string
Domain string
Sign string // 短信签名
CodeTempId string // 验证码短信模板 ID
AccessKey string `json:"access_key"`
AccessSecret string `json:"access_secret"`
Sign string `json:"sign"` // 短信签名
CodeTempId string `json:"code_temp_id"` // 验证码短信模板 ID
}
// SmsConfigBao 短信宝平台配置
type SmsConfigBao struct {
Username string //短信宝平台注册的用户名
Password string //短信宝平台注册的密码
Domain string //域
Sign string // 短信签名
CodeTemplate string // 验证码短信模板 匹配
Username string `json:"username"` //短信宝平台注册的用户名
Password string `json:"password"` //短信宝平台注册的密码
Sign string `json:"sign"` // 短信签
CodeTemplate string `json:"code_template"` // 验证码短信模板 匹配
}

26
api/core/types/smtp.go Normal file
View File

@@ -0,0 +1,26 @@
package types
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
type SmtpConfig struct {
UseTls bool `json:"use_tls"` // 是否使用 TLS 发送
Host string `json:"host"` // 邮件服务器地址
Port int `json:"port"` // 邮件服务器端口
AppName string `json:"app_name"` // 应用名称
From string `json:"from"` // 发件人邮箱地址
Password string `json:"password"` // 发件人邮箱密码
}
func (s *SmtpConfig) Equal(other *SmtpConfig) bool {
return s.UseTls == other.UseTls &&
s.Host == other.Host &&
s.Port == other.Port &&
s.AppName == other.AppName &&
s.From == other.From &&
s.Password == other.Password
}

View File

@@ -70,17 +70,18 @@ type SdTaskParams struct {
// DallTask DALL-E task
type DallTask struct {
ModelId uint `json:"model_id"`
ModelName string `json:"model_name"`
Id uint `json:"id"`
UserId uint `json:"user_id"`
Prompt string `json:"prompt"`
N int `json:"n"`
Quality string `json:"quality"`
Size string `json:"size"`
Style string `json:"style"`
Power int `json:"power"`
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
ModelId uint `json:"model_id"`
ModelName string `json:"model_name"`
Image []string `json:"image,omitempty"`
Id uint `json:"id"`
UserId uint `json:"user_id"`
Prompt string `json:"prompt"`
N int `json:"n"`
Quality string `json:"quality"`
Size string `json:"size"`
Style string `json:"style"`
Power int `json:"power"`
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
}
type SunoTask struct {

View File

@@ -4,7 +4,7 @@ build_name: runner-build
build_log: runner-build-errors.log
valid_ext: .go, .tpl, .tmpl, .html
no_rebuild_ext: .tpl, .tmpl, .html, .js, .vue
ignored: assets, tmp, web, .git, .idea, test, data
ignored: assets, tmp, web, .git, .idea, test, data, static
build_delay: 600
colors: 1
log_color_main: cyan

View File

@@ -24,11 +24,9 @@ require (
gorm.io/driver/mysql v1.4.7
)
require github.com/xxl-job/xxl-job-executor-go v1.2.0
require (
github.com/go-pay/gopay v1.5.101
github.com/go-rod/rod v0.116.2
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/go-tika v0.3.1
github.com/microcosm-cc/bluemonday v1.0.26
github.com/sashabaranov/go-openai v1.38.1
@@ -50,11 +48,6 @@ require (
github.com/gorilla/css v1.0.0 // indirect
github.com/tklauser/go-sysconf v0.3.13 // indirect
github.com/tklauser/numcpus v0.7.0 // indirect
github.com/ysmood/fetchup v0.3.0 // indirect
github.com/ysmood/goob v0.4.0 // indirect
github.com/ysmood/got v0.40.0 // indirect
github.com/ysmood/gson v0.7.3 // indirect
github.com/ysmood/leakless v0.9.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.uber.org/mock v0.4.0 // indirect
)
@@ -69,7 +62,6 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gaukas/godicttls v0.0.3 // indirect
github.com/go-basic/ipv4 v1.0.0 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/goccy/go-json v0.10.2 // indirect

View File

@@ -46,8 +46,6 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs=
github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
@@ -80,8 +78,6 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-rod/rod v0.116.2 h1:A5t2Ky2A+5eD/ZJQr1EfsQSe5rms5Xof/qj296e+ZqA=
github.com/go-rod/rod v0.116.2/go.mod h1:H+CMO9SCNc2TJ2WfrG+pKhITz57uGNYU43qYHh438Mg=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
@@ -89,6 +85,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
@@ -261,22 +259,6 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8=
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ=
github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo=
github.com/ysmood/fetchup v0.3.0 h1:UhYz9xnLEVn2ukSuK3KCgcznWpHMdrmbsPpllcylyu8=
github.com/ysmood/fetchup v0.3.0/go.mod h1:hbysoq65PXL0NQeNzUczNYIKpwpkwFL4LXMDEvIQq9A=
github.com/ysmood/goob v0.4.0 h1:HsxXhyLBeGzWXnqVKtmT9qM7EuVs/XOgkX7T6r1o1AQ=
github.com/ysmood/goob v0.4.0/go.mod h1:u6yx7ZhS4Exf2MwciFr6nIM8knHQIE22lFpWHnfql18=
github.com/ysmood/gop v0.2.0 h1:+tFrG0TWPxT6p9ZaZs+VY+opCvHU8/3Fk6BaNv6kqKg=
github.com/ysmood/gop v0.2.0/go.mod h1:rr5z2z27oGEbyB787hpEcx4ab8cCiPnKxn0SUHt6xzk=
github.com/ysmood/got v0.40.0 h1:ZQk1B55zIvS7zflRrkGfPDrPG3d7+JOza1ZkNxcc74Q=
github.com/ysmood/got v0.40.0/go.mod h1:W7DdpuX6skL3NszLmAsC5hT7JAhuLZhByVzHTq874Qg=
github.com/ysmood/gotrace v0.6.0 h1:SyI1d4jclswLhg7SWTL6os3L1WOKeNn/ZtzVQF8QmdY=
github.com/ysmood/gotrace v0.6.0/go.mod h1:TzhIG7nHDry5//eYZDYcTzuJLYQIkykJzCRIo4/dzQM=
github.com/ysmood/gson v0.7.3 h1:QFkWbTH8MxyUTKPkVWAENJhxqdBa4lYTQWqZCiLG6kE=
github.com/ysmood/gson v0.7.3/go.mod h1:3Kzs5zDl21g5F/BlLTNcuAGAYLKt2lV5G8D1zF3RNmg=
github.com/ysmood/leakless v0.9.0 h1:qxCG5VirSBvmi3uynXFkcnLMzkphdh3xx5FtrORwDCU=
github.com/ysmood/leakless v0.9.0/go.mod h1:R8iAXPRaG97QJwqxs74RdwzcRHT1SWCGTNqY8q0JvMQ=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=

View File

@@ -11,6 +11,7 @@ import (
"context"
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/handler"
logger2 "geekai/logger"
@@ -19,9 +20,10 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"time"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -45,6 +47,26 @@ func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client, cap
}
}
// RegisterRoutes 注册路由
func (h *ManagerHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/")
// 公开接口,不需要授权
group.POST("login", h.Login)
group.GET("logout", h.Logout)
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.GET("session", h.Session)
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("enable", h.Enable)
group.GET("remove", h.Remove)
group.POST("resetPass", h.ResetPass)
}
}
// Login 登录
func (h *ManagerHandler) Login(c *gin.Context) {
var data struct {
@@ -59,19 +81,6 @@ func (h *ManagerHandler) Login(c *gin.Context) {
return
}
if h.App.SysConfig.EnabledVerify {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
}
var manager model.AdminUser
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
if res.Error != nil {
@@ -135,16 +144,15 @@ func (h *ManagerHandler) Logout(c *gin.Context) {
// Session 会话检测
func (h *ManagerHandler) Session(c *gin.Context) {
id := h.GetLoginUserId(c)
key := fmt.Sprintf("admin/%d", id)
if _, err := h.redis.Get(context.Background(), key).Result(); err != nil {
resp.NotAuth(c)
id := h.GetAdminId(c)
if id == 0 {
resp.NotAuth(c, "当前用户已退出登录")
return
}
var manager model.AdminUser
res := h.DB.Where("id", id).First(&manager)
if res.Error != nil {
resp.NotAuth(c)
err := h.DB.Where("id", id).First(&manager).Error
if err != nil {
resp.NotAuth(c, "当前用户已退出登录")
return
}

View File

@@ -10,6 +10,7 @@ package admin
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
@@ -30,6 +31,20 @@ func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}}
}
// RegisterRoutes 注册路由
func (h *ApiKeyHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/apikey/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
}
}
func (h *ApiKeyHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`

View File

@@ -10,6 +10,7 @@ package admin
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
@@ -30,14 +31,29 @@ func NewChatAppHandler(app *core.AppServer, db *gorm.DB) *ChatAppHandler {
return &ChatAppHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *ChatAppHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/role/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("sort", h.Sort)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
}
}
// Save 创建或者更新某个角色
func (h *ChatAppHandler) Save(c *gin.Context) {
var data vo.ChatRole
var data vo.ChatApp
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var role model.ChatRole
var role model.ChatApp
err := utils.CopyObject(data, &role)
if err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -65,8 +81,8 @@ func (h *ChatAppHandler) Save(c *gin.Context) {
}
func (h *ChatAppHandler) List(c *gin.Context) {
var items []model.ChatRole
var roles = make([]vo.ChatRole, 0)
var items []model.ChatApp
var roles = make([]vo.ChatApp, 0)
res := h.DB.Order("sort_num ASC").Find(&items)
if res.Error != nil {
resp.ERROR(c, "No data found")
@@ -107,7 +123,7 @@ func (h *ChatAppHandler) List(c *gin.Context) {
}
for _, v := range items {
var role vo.ChatRole
var role vo.ChatApp
err := utils.CopyObject(v, &role)
if err == nil {
role.Id = v.Id
@@ -135,7 +151,7 @@ func (h *ChatAppHandler) Sort(c *gin.Context) {
}
for index, id := range data.Ids {
err := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
err := h.DB.Model(&model.ChatApp{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
if err != nil {
resp.ERROR(c, err.Error())
return
@@ -157,7 +173,7 @@ func (h *ChatAppHandler) Set(c *gin.Context) {
return
}
err := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
err := h.DB.Model(&model.ChatApp{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
if err != nil {
resp.ERROR(c, err.Error())
return
@@ -172,9 +188,8 @@ func (h *ChatAppHandler) Remove(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Where("id", id).Delete(&model.ChatRole{})
res := h.DB.Where("id", id).Delete(&model.ChatApp{})
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "删除失败!")
return
}

View File

@@ -2,12 +2,14 @@ package admin
import (
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
@@ -20,6 +22,21 @@ func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler
return &ChatAppTypeHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *ChatAppTypeHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/app/type/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.POST("save", h.Save)
group.GET("remove", h.Remove)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
}
}
// Save 创建或更新App类型
func (h *ChatAppTypeHandler) Save(c *gin.Context) {
var data struct {

View File

@@ -9,6 +9,7 @@ package admin
import (
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
@@ -28,16 +29,31 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *ChatHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/chat/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.POST("list", h.List)
group.POST("message", h.Messages)
group.GET("history", h.History)
group.GET("remove", h.RemoveChat)
group.GET("message/remove", h.RemoveMessage)
}
}
type chatItemVo struct {
Username string `json:"username"`
UserId uint `json:"user_id"`
ChatId string `json:"chat_id"`
Title string `json:"title"`
Role vo.ChatRole `json:"role"`
Model string `json:"model"`
Token int `json:"token"`
CreatedAt int64 `json:"created_at"`
MsgNum int `json:"msg_num"` // 消息数量
Username string `json:"username"`
UserId uint `json:"user_id"`
ChatId string `json:"chat_id"`
Title string `json:"title"`
Role vo.ChatApp `json:"role"`
Model string `json:"model"`
Token int `json:"token"`
CreatedAt int64 `json:"created_at"`
MsgNum int `json:"msg_num"` // 消息数量
}
func (h *ChatHandler) List(c *gin.Context) {
@@ -87,7 +103,7 @@ func (h *ChatHandler) List(c *gin.Context) {
}
var messages []model.ChatMessage
var users []model.User
var roles []model.ChatRole
var roles []model.ChatApp
h.DB.Where("chat_id IN ?", chatIds).Find(&messages)
h.DB.Where("id IN ?", userIds).Find(&users)
h.DB.Where("id IN ?", roleIds).Find(&roles)
@@ -95,7 +111,7 @@ func (h *ChatHandler) List(c *gin.Context) {
tokenMap := make(map[string]int)
userMap := make(map[uint]string)
msgMap := make(map[string]int)
roleMap := make(map[uint]vo.ChatRole)
roleMap := make(map[uint]vo.ChatApp)
for _, msg := range messages {
tokenMap[msg.ChatId] += msg.Tokens
msgMap[msg.ChatId] += 1
@@ -104,7 +120,7 @@ func (h *ChatHandler) List(c *gin.Context) {
userMap[user.Id] = user.Username
}
for _, r := range roles {
var roleVo vo.ChatRole
var roleVo vo.ChatApp
err := utils.CopyObject(r, &roleVo)
if err != nil {
continue

View File

@@ -8,7 +8,9 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
@@ -28,6 +30,22 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *ChatModelHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/model/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("set", h.Set)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
group.POST("batch-remove", h.BatchRemove)
}
}
func (h *ChatModelHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
@@ -201,3 +219,33 @@ func (h *ChatModelHandler) Remove(c *gin.Context) {
}
resp.SUCCESS(c)
}
// BatchRemove 批量删除模型
func (h *ChatModelHandler) BatchRemove(c *gin.Context) {
var data struct {
Ids []uint `json:"ids"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if len(data.Ids) == 0 {
resp.ERROR(c, "请选择要删除的模型")
return
}
// 执行批量删除
err := h.DB.Where("id IN ?", data.Ids).Delete(&model.ChatModel{}).Error
if err != nil {
logger.Error("批量删除模型失败:", err)
resp.ERROR(c, "批量删除失败:"+err.Error())
return
}
resp.SUCCESS(c, gin.H{
"message": fmt.Sprintf("成功删除 %d 个模型", len(data.Ids)),
"deleted_count": len(data.Ids),
})
}

View File

@@ -9,106 +9,399 @@ package admin
import (
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/store"
"geekai/service/oss"
"geekai/service/payment"
"geekai/service/sms"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/shirou/gopsutil/host"
"gorm.io/gorm"
)
type ConfigHandler struct {
handler.BaseHandler
levelDB *store.LevelDB
licenseService *service.LicenseService
licenseService *service.LicenseService
sysConfig *types.SystemConfig
alipayService *payment.AlipayService
wxpayService *payment.WxPayService
epayService *payment.EPayService
smsManager *sms.SmsManager
uploaderManager *oss.UploaderManager
smtpService *service.SmtpService
captchaService *service.CaptchaService
wxLoginService *service.WxLoginService
}
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler {
func NewConfigHandler(
app *core.AppServer,
db *gorm.DB,
licenseService *service.LicenseService,
sysConfig *types.SystemConfig,
alipayService *payment.AlipayService,
wxpayService *payment.WxPayService,
epayService *payment.EPayService,
smsManager *sms.SmsManager,
uploaderManager *oss.UploaderManager,
smtpService *service.SmtpService,
captchaService *service.CaptchaService,
wxLoginService *service.WxLoginService,
) *ConfigHandler {
return &ConfigHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
levelDB: levelDB,
licenseService: licenseService,
BaseHandler: handler.BaseHandler{App: app, DB: db},
licenseService: licenseService,
sysConfig: sysConfig,
alipayService: alipayService,
wxpayService: wxpayService,
epayService: epayService,
smsManager: smsManager,
uploaderManager: uploaderManager,
smtpService: smtpService,
captchaService: captchaService,
wxLoginService: wxLoginService,
}
}
func (h *ConfigHandler) Update(c *gin.Context) {
var data struct {
Key string `json:"key"`
Config struct {
types.SystemConfig
Content string `json:"content,omitempty"`
Updated bool `json:"updated,omitempty"`
} `json:"config"`
ConfigBak types.SystemConfig `json:"config_bak,omitempty"`
// RegisterRoutes 注册路由
func (h *ConfigHandler) RegisterRoutes() {
rg := h.App.Engine.Group("/api/admin/config")
// 需要管理员登录的接口
rg.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
rg.POST("update/base", h.UpdateBase)
rg.POST("update/power", h.UpdatePower)
rg.POST("update/notice", h.UpdateNotice)
rg.POST("update/agreement", h.UpdateAgreement)
rg.POST("update/privacy", h.UpdatePrivacy)
rg.POST("update/mark_map", h.UpdateMarkMap)
rg.POST("update/captcha", h.UpdateCaptcha)
rg.POST("update/wx_login", h.UpdateWxLogin)
rg.POST("update/payment", h.UpdatePayment)
rg.POST("update/sms", h.UpdateSms)
rg.POST("update/oss", h.UpdateOss)
rg.POST("update/smtp", h.UpdateStmp)
rg.GET("get", h.Get)
rg.POST("license/active", h.Active)
rg.GET("license/get", h.GetLicense)
}
}
// UpdateBase 更新基础配置
func (h *ConfigHandler) UpdateBase(c *gin.Context) {
var data types.BaseConfig
if err := c.ShouldBindJSON(&data); err != nil {
logger.Errorf("Update config failed: %v", err)
resp.ERROR(c, types.InvalidArgs)
return
}
// ONLY authorized user can change the copyright
if (data.Key == "system" && data.Config.Copyright != data.ConfigBak.Copyright) && !h.licenseService.GetLicense().Configs.DeCopy {
resp.ERROR(c, "您无权修改版权信息,请先联系作者获取授权")
// 未授权的话不允许修改版权
license := h.licenseService.GetLicense()
if !license.IsActive && data.Copyright != h.sysConfig.Base.Copyright {
resp.ERROR(c, "未授权系统不允许修改版权信息")
return
}
// 如果要启用图形验证码功能,则检查是否配置了 API 服务
if data.Config.EnabledVerify && h.App.Config.ApiConfig.AppId == "" {
resp.ERROR(c, "启用验证码服务需要先配置 GeekAI 官方 API 服务 AppId 和 Token")
// 未授权的话不允许修改 Logo
if !license.IsActive && data.Logo != h.sysConfig.Base.Logo {
resp.ERROR(c, "未授权系统不允许修改 Logo")
return
}
value := utils.JsonEncode(&data.Config)
config := model.Config{Name: data.Key, Value: value}
res := h.DB.FirstOrCreate(&config, model.Config{Name: data.Key})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
err := h.Update(types.ConfigKeySystem, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if config.Id > 0 {
config.Value = value
res := h.DB.Updates(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
h.sysConfig.Base = data
// update config cache for AppServer
var cfg model.Config
h.DB.Where("name", data.Key).First(&cfg)
var err error
if data.Key == "system" {
err = utils.JsonDecode(cfg.Value, &h.App.SysConfig)
}
if err != nil {
resp.ERROR(c, "Failed to update config cache: "+err.Error())
return
}
logger.Infof("Update AppServer's config successfully: %v", config.Value)
}
resp.SUCCESS(c, config)
resp.SUCCESS(c, data)
}
// Get 获取指定的系统配置
func (h *ConfigHandler) Get(c *gin.Context) {
key := c.Query("key")
// UpdatePower 更新系统配置
func (h *ConfigHandler) UpdatePower(c *gin.Context) {
var data struct {
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
KeLingPowers map[string]int `json:"keling_powers,omitempty"` // 可灵生成视频消耗算力
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
h.sysConfig.Base.InitPower = data.InitPower
h.sysConfig.Base.DailyPower = data.DailyPower
h.sysConfig.Base.InvitePower = data.InvitePower
h.sysConfig.Base.MjPower = data.MjPower
h.sysConfig.Base.MjActionPower = data.MjActionPower
h.sysConfig.Base.SdPower = data.SdPower
h.sysConfig.Base.SunoPower = data.SunoPower
h.sysConfig.Base.LumaPower = data.LumaPower
h.sysConfig.Base.KeLingPowers = data.KeLingPowers
err := h.Update(types.ConfigKeySystem, h.sysConfig.Base)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, h.sysConfig.Base)
}
// UpdateNotice 更新公告配置
func (h *ConfigHandler) UpdateNotice(c *gin.Context) {
var data struct {
Content string `json:"content"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyNotice, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, data)
}
// UpdateAgreement 更新用户协议配置
func (h *ConfigHandler) UpdateAgreement(c *gin.Context) {
var data struct {
Content string `json:"content"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyAgreement, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, data)
}
// UpdatePrivacy 更新隐私政策配置
func (h *ConfigHandler) UpdatePrivacy(c *gin.Context) {
var data struct {
Content string `json:"content"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyPrivacy, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, data)
}
// UpdateMarkMap 更新思维导图配置
func (h *ConfigHandler) UpdateMarkMap(c *gin.Context) {
var data struct {
Content string `json:"content"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyMarkMap, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, data)
}
// UpdateCaptcha 更新行为验证码配置
func (h *ConfigHandler) UpdateCaptcha(c *gin.Context) {
var data types.CaptchaConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyCaptcha, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
h.captchaService.UpdateConfig(data)
resp.SUCCESS(c, data)
}
// UpdatePayment 更新支付配置
func (h *ConfigHandler) UpdatePayment(c *gin.Context) {
var data types.PaymentConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyPayment, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 如果启用状态发生改变,则需要更新支付服务配置
if data.WxPay.Enabled {
err = h.wxpayService.UpdateConfig(&data.WxPay)
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
if data.Epay.Enabled {
h.epayService.UpdateConfig(&data.Epay)
}
if data.Alipay.Enabled {
err = h.alipayService.UpdateConfig(&data.Alipay)
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
h.sysConfig.Payment = data
resp.SUCCESS(c, data)
}
// UpdateSms 更新短信配置
func (h *ConfigHandler) UpdateSms(c *gin.Context) {
var data types.SMSConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeySms, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 更新服务配置
h.smsManager.UpdateConfig(data)
resp.SUCCESS(c, data)
}
// UpdateOss 更新 Oss 配置
func (h *ConfigHandler) UpdateOss(c *gin.Context) {
var data types.OSSConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyOss, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 更新服务配置
h.uploaderManager.UpdateConfig(data)
h.sysConfig.OSS = data
resp.SUCCESS(c, data)
}
// UpdateStmp 更新 Stmp 配置
func (h *ConfigHandler) UpdateStmp(c *gin.Context) {
var data types.SmtpConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeySmtp, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 更新服务配置
h.smtpService.UpdateConfig(&data)
h.sysConfig.SMTP = data
resp.SUCCESS(c, data)
}
// UpdateWxLogin 更新微信登录配置
func (h *ConfigHandler) UpdateWxLogin(c *gin.Context) {
var data types.WxLoginConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.Update(types.ConfigKeyWxLogin, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if data.Enabled {
h.wxLoginService.UpdateConfig(data)
}
h.sysConfig.WxLogin = data
resp.SUCCESS(c, data)
}
// Update 更新系统配置
func (h *ConfigHandler) Update(name string, value any) error {
var config model.Config
res := h.DB.Where("name", key).First(&config)
err := h.DB.Where("name", name).First(&config).Error
if err != nil { // 不存在则创建
config.Name = name
config.Value = utils.JsonEncode(value)
return h.DB.Create(&config).Error
} else { // 存在则更新
config.Value = utils.JsonEncode(value)
return h.DB.Updates(&config).Error
}
}
// Get 获取指定名称的系统配置
func (h *ConfigHandler) Get(c *gin.Context) {
name := c.Query("key")
var config model.Config
res := h.DB.Where("name", name).First(&config)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
var value map[string]interface{}
var value map[string]any
err := utils.JsonDecode(config.Value, &value)
if err != nil {
resp.ERROR(c, err.Error())
@@ -127,19 +420,21 @@ func (h *ConfigHandler) Active(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
info, err := host.Info()
err := h.licenseService.ActiveLicense(data.License)
license := h.licenseService.GetLicense()
if err != nil {
resp.ERROR(c, err.Error())
return
}
err = h.licenseService.ActiveLicense(data.License, info.HostID)
if err != nil {
if err := h.Update(types.ConfigKeyLicense, license); err != nil {
resp.ERROR(c, err.Error())
return
}
// 更新系统配置
h.sysConfig.License = *license
resp.SUCCESS(c)
resp.SUCCESS(c, license.MachineId)
}
@@ -148,69 +443,3 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
license := h.licenseService.GetLicense()
resp.SUCCESS(c, license)
}
// FixData 修复数据
func (h *ConfigHandler) FixData(c *gin.Context) {
resp.ERROR(c, "当前升级版本没有数据需要修正!")
//var fixed bool
//version := "data_fix_4.1.4"
//err := h.levelDB.Get(version, &fixed)
//if err == nil || fixed {
// resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作")
// return
//}
//tx := h.DB.Begin()
//var users []model.User
//err = tx.Find(&users).Error
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//for _, user := range users {
// if user.Email != "" || user.Mobile != "" {
// continue
// }
// if utils.IsValidEmail(user.Username) {
// user.Email = user.Username
// } else if utils.IsValidMobile(user.Username) {
// user.Mobile = user.Username
// }
// err = tx.Save(&user).Error
// if err != nil {
// resp.ERROR(c, err.Error())
// tx.Rollback()
// return
// }
//}
//
//var orders []model.Order
//err = h.DB.Find(&orders).Error
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//for _, order := range orders {
// if order.PayWay == "支付宝" {
// order.PayWay = "alipay"
// order.PayType = "alipay"
// } else if order.PayWay == "微信支付" {
// order.PayWay = "wechat"
// order.PayType = "wxpay"
// } else if order.PayWay == "hupi" {
// order.PayType = "wxpay"
// }
// err = tx.Save(&order).Error
// if err != nil {
// resp.ERROR(c, err.Error())
// tx.Rollback()
// return
// }
//}
//tx.Commit()
//err = h.levelDB.Put(version, true)
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//resp.SUCCESS(c)
}

View File

@@ -13,10 +13,11 @@ import (
"geekai/handler"
"geekai/store/model"
"geekai/utils/resp"
"time"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
"gorm.io/gorm"
"time"
)
type DashboardHandler struct {
@@ -27,46 +28,161 @@ func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler {
return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *DashboardHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/dashboard/")
group.GET("stats", h.Stats)
}
// statsVo 增加 recentOrders、recentUsers 字段
// 最近订单
type OrderBrief struct {
OrderNo string `json:"order_no"`
Amount float64 `json:"amount"`
CreatedAt time.Time `json:"created_at"`
}
// 最近用户
type UserBrief struct {
Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
LastActive time.Time `json:"last_active"`
}
type statsVo struct {
Users int64 `json:"users"`
Chats int64 `json:"chats"`
Tokens int `json:"tokens"`
Income float64 `json:"income"`
Chart map[string]map[string]float64 `json:"chart"`
Users int64 `json:"users"`
Chats int64 `json:"chats"`
Tokens int `json:"tokens"`
Income float64 `json:"income"`
Chart map[string]map[string]float64 `json:"chart"`
TodayUsers int64 `json:"todayUsers"`
TodayChats int64 `json:"todayChats"`
TodayTokens int `json:"todayTokens"`
TodayIncome float64 `json:"todayIncome"`
TodayOrders int64 `json:"todayOrders"`
TodayImageJobs int64 `json:"todayImageJobs"`
TodayVideoJobs int64 `json:"todayVideoJobs"`
TodayMusicJobs int64 `json:"todayMusicJobs"`
Orders int64 `json:"orders"`
ImageJobs int64 `json:"imageJobs"`
VideoJobs int64 `json:"videoJobs"`
MusicJobs int64 `json:"musicJobs"`
RecentOrders []OrderBrief `json:"recentOrders"`
RecentUsers []UserBrief `json:"recentUsers"`
}
func (h *DashboardHandler) Stats(c *gin.Context) {
stats := statsVo{}
// new users statistic
var userCount int64
now := time.Now()
zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
res := h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
if res.Error == nil {
stats.Users = userCount
// 总用户数
h.DB.Model(&model.User{}).Count(&stats.Users)
// 今日新增用户
h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&stats.TodayUsers)
// 总对话数
h.DB.Model(&model.ChatItem{}).Count(&stats.Chats)
// 今日新增对话
h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&stats.TodayChats)
// 总算力消耗
var powerLogs []model.PowerLog
h.DB.Where("mark = ?", types.PowerSub).Find(&powerLogs)
for _, item := range powerLogs {
stats.Tokens += item.Amount
}
// new chats statistic
var chatCount int64
res = h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
if res.Error == nil {
stats.Chats = chatCount
// 今日算力消耗
var todayPowerLogs []model.PowerLog
h.DB.Where("mark = ?", types.PowerSub).Where("created_at > ?", zeroTime).Find(&todayPowerLogs)
for _, item := range todayPowerLogs {
stats.TodayTokens += item.Amount
}
// tokens took stats
var historyMessages []model.ChatMessage
res = h.DB.Where("created_at > ?", zeroTime).Find(&historyMessages)
for _, item := range historyMessages {
stats.Tokens += item.Tokens
}
// 订单收入
var orders []model.Order
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
for _, item := range orders {
// 总收入
var allOrders []model.Order
h.DB.Where("status = ?", types.OrderPaidSuccess).Find(&allOrders)
for _, item := range allOrders {
stats.Income += item.Amount
}
// 今日收入
var todayOrders []model.Order
h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&todayOrders)
for _, item := range todayOrders {
stats.TodayIncome += item.Amount
}
// 订单总数
h.DB.Model(&model.Order{}).Where("status = ?", types.OrderPaidSuccess).Count(&stats.Orders)
// 今日订单数
h.DB.Model(&model.Order{}).Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Count(&stats.TodayOrders)
// 图片生成任务统计
var mjJobs, sdJobs, dallJobs, jimengImageJobs int64
h.DB.Model(&model.MidJourneyJob{}).Count(&mjJobs)
h.DB.Model(&model.SdJob{}).Count(&sdJobs)
h.DB.Model(&model.DallJob{}).Count(&dallJobs)
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_image", "image_to_image", "image_edit", "image_effects"}).Count(&jimengImageJobs)
stats.ImageJobs = mjJobs + sdJobs + dallJobs + jimengImageJobs
logger.Info("stats.ImageJobs", stats.ImageJobs)
// 今日图片生成任务统计
var todayMjJobs, todaySdJobs, todayDallJobs, todayJimengImageJobs int64
h.DB.Model(&model.MidJourneyJob{}).Where("created_at > ?", zeroTime).Count(&todayMjJobs)
h.DB.Model(&model.SdJob{}).Where("created_at > ?", zeroTime).Count(&todaySdJobs)
h.DB.Model(&model.DallJob{}).Where("created_at > ?", zeroTime).Count(&todayDallJobs)
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_image", "image_to_image", "image_edit", "image_effects"}).Where("created_at > ?", zeroTime).Count(&todayJimengImageJobs)
stats.TodayImageJobs = todayMjJobs + todaySdJobs + todayDallJobs + todayJimengImageJobs
// 视频生成任务统计
var videoJobs, jimengVideoJobs int64
h.DB.Model(&model.VideoJob{}).Count(&videoJobs)
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_video", "image_to_video"}).Count(&jimengVideoJobs)
stats.VideoJobs = videoJobs + jimengVideoJobs
// 今日视频生成任务统计
var todayVideoJobs, todayJimengVideoJobs int64
h.DB.Model(&model.VideoJob{}).Where("created_at > ?", zeroTime).Count(&todayVideoJobs)
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_video", "image_to_video"}).Where("created_at > ?", zeroTime).Count(&todayJimengVideoJobs)
stats.TodayVideoJobs = todayVideoJobs + todayJimengVideoJobs
// 音乐生成任务统计
h.DB.Model(&model.SunoJob{}).Count(&stats.MusicJobs)
// 今日音乐生成任务统计
h.DB.Model(&model.SunoJob{}).Where("created_at > ?", zeroTime).Count(&stats.TodayMusicJobs)
// recentOrders: 最近10条已支付订单
var orderList []model.Order
h.DB.Model(&model.Order{}).Where("status = ?", types.OrderPaidSuccess).Order("created_at desc").Limit(10).Find(&orderList)
for _, o := range orderList {
stats.RecentOrders = append(stats.RecentOrders, OrderBrief{
OrderNo: o.OrderNo,
Amount: o.Amount,
CreatedAt: o.CreatedAt,
})
}
// recentUsers: 最近10个注册用户
var userList []model.User
h.DB.Model(&model.User{}).Order("created_at desc").Limit(10).Find(&userList)
for _, u := range userList {
lastActive := u.UpdatedAt
if lastActive.IsZero() {
lastActive = u.CreatedAt
}
stats.RecentUsers = append(stats.RecentUsers, UserBrief{
Nickname: u.Nickname,
Avatar: u.Avatar,
LastActive: lastActive,
})
}
// 统计7天的订单的图表
startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02")
var statsChart = make(map[string]map[string]float64)
@@ -81,23 +197,29 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
// 统计用户7天增加的曲线
var users []model.User
res = h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users)
if res.Error == nil {
err := h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users).Error
if err == nil {
for _, item := range users {
userStatistic[item.CreatedAt.Format("2006-01-02")] += 1
}
}
// 统计7天Token 消耗
res = h.DB.Where("created_at > ?", startDate).Find(&historyMessages)
for _, item := range historyMessages {
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
// 统计7天算力消耗
var chartPowerLogs []model.PowerLog
err = h.DB.Where("mark = ?", types.PowerSub).Where("created_at > ?", startDate).Find(&chartPowerLogs).Error
if err == nil {
for _, item := range chartPowerLogs {
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Amount)
}
}
// 统计最近7天的订单
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
for _, item := range orders {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
var orders []model.Order
err = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders).Error
if err == nil {
for _, item := range orders {
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
}
}
statsChart["users"] = userStatistic

View File

@@ -9,6 +9,7 @@ package admin
import (
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/handler"
"geekai/store/model"
@@ -30,6 +31,21 @@ func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler {
return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *FunctionHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/function/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
group.GET("token", h.GenToken)
}
}
func (h *FunctionHandler) Save(c *gin.Context) {
var data vo.Function
if err := c.ShouldBindJSON(&data); err != nil {
@@ -119,7 +135,6 @@ func (h *FunctionHandler) GenToken(c *gin.Context) {
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
if err != nil {
logger.Error("error with generate token", err)
resp.ERROR(c)
return
}

View File

@@ -10,6 +10,7 @@ package admin
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/handler"
"geekai/service"
@@ -33,6 +34,20 @@ func NewImageHandler(app *core.AppServer, db *gorm.DB, userService *service.User
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
}
// RegisterRoutes 注册路由
func (h *ImageHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/image/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.POST("list/mj", h.MjList)
group.POST("list/sd", h.SdList)
group.POST("list/dall", h.DallList)
group.GET("remove", h.Remove)
}
}
type imageQuery struct {
Prompt string `json:"prompt"`
Username string `json:"username"`

View File

@@ -21,18 +21,18 @@ import (
// AdminJimengHandler 管理后台即梦AI处理器
type AdminJimengHandler struct {
handler.BaseHandler
jimengService *jimeng.Service
userService *service.UserService
uploader *oss.UploaderManager
jimengClient *jimeng.Client
userService *service.UserService
uploader *oss.UploaderManager
}
// NewAdminJimengHandler 创建管理后台即梦AI处理器
func NewAdminJimengHandler(app *core.AppServer, db *gorm.DB, jimengService *jimeng.Service, userService *service.UserService, uploader *oss.UploaderManager) *AdminJimengHandler {
func NewAdminJimengHandler(app *core.AppServer, db *gorm.DB, jimengClient *jimeng.Client, userService *service.UserService, uploader *oss.UploaderManager) *AdminJimengHandler {
return &AdminJimengHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
jimengService: jimengService,
userService: userService,
uploader: uploader,
BaseHandler: handler.BaseHandler{App: app, DB: db},
jimengClient: jimengClient,
userService: userService,
uploader: uploader,
}
}
@@ -43,7 +43,6 @@ func (h *AdminJimengHandler) RegisterRoutes() {
rg.GET("/jobs/:id", h.JobDetail)
rg.POST("/jobs/remove", h.BatchRemove)
rg.GET("/stats", h.Stats)
rg.GET("/config", h.GetConfig)
rg.POST("/config/update", h.UpdateConfig)
}
@@ -213,12 +212,6 @@ func (h *AdminJimengHandler) Stats(c *gin.Context) {
resp.SUCCESS(c, result)
}
// GetConfig 获取即梦AI配置
func (h *AdminJimengHandler) GetConfig(c *gin.Context) {
jimengConfig := h.jimengService.GetConfig()
resp.SUCCESS(c, jimengConfig)
}
// UpdateConfig 更新即梦AI配置
func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
var req types.JimengConfig
@@ -266,31 +259,35 @@ func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
// 保存配置
tx := h.DB.Begin()
value := utils.JsonEncode(&req)
config := model.Config{Name: "jimeng", Value: value}
var exist model.Config
tx.Where("name", types.ConfigKeyJimeng).First(&exist)
err := tx.FirstOrCreate(&config, model.Config{Name: "jimeng"}).Error
if err != nil {
resp.ERROR(c, "保存配置失败: "+err.Error())
return
}
if config.Id > 0 {
config.Value = value
err = tx.Updates(&config).Error
if exist.Id > 0 {
exist.Value = value
err := tx.Updates(&exist).Error
if err != nil {
resp.ERROR(c, "更新配置失败: "+err.Error())
return
}
} else {
exist.Name = types.ConfigKeyJimeng
exist.Value = value
err := tx.Create(&exist).Error
if err != nil {
resp.ERROR(c, "创建配置失败: "+err.Error())
return
}
}
// 更新服务中的客户端配置
updateErr := h.jimengService.UpdateClientConfig(req.AccessKey, req.SecretKey)
if updateErr != nil {
resp.ERROR(c, updateErr.Error())
err := h.jimengClient.UpdateConfig(req)
if err != nil {
resp.ERROR(c, err.Error())
tx.Rollback()
return
}
tx.Commit()
h.App.SysConfig.Jimeng = req
resp.SUCCESS(c, gin.H{"message": "配置更新成功"})
}

View File

@@ -10,6 +10,7 @@ package admin
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/handler"
"geekai/service"
@@ -33,6 +34,19 @@ func NewMediaHandler(app *core.AppServer, db *gorm.DB, userService *service.User
return &MediaHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
}
// RegisterRoutes 注册路由
func (h *MediaHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/media/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.POST("suno", h.SunoList)
group.POST("videos", h.Videos)
group.GET("remove", h.Remove)
}
}
type mediaQuery struct {
Type string `json:"type"` // 任务类型 luma, keling
Prompt string `json:"prompt"`

View File

@@ -27,6 +27,16 @@ func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
return &MenuHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *MenuHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/menu/")
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
}
func (h *MenuHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`

View File

@@ -0,0 +1,333 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/handler"
"geekai/service/moderation"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ModerationHandler struct {
handler.BaseHandler
sysConfig *types.SystemConfig
moderationManager *moderation.ServiceManager
}
func NewModerationHandler(app *core.AppServer, db *gorm.DB, sysConfig *types.SystemConfig, moderationManager *moderation.ServiceManager) *ModerationHandler {
return &ModerationHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, sysConfig: sysConfig, moderationManager: moderationManager}
}
// RegisterRoutes 注册路由
func (h *ModerationHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/moderation/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.POST("list", h.List)
group.GET("remove", h.Remove)
group.POST("batch-remove", h.BatchRemove)
group.GET("source-list", h.GetSourceList)
group.POST("config", h.UpdateModeration)
group.POST("test", h.TestModeration)
}
}
// List 获取文本审核记录列表
func (h *ModerationHandler) List(c *gin.Context) {
var data struct {
Username string `json:"username"`
Source string `json:"source"`
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
// 构建查询条件
if data.Username != "" {
// 通过用户名查找用户ID
var user model.User
if err := h.DB.Where("username LIKE ?", "%"+data.Username+"%").First(&user).Error; err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Source != "" {
session = session.Where("source", data.Source)
}
if data.StartDate != "" && data.EndDate != "" {
startTime := data.StartDate + " 00:00:00"
endTime := data.EndDate + " 23:59:59"
session = session.Where("created_at >= ? AND created_at <= ?", startTime, endTime)
}
// 统计总数
var total int64
session.Model(&model.Moderation{}).Count(&total)
// 分页
page := data.Page
pageSize := data.PageSize
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
// 查询数据
var items []model.Moderation
err := session.Order("id DESC").Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 获取用户信息
userIds := make([]uint, 0)
for _, item := range items {
userIds = append(userIds, item.UserId)
}
var users []model.User
if len(userIds) > 0 {
h.DB.Where("id IN ?", userIds).Find(&users)
}
userMap := make(map[uint]string)
for _, user := range users {
userMap[user.Id] = user.Username
}
// 转换为响应数据
list := make([]map[string]any, 0)
for _, item := range items {
var moderation types.ModerationResult
err := utils.JsonDecode(item.Result, &moderation)
if err != nil {
continue
}
var result []string
for value, label := range types.ModerationCategories {
if moderation.Categories[value] {
result = append(result, label)
}
}
list = append(list, map[string]any{
"id": item.Id,
"user_id": item.UserId,
"username": userMap[item.UserId],
"source": item.Source,
"input": item.Input,
"output": item.Output,
"result": result,
"created_at": item.CreatedAt.Unix(),
})
}
resp.SUCCESS(c, map[string]any{
"items": list,
"total": total,
"page": page,
"page_size": pageSize,
})
}
func (h *ModerationHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
err := h.DB.Where("id", id).Delete(&model.Moderation{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
// BatchRemove 批量删除文本审核记录
func (h *ModerationHandler) BatchRemove(c *gin.Context) {
var data struct {
Ids []uint `json:"ids"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if len(data.Ids) == 0 {
resp.ERROR(c, "请选择要删除的记录")
return
}
err := h.DB.Where("id IN ?", data.Ids).Delete(&model.Moderation{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
// 获取 source 列表
func (h *ModerationHandler) GetSourceList(c *gin.Context) {
sources := []gin.H{
{
"id": types.ModerationSourceChat,
"name": "AI对话",
},
{
"id": types.ModerationSourceMJ,
"name": "Midjourney 绘图",
},
{
"id": types.ModerationSourceDalle,
"name": "Dalle 绘图",
},
{
"id": types.ModerationSourceSD,
"name": "StableDiffusion 绘图",
},
{
"id": types.ModerationSourceSuno,
"name": "Suno 音乐",
},
{
"id": types.ModerationSourceVideo,
"name": "视频生成",
},
{
"id": types.ModerationSourceJiMeng,
"name": "即梦AI",
},
}
resp.SUCCESS(c, sources)
}
// UpdateModeration 更新文本审查配置
func (h *ModerationHandler) UpdateModeration(c *gin.Context) {
var data types.ModerationConfig
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
var config model.Config
err := h.DB.Where("name", types.ConfigKeyModeration).First(&config).Error
if err != nil {
config.Name = types.ConfigKeyModeration
config.Value = utils.JsonEncode(data)
err = h.DB.Create(&config).Error
} else {
config.Value = utils.JsonEncode(data)
err = h.DB.Updates(&config).Error
}
if err != nil {
resp.ERROR(c, err.Error())
return
}
h.moderationManager.UpdateConfig(data)
h.sysConfig.Moderation = data
resp.SUCCESS(c, data)
}
// 测试结果类型,用于前端显示
type ModerationTestResult struct {
IsAbnormal bool `json:"isAbnormal"`
Details []ModerationTestDetail `json:"details"`
}
type ModerationTestDetail struct {
Category string `json:"category"`
Description string `json:"description"`
Confidence string `json:"confidence"`
IsCategory bool `json:"isCategory"`
}
// TestModeration 测试文本审查服务
func (h *ModerationHandler) TestModeration(c *gin.Context) {
var data struct {
Text string `json:"text"`
Service string `json:"service"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Text == "" {
resp.ERROR(c, "测试文本不能为空")
return
}
// 检查是否启用了文本审查
if !h.sysConfig.Moderation.Enable {
resp.ERROR(c, "文本审查服务未启用")
return
}
// 获取当前激活的审核服务
service := h.moderationManager.GetService()
// 执行文本审核
result, err := service.Moderate(data.Text)
if err != nil {
resp.ERROR(c, "审核服务调用失败: "+err.Error())
return
}
// 转换为前端需要的格式
testResult := ModerationTestResult{
IsAbnormal: result.Flagged,
Details: make([]ModerationTestDetail, 0),
}
// 构建详细信息
for category, description := range types.ModerationCategories {
score := result.CategoryScores[category]
isCategory := result.Categories[category]
testResult.Details = append(testResult.Details, ModerationTestDetail{
Category: category,
Description: description,
Confidence: fmt.Sprintf("%.2f", score),
IsCategory: isCategory,
})
}
resp.SUCCESS(c, testResult)
}

View File

@@ -29,6 +29,14 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *OrderHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/order/")
group.POST("list", h.List)
group.GET("remove", h.Remove)
group.GET("clear", h.Clear)
}
func (h *OrderHandler) List(c *gin.Context) {
var data struct {
OrderNo string `json:"order_no"`
@@ -68,16 +76,16 @@ func (h *OrderHandler) List(c *gin.Context) {
order.Id = item.Id
order.CreatedAt = item.CreatedAt.Unix()
order.UpdatedAt = item.UpdatedAt.Unix()
payMethod, ok := types.PayMethods[item.PayWay]
payChannel, ok := types.PayChannel[item.Channel]
if !ok {
payMethod = item.PayWay
payChannel = item.Channel
}
payName, ok := types.PayNames[item.PayType]
payWays, ok := types.PayWays[item.PayWay]
if !ok {
payName = item.PayWay
payWays = item.PayWay
}
order.PayMethod = payMethod
order.PayName = payName
order.ChannelName = payChannel
order.PayName = payWays
list = append(list, order)
} else {
logger.Error(err)
@@ -121,8 +129,8 @@ func (h *OrderHandler) Clear(c *gin.Context) {
}
deleteIds := make([]uint, 0)
for _, order := range orders {
// 只删除 15 分钟内的未支付订单
if time.Now().After(order.CreatedAt.Add(time.Minute * 15)) {
// 只删除超时的未支付订单
if time.Now().After(order.CreatedAt.Add(time.Minute * time.Duration(h.App.SysConfig.Base.OrderPayTimeout))) {
deleteIds = append(deleteIds, order.Id)
}
}

View File

@@ -28,6 +28,12 @@ func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *PowerLogHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/powerLog/")
group.POST("list", h.List)
}
func (h *PowerLogHandler) List(c *gin.Context) {
var data struct {
Username string `json:"username"`

View File

@@ -15,9 +15,10 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"time"
)
type ProductHandler struct {
@@ -28,14 +29,22 @@ func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *ProductHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/product/")
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
}
func (h *ProductHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Name string `json:"name"`
Price float64 `json:"price"`
Discount float64 `json:"discount"`
Enabled bool `json:"enabled"`
Days int `json:"days"`
Power int `json:"power"`
CreatedAt int64 `json:"created_at"`
}
@@ -45,12 +54,10 @@ func (h *ProductHandler) Save(c *gin.Context) {
}
item := model.Product{
Name: data.Name,
Price: data.Price,
Discount: data.Discount,
Days: data.Days,
Power: data.Power,
Enabled: data.Enabled}
Name: data.Name,
Price: data.Price,
Power: data.Power,
Enabled: data.Enabled}
item.Id = data.Id
if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0)

View File

@@ -29,6 +29,16 @@ func NewRedeemHandler(app *core.AppServer, db *gorm.DB) *RedeemHandler {
return &RedeemHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *RedeemHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/redeem/")
group.GET("list", h.List)
group.POST("create", h.Create)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
group.POST("export", h.Export)
}
func (h *RedeemHandler) List(c *gin.Context) {
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)

View File

@@ -9,6 +9,7 @@ package admin
import (
"geekai/core"
"geekai/core/middleware"
"geekai/handler"
"geekai/service/oss"
"geekai/store/model"
@@ -28,6 +29,17 @@ func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderMan
return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager}
}
// RegisterRoutes 注册路由
func (h *UploadHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/upload")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.POST("", h.Upload)
}
}
func (h *UploadHandler) Upload(c *gin.Context) {
// 判断文件大小
f, err := c.FormFile("file")
@@ -36,7 +48,7 @@ func (h *UploadHandler) Upload(c *gin.Context) {
return
}
if h.App.SysConfig.MaxFileSize > 0 && f.Size > int64(h.App.SysConfig.MaxFileSize)*1024*1024 {
if h.App.SysConfig.Base.MaxFileSize > 0 && f.Size > int64(h.App.SysConfig.Base.MaxFileSize)*1024*1024 {
resp.ERROR(c, "文件大小超过限制")
return
}

View File

@@ -10,6 +10,7 @@ package admin
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/handler"
"geekai/service"
@@ -19,10 +20,9 @@ import (
"geekai/utils/resp"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
@@ -36,6 +36,22 @@ func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.Li
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService, redis: redisCli}
}
// RegisterRoutes 注册路由
func (h *UserHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/admin/user/")
// 需要管理员授权的接口
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.POST("save", h.Save)
group.GET("remove", h.Remove)
group.GET("loginLog", h.LoginLog)
group.GET("genLoginLink", h.GenLoginLink)
group.POST("resetPass", h.ResetPass)
}
}
// List 用户列表
func (h *UserHandler) List(c *gin.Context) {
page := h.GetInt(c, "page", 1)

View File

@@ -15,9 +15,10 @@ import (
logger2 "geekai/logger"
"geekai/store/model"
"geekai/utils"
"gorm.io/gorm"
"strings"
"gorm.io/gorm"
"github.com/gin-gonic/gin"
)
@@ -69,6 +70,14 @@ func (h *BaseHandler) GetLoginUserId(c *gin.Context) uint {
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
}
func (h *BaseHandler) GetAdminId(c *gin.Context) uint {
userId, ok := c.Get(types.AdminUserID)
if !ok {
return 0
}
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
}
func (h *BaseHandler) IsLogin(c *gin.Context) bool {
return h.GetLoginUserId(c) > 0
}

View File

@@ -8,23 +8,45 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
)
// 今日头条函数实现
type CaptchaHandler struct {
App *core.AppServer
service *service.CaptchaService
}
func NewCaptchaHandler(s *service.CaptchaService) *CaptchaHandler {
return &CaptchaHandler{service: s}
func NewCaptchaHandler(app *core.AppServer, s *service.CaptchaService, sysConfig *types.SystemConfig) *CaptchaHandler {
return &CaptchaHandler{App: app, service: s}
}
// RegisterRoutes 注册路由
func (h *CaptchaHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/captcha/")
// 无需授权的接口
group.GET("get", h.Get)
group.POST("check", h.Check)
group.GET("slide/get", h.SlideGet)
group.POST("slide/check", h.SlideCheck)
group.GET("config", h.GetConfig)
}
func (h *CaptchaHandler) GetConfig(c *gin.Context) {
resp.SUCCESS(c, gin.H{"enabled": h.service.GetConfig().Enabled, "type": h.service.GetConfig().Type})
}
func (h *CaptchaHandler) Get(c *gin.Context) {
if !h.service.GetConfig().Enabled {
resp.ERROR(c, "验证码服务未启用")
return
}
data, err := h.service.Get()
if err != nil {
resp.ERROR(c, err.Error())
@@ -36,6 +58,11 @@ func (h *CaptchaHandler) Get(c *gin.Context) {
// Check verify the captcha data
func (h *CaptchaHandler) Check(c *gin.Context) {
if !h.service.GetConfig().Enabled {
resp.ERROR(c, "验证码服务未启用")
return
}
var data struct {
Key string `json:"key"`
Dots string `json:"dots"`
@@ -55,6 +82,11 @@ func (h *CaptchaHandler) Check(c *gin.Context) {
// SlideGet 获取滑动验证图片
func (h *CaptchaHandler) SlideGet(c *gin.Context) {
if !h.service.GetConfig().Enabled {
resp.ERROR(c, "验证码服务未启用")
return
}
data, err := h.service.SlideGet()
if err != nil {
resp.ERROR(c, err.Error())
@@ -66,6 +98,11 @@ func (h *CaptchaHandler) SlideGet(c *gin.Context) {
// SlideCheck 滑动验证结果校验
func (h *CaptchaHandler) SlideCheck(c *gin.Context) {
if !h.service.GetConfig().Enabled {
resp.ERROR(c, "验证码服务未启用")
return
}
var data struct {
Key string `json:"key"`
X int `json:"x"`

View File

@@ -9,6 +9,7 @@ package handler
import (
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
@@ -19,18 +20,31 @@ import (
"gorm.io/gorm"
)
type ChatRoleHandler struct {
type ChatAppHandler struct {
BaseHandler
}
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
return &ChatRoleHandler{BaseHandler: BaseHandler{App: app, DB: db}}
func NewChatAppHandler(app *core.AppServer, db *gorm.DB) *ChatAppHandler {
return &ChatAppHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *ChatAppHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/app/")
group.GET("list", h.List)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.GET("list/user", h.ListByUser)
group.POST("update", h.UpdateApp)
}
}
// List 获取用户聊天应用列表
func (h *ChatRoleHandler) List(c *gin.Context) {
func (h *ChatAppHandler) List(c *gin.Context) {
tid := h.GetInt(c, "tid", 0)
var roles []model.ChatRole
var roles []model.ChatApp
session := h.DB.Where("enable", true)
if tid > 0 {
session = session.Where("tid", tid)
@@ -41,9 +55,9 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
return
}
var roleVos = make([]vo.ChatRole, 0)
var roleVos = make([]vo.ChatApp, 0)
for _, r := range roles {
var v vo.ChatRole
var v vo.ChatApp
err := utils.CopyObject(r, &v)
if err == nil {
v.Id = r.Id
@@ -54,10 +68,10 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
}
// ListByUser 获取用户添加的角色列表
func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
func (h *ChatAppHandler) ListByUser(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
var roles []model.ChatRole
var roles []model.ChatApp
session := h.DB.Where("enable", true)
// 如果用户没登录,则获取所有角色
if userId > 0 {
@@ -86,9 +100,9 @@ func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
return
}
var roleVos = make([]vo.ChatRole, 0)
var roleVos = make([]vo.ChatApp, 0)
for _, r := range roles {
var v vo.ChatRole
var v vo.ChatApp
err := utils.CopyObject(r, &v)
if err == nil {
v.Id = r.Id
@@ -98,8 +112,8 @@ func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
resp.SUCCESS(c, roleVos)
}
// UpdateRole 更新用户聊天角色
func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
// UpdateApp 更新用户聊天应用
func (h *ChatAppHandler) UpdateApp(c *gin.Context) {
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)

View File

@@ -19,6 +19,12 @@ func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler
return &ChatAppTypeHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *ChatAppTypeHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/app/type/")
group.GET("list", h.List)
}
// List 获取App类型列表
func (h *ChatAppTypeHandler) List(c *gin.Context) {
var items []model.AppType

View File

@@ -14,8 +14,10 @@ import (
"errors"
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/service/moderation"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
@@ -39,6 +41,7 @@ import (
const (
ChatEventStart = "start"
ChatEventEnd = "end"
ChatEventComplete = "complete"
ChatEventError = "error"
ChatEventMessageDelta = "message_delta"
ChatEventTitle = "title"
@@ -54,44 +57,69 @@ type ChatInput struct {
Stream bool `json:"stream"`
Files []vo.File `json:"files"`
ChatModel model.ChatModel `json:"chat_model,omitempty"`
ChatRole model.ChatRole `json:"chat_role,omitempty"`
ChatRole model.ChatApp `json:"chat_role,omitempty"`
LastMsgId uint `json:"last_msg_id,omitempty"` // 最后的消息ID用于重新生成答案的时候过滤上下文
}
type ChatHandler struct {
BaseHandler
redis *redis.Client
uploadManager *oss.UploaderManager
licenseService *service.LicenseService
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
userService *service.UserService
redis *redis.Client
uploadManager *oss.UploaderManager
licenseService *service.LicenseService
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
userService *service.UserService
moderationManager *moderation.ServiceManager
}
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler {
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService, moderationManager *moderation.ServiceManager) *ChatHandler {
return &ChatHandler{
BaseHandler: BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
licenseService: licenseService,
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
userService: userService,
BaseHandler: BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
licenseService: licenseService,
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
userService: userService,
moderationManager: moderationManager,
}
}
// RegisterRoutes 注册路由
func (h *ChatHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/chat/")
// 聊天接口不需要授权已在authConfig中配置
group.Any("message", h.Chat)
// 其他接口需要用户授权
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.GET("detail", h.Detail)
group.POST("update", h.Update)
group.GET("remove", h.Remove)
group.GET("history", h.History)
group.GET("clear", h.Clear)
group.POST("tokens", h.Tokens)
group.GET("stop", h.StopGenerate)
group.POST("tts", h.TextToSpeech)
}
}
// Chat 处理聊天请求
func (h *ChatHandler) Chat(c *gin.Context) {
var input ChatInput
if err := c.ShouldBindJSON(&input); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 设置SSE响应头
c.Header("Prompt-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
var input ChatInput
if err := c.ShouldBindJSON(&input); err != nil {
pushMessage(c, ChatEventError, types.InvalidArgs)
c.Abort()
return
}
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
@@ -113,7 +141,7 @@ func (h *ChatHandler) Chat(c *gin.Context) {
}
// 验证聊天角色
var chatRole model.ChatRole
var chatRole model.ChatApp
err := h.DB.First(&chatRole, input.RoleId).Error
if err != nil || !chatRole.Enable {
pushMessage(c, ChatEventError, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!")
@@ -166,7 +194,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
}
if userVo.Power < input.ChatModel.Power {
return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d[立即购买](/member)。", userVo.Power, input.ChatModel.Power)
return fmt.Errorf("您的算力不足,请购买算力。")
}
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
@@ -229,17 +257,24 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
// 加载聊天上下文
chatCtx := make([]any, 0)
messages := make([]any, 0)
if h.App.SysConfig.EnableContext {
if h.App.SysConfig.Base.EnableContext {
_ = utils.JsonDecode(input.ChatRole.Context, &messages)
if h.App.SysConfig.ContextDeep > 0 {
if h.App.SysConfig.Base.ContextDeep > 0 {
var historyMessages []model.ChatMessage
dbSession := h.DB.Session(&gorm.Session{}).Where("chat_id", input.ChatId)
if input.LastMsgId > 0 { // 重新生成逻辑
var lastMessage model.ChatMessage
err = dbSession.Where("id <= ?", input.LastMsgId).Where("type", types.PromptMsg).First(&lastMessage).Error
if err != nil {
input.LastMsgId = 0
} else {
input.LastMsgId = lastMessage.Id
}
dbSession = dbSession.Where("id < ?", input.LastMsgId)
// 删除对应的聊天记录
h.DB.Debug().Where("chat_id", input.ChatId).Where("id >= ?", input.LastMsgId).Delete(&model.ChatMessage{})
}
err = dbSession.Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages).Error
err = dbSession.Limit(h.App.SysConfig.Base.ContextDeep).Order("id DESC").Find(&historyMessages).Error
if err == nil {
for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i]
@@ -267,7 +302,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
}
// 上下文的深度超出了模型的最大上下文深度
if len(chatCtx) >= h.App.SysConfig.ContextDeep {
if len(chatCtx) >= h.App.SysConfig.Base.ContextDeep {
break
}
@@ -277,6 +312,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
}
reqMgs := make([]any, 0)
// 添加引导提示词,防止模型生成违规内容
if h.App.SysConfig.Moderation.EnableGuide {
reqMgs = append(reqMgs, map[string]any{
"role": "system",
"content": h.App.SysConfig.Moderation.GuidePrompt,
})
}
for i := len(chatCtx) - 1; i >= 0; i-- {
reqMgs = append(reqMgs, chatCtx[i])
}
@@ -295,16 +338,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
},
})
} else {
// 如果不是逆向模型,则提取文件内容
modelValue := input.ChatModel.Value
if !(strings.Contains(modelValue, "-all") || strings.HasPrefix(modelValue, "gpt-4-gizmo") || strings.HasPrefix(modelValue, "claude")) {
content, err := utils.ReadFileContent(file.URL, h.App.Config.TikaHost)
if err != nil {
logger.Error("error with read file: ", err)
continue
} else {
fileContents = append(fileContents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
}
// 处理文件,提取文件内容
content, err := utils.ReadFileContent(file.URL, h.App.Config.TikaHost)
if err != nil {
logger.Error("error with read file: ", err)
continue
} else {
fileContents = append(fileContents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
logger.Debugf("fileContents: %s", fileContents)
}
}
}
@@ -320,16 +361,16 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
}
if len(imgList) > 0 {
imgList = append(imgList, map[string]interface{}{
imgList = append(imgList, map[string]any{
"type": "text",
"text": input.Prompt,
})
req.Messages = append(reqMgs, map[string]interface{}{
req.Messages = append(reqMgs, map[string]any{
"role": "user",
"content": imgList,
})
} else {
req.Messages = append(reqMgs, map[string]interface{}{
req.Messages = append(reqMgs, map[string]any{
"role": "user",
"content": finalPrompt,
})
@@ -445,7 +486,7 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, input ChatInput, apiKey *model.ApiKey) (*http.Response, error) {
// if the chat model bind a KEY, use it directly
if input.ChatModel.KeyId > 0 {
h.DB.Where("id", input.ChatModel.KeyId).Find(apiKey)
h.DB.Where("id", input.ChatModel.KeyId).Where("enabled", true).Find(apiKey)
} else { // use the last unused key
h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
}
@@ -516,6 +557,7 @@ func (h *ChatHandler) subUserPower(userVo vo.User, input ChatInput, promptTokens
}
func (h *ChatHandler) saveChatHistory(
c *gin.Context,
req types.ApiRequest,
usage Usage,
message types.Message,
@@ -524,6 +566,34 @@ func (h *ChatHandler) saveChatHistory(
promptCreatedAt time.Time,
replyCreatedAt time.Time) {
// 文本审核
if h.App.SysConfig.Moderation.Enable {
moderationResult, err := h.moderationManager.GetService().Moderate(usage.Content)
if err != nil {
logger.Error("failed to moderate content: ", err)
}
logger.Debugf("moderationResult: %+v", moderationResult)
if moderationResult.Flagged {
// 记录违规内容
moderation := model.Moderation{
UserId: userVo.Id,
Source: types.ModerationSourceChat,
Input: usage.Prompt,
Output: usage.Content,
Result: utils.JsonEncode(moderationResult),
}
err = h.DB.Create(&moderation).Error
if err != nil {
logger.Error("failed to save moderation: ", err)
}
pushMessage(c, ChatEventError, "很抱歉内容触发敏感词预警AI 无法回答!!!")
// 更新用户算力
if input.ChatModel.Power > 0 {
h.subUserPower(userVo, input, 0, 0)
}
return
}
}
// 追加聊天记录
// for prompt
var promptTokens, replyTokens, totalTokens int
@@ -586,6 +656,22 @@ func (h *ChatHandler) saveChatHistory(
logger.Error("failed to save reply history message: ", err)
}
// 发送完整聊天记录给前端
var messageVo vo.ChatMessage
err = utils.CopyObject(historyReplyMsg, &messageVo)
if err == nil {
// 解析内容
var content vo.MsgContent
err = utils.JsonDecode(historyReplyMsg.Content, &content)
if err != nil {
content.Text = historyReplyMsg.Content
}
messageVo.Content = content
messageVo.CreatedAt = historyReplyMsg.CreatedAt.Unix()
messageVo.UpdatedAt = historyReplyMsg.UpdatedAt.Unix()
pushMessage(c, ChatEventComplete, messageVo)
}
// 更新用户算力
if input.ChatModel.Power > 0 {
h.subUserPower(userVo, input, promptTokens, replyTokens)
@@ -710,221 +796,3 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
logger.Error("写入音频数据到响应失败:", err)
}
}
// // OPenAI 消息发送实现
// func (h *ChatHandler) sendOpenAiMessage(
// req types.ApiRequest,
// userVo vo.User,
// ctx context.Context,
// session *types.ChatSession,
// role model.ChatRole,
// prompt string,
// c *gin.Context) error {
// promptCreatedAt := time.Now() // 记录提问时间
// start := time.Now()
// var apiKey = model.ApiKey{}
// response, err := h.doRequest(ctx, req, session, &apiKey)
// logger.Info("HTTP请求完成耗时", time.Since(start))
// if err != nil {
// if strings.Contains(err.Error(), "context canceled") {
// return fmt.Errorf("用户取消了请求:%s", prompt)
// } else if strings.Contains(err.Error(), "no available key") {
// return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
// }
// return err
// } else {
// defer response.Body.Close()
// }
// if response.StatusCode != 200 {
// body, _ := io.ReadAll(response.Body)
// return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
// }
// contentType := response.Header.Get("Prompt-Type")
// if strings.Contains(contentType, "text/event-stream") {
// replyCreatedAt := time.Now() // 记录回复时间
// // 循环读取 Chunk 消息
// var message = types.Message{Role: "assistant"}
// var contents = make([]string, 0)
// var function model.Function
// var toolCall = false
// var arguments = make([]string, 0)
// var reasoning = false
// pushMessage(c, ChatEventStart, "开始响应")
// scanner := bufio.NewScanner(response.Body)
// for scanner.Scan() {
// line := scanner.Text()
// if !strings.Contains(line, "data:") || len(line) < 30 {
// continue
// }
// var responseBody = types.ApiResponse{}
// err = json.Unmarshal([]byte(line[6:]), &responseBody)
// if err != nil { // 数据解析出错
// return errors.New(line)
// }
// if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
// continue
// }
// if responseBody.Choices[0].Delta.Prompt == nil &&
// responseBody.Choices[0].Delta.ToolCalls == nil &&
// responseBody.Choices[0].Delta.ReasoningContent == "" {
// continue
// }
// if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
// pushMessage(c, ChatEventError, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。")
// break
// }
// var tool types.ToolCall
// if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
// tool = responseBody.Choices[0].Delta.ToolCalls[0]
// if toolCall && tool.Function.Name == "" {
// arguments = append(arguments, tool.Function.Arguments)
// continue
// }
// }
// // 兼容 Function Call
// fun := responseBody.Choices[0].Delta.FunctionCall
// if fun.Name != "" {
// tool = *new(types.ToolCall)
// tool.Function.Name = fun.Name
// } else if toolCall {
// arguments = append(arguments, fun.Arguments)
// continue
// }
// if !utils.IsEmptyValue(tool) {
// res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
// if res.Error == nil {
// toolCall = true
// callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
// "type": "text",
// "content": callMsg,
// })
// contents = append(contents, callMsg)
// }
// continue
// }
// if responseBody.Choices[0].FinishReason == "tool_calls" ||
// responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
// break
// }
// // output stopped
// if responseBody.Choices[0].FinishReason != "" {
// break // 输出完成或者输出中断了
// } else { // 正常输出结果
// // 兼容思考过程
// if responseBody.Choices[0].Delta.ReasoningContent != "" {
// reasoningContent := responseBody.Choices[0].Delta.ReasoningContent
// if !reasoning {
// reasoningContent = fmt.Sprintf("<think>%s", reasoningContent)
// reasoning = true
// }
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
// "type": "text",
// "content": reasoningContent,
// })
// contents = append(contents, reasoningContent)
// } else if responseBody.Choices[0].Delta.Prompt != "" {
// finalContent := responseBody.Choices[0].Delta.Prompt
// if reasoning {
// finalContent = fmt.Sprintf("</think>%s", responseBody.Choices[0].Delta.Prompt)
// reasoning = false
// }
// contents = append(contents, utils.InterfaceToString(finalContent))
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
// "type": "text",
// "content": finalContent,
// })
// }
// }
// } // end for
// if err := scanner.Err(); err != nil {
// if strings.Contains(err.Error(), "context canceled") {
// logger.Info("用户取消了请求:", prompt)
// } else {
// logger.Error("信息读取出错:", err)
// }
// }
// if toolCall { // 调用函数完成任务
// params := make(map[string]any)
// _ = utils.JsonDecode(strings.Join(arguments, ""), &params)
// logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
// params["user_id"] = userVo.Id
// var apiRes types.BizVo
// r, err := req2.C().R().SetHeader("Body-Type", "application/json").
// SetHeader("Authorization", function.Token).
// SetBody(params).Post(function.Action)
// errMsg := ""
// if err != nil {
// errMsg = err.Error()
// } else {
// all, _ := io.ReadAll(r.Body)
// err = json.Unmarshal(all, &apiRes)
// if err != nil {
// errMsg = err.Error()
// } else if apiRes.Code != types.Success {
// errMsg = apiRes.Message
// }
// }
// if errMsg != "" {
// errMsg = "调用函数工具出错:" + errMsg
// contents = append(contents, errMsg)
// } else {
// errMsg = utils.InterfaceToString(apiRes.Data)
// contents = append(contents, errMsg)
// }
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
// "type": "text",
// "content": errMsg,
// })
// }
// // 消息发送成功
// if len(contents) > 0 {
// usage := Usage{
// Prompt: prompt,
// Prompt: strings.Join(contents, ""),
// PromptTokens: 0,
// CompletionTokens: 0,
// TotalTokens: 0,
// }
// message.Prompt = usage.Prompt
// h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
// }
// } else {
// var respVo OpenAIResVo
// body, err := io.ReadAll(response.Body)
// if err != nil {
// return fmt.Errorf("读取响应失败:%v", body)
// }
// err = json.Unmarshal(body, &respVo)
// if err != nil {
// return fmt.Errorf("解析响应失败:%v", body)
// }
// content := respVo.Choices[0].Message.Prompt
// if strings.HasPrefix(req.Model, "o1-") {
// content = fmt.Sprintf("AI思考结束耗时%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Prompt)
// }
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
// "type": "text",
// "content": content,
// })
// respVo.Usage.Prompt = prompt
// respVo.Usage.Prompt = content
// h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
// }
// return nil
// }

View File

@@ -42,9 +42,9 @@ func (h *ChatHandler) List(c *gin.Context) {
modelValues = append(modelValues, chat.Model)
}
var roles []model.ChatRole
var roles []model.ChatApp
var models []model.ChatModel
roleMap := make(map[uint]model.ChatRole)
roleMap := make(map[uint]model.ChatApp)
modelMap := make(map[string]model.ChatModel)
h.DB.Where("id IN ?", roleIds).Find(&roles)
h.DB.Where("value IN ?", modelValues).Find(&models)
@@ -205,7 +205,7 @@ func (h *ChatHandler) Detail(c *gin.Context) {
}
// 填充角色名称
var role model.ChatRole
var role model.ChatApp
res = h.DB.Where("id", chatItem.RoleId).First(&role)
if res.Error != nil {
resp.ERROR(c, "Role not found")

View File

@@ -26,6 +26,12 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *ChatModelHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/model/")
group.GET("list", h.List)
}
// List 模型列表
func (h *ChatModelHandler) List(c *gin.Context) {
var items []model.ChatModel

View File

@@ -226,7 +226,7 @@ func (h *ChatHandler) sendOpenAiMessage(
TotalTokens: 0,
}
message.Content = usage.Content
h.saveChatHistory(req, usage, message, input, userVo, promptCreatedAt, replyCreatedAt)
h.saveChatHistory(c, req, usage, message, input, userVo, promptCreatedAt, replyCreatedAt)
}
} else { // 非流式输出
var respVo OpenAIResVo
@@ -242,7 +242,7 @@ func (h *ChatHandler) sendOpenAiMessage(
pushMessage(c, "text", content)
respVo.Usage.Prompt = input.Prompt
respVo.Usage.Content = content
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, input, userVo, promptCreatedAt, time.Now())
h.saveChatHistory(c, req, respVo.Usage, respVo.Choices[0].Message, input, userVo, promptCreatedAt, time.Now())
}
return nil

View File

@@ -27,6 +27,15 @@ func NewConfigHandler(app *core.AppServer, db *gorm.DB, licenseService *service.
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}, licenseService: licenseService}
}
// RegisterRoutes 注册路由
func (h *ConfigHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/config/")
// 无需授权的接口
group.GET("get", h.Get)
group.GET("license", h.License)
}
// Get 获取指定的系统配置
func (h *ConfigHandler) Get(c *gin.Context) {
key := c.Query("key")

View File

@@ -10,9 +10,11 @@ package handler
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/service/dalle"
"geekai/service/moderation"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
@@ -25,16 +27,18 @@ import (
type DallJobHandler struct {
BaseHandler
dallService *dalle.Service
uploader *oss.UploaderManager
userService *service.UserService
dallService *dalle.Service
uploader *oss.UploaderManager
userService *service.UserService
moderationManager *moderation.ServiceManager
}
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService) *DallJobHandler {
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *DallJobHandler {
return &DallJobHandler{
dallService: service,
uploader: manager,
userService: userService,
dallService: service,
uploader: manager,
userService: userService,
moderationManager: moderationManager,
BaseHandler: BaseHandler{
App: app,
DB: db,
@@ -42,6 +46,24 @@ func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service,
}
}
// RegisterRoutes 注册路由
func (h *DallJobHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/dall/")
// 公开接口,不需要授权
group.GET("imgWall", h.ImgWall)
group.GET("models", h.GetModels)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}
}
// Image 创建一个绘画任务
func (h *DallJobHandler) Image(c *gin.Context) {
var data types.DallTask
@@ -50,6 +72,29 @@ func (h *DallJobHandler) Image(c *gin.Context) {
return
}
// 文本审核
if h.App.SysConfig.Moderation.Enable {
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
if err != nil {
logger.Error("failed to moderate content: ", err)
}
if moderationResult.Flagged {
// 记录违规内容
moderation := model.Moderation{
UserId: h.GetLoginUserId(c),
Source: types.ModerationSourceDalle,
Input: data.Prompt,
Result: utils.JsonEncode(moderationResult),
}
err = h.DB.Create(&moderation).Error
if err != nil {
logger.Error("failed to save moderation: ", err)
}
resp.ERROR(c, "当前创作内容包含敏感词,提示词未通过文本审核,请重新输入!")
return
}
}
var chatModel model.ChatModel
if res := h.DB.Where("id = ?", data.ModelId).First(&chatModel); res.Error != nil {
resp.ERROR(c, "模型不存在")
@@ -73,11 +118,12 @@ func (h *DallJobHandler) Image(c *gin.Context) {
UserId: uint(userId),
ModelId: chatModel.Id,
ModelName: chatModel.Value,
Image: data.Image,
Prompt: data.Prompt,
Quality: data.Quality,
Size: data.Size,
Style: data.Style,
TranslateModelId: h.App.SysConfig.AssistantModelId,
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
Power: chatModel.Power,
}
job := model.DallJob{

View File

@@ -13,7 +13,6 @@ import (
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/crawler"
"geekai/service/dalle"
"geekai/service/oss"
"geekai/store/model"
@@ -31,7 +30,6 @@ import (
type FunctionHandler struct {
BaseHandler
config types.ApiConfig
uploadManager *oss.UploaderManager
dallService *dalle.Service
userService *service.UserService
@@ -49,13 +47,23 @@ func NewFunctionHandler(
App: server,
DB: db,
},
config: config.ApiConfig,
uploadManager: manager,
dallService: dallService,
userService: userService,
}
}
// RegisterRoutes 注册路由
func (h *FunctionHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/function/")
group.GET("list", h.List)
// 需要用户授权的接口
group.POST("weibo", h.WeiBo)
group.POST("zaobao", h.ZaoBao)
group.POST("dalle3", h.Dall3)
}
type resVo struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
@@ -107,16 +115,10 @@ func (h *FunctionHandler) WeiBo(c *gin.Context) {
return
}
if h.config.Token == "" {
resp.ERROR(c, "无效的 API Token")
return
}
url := fmt.Sprintf("%s/api/weibo/fetch", h.config.ApiURL)
url := fmt.Sprintf("%s/api/weibo/fetch", types.GeekAPIURL)
var res resVo
r, err := req.C().R().
SetHeader("AppId", h.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
SetHeader("Authorization", "Bearer geekai-plus").
SetSuccessResult(&res).Get(url)
if err != nil {
resp.ERROR(c, fmt.Sprintf("%v", err))
@@ -146,16 +148,10 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) {
return
}
if h.config.Token == "" {
resp.ERROR(c, "无效的 API Token")
return
}
url := fmt.Sprintf("%s/api/zaobao/fetch", h.config.ApiURL)
url := fmt.Sprintf("%s/api/zaobao/fetch", types.GeekAPIURL)
var res resVo
r, err := req.C().R().
SetHeader("AppId", h.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
SetHeader("Authorization", "Bearer geekai-plus").
SetSuccessResult(&res).Get(url)
if err != nil {
resp.ERROR(c, fmt.Sprintf("%v", err))
@@ -193,16 +189,23 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
return
}
var chatModel model.ChatModel
res := h.DB.Where("type = ?", "img").Where("enabled", true).First(&chatModel)
if res.Error != nil {
resp.ERROR(c, "没有找到可用的AI绘图模型")
return
}
logger.Debugf("绘画参数:%+v", params)
var user model.User
res := h.DB.Where("id = ?", params["user_id"]).First(&user)
res = h.DB.Where("id = ?", params["user_id"]).First(&user)
if res.Error != nil {
resp.ERROR(c, "当前用户不存在!")
return
}
if user.Power < h.App.SysConfig.DallPower {
resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足")
if user.Power < chatModel.Power {
resp.ERROR(c, "创建绘图任务失败,算力不足")
return
}
@@ -211,24 +214,24 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
task := types.DallTask{
UserId: user.Id,
Prompt: prompt,
ModelId: 0,
ModelName: "dall-e-3",
TranslateModelId: h.App.SysConfig.AssistantModelId,
ModelId: chatModel.Id,
ModelName: chatModel.Value,
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
N: 1,
Quality: "standard",
Size: "1024x1024",
Style: "vivid",
Power: h.App.SysConfig.DallPower,
Power: chatModel.Power,
}
job := model.DallJob{
UserId: user.Id,
Prompt: prompt,
Power: h.App.SysConfig.DallPower,
Power: chatModel.Power,
TaskInfo: utils.JsonEncode(task),
}
err := h.DB.Create(&job).Error
if err != nil {
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+err.Error())
resp.ERROR(c, "创建绘图任务失败:"+err.Error())
return
}
@@ -253,76 +256,6 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
resp.SUCCESS(c, content)
}
// 实现一个联网搜索的函数工具,采用爬虫实现
func (h *FunctionHandler) WebSearch(c *gin.Context) {
if err := h.checkAuth(c); err != nil {
resp.ERROR(c, err.Error())
return
}
var params map[string]interface{}
if err := c.ShouldBindJSON(&params); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 从参数中获取搜索关键词
keyword, ok := params["keyword"].(string)
if !ok || keyword == "" {
resp.ERROR(c, "搜索关键词不能为空")
return
}
// 从参数中获取最大页数默认为1页
maxPages := 1
if pages, ok := params["max_pages"].(float64); ok {
maxPages = int(pages)
}
// 获取用户ID
userID, ok := params["user_id"].(float64)
if !ok {
resp.ERROR(c, "用户ID不能为空")
return
}
// 查询用户信息
var user model.User
res := h.DB.Where("id = ?", int(userID)).First(&user)
if res.Error != nil {
resp.ERROR(c, "用户不存在")
return
}
// 检查用户算力是否足够
searchPower := 1 // 每次搜索消耗1点算力
if user.Power < searchPower {
resp.ERROR(c, "算力不足,无法执行网络搜索")
return
}
// 执行网络搜索
searchResults, err := crawler.SearchWeb(keyword, maxPages)
if err != nil {
resp.ERROR(c, fmt.Sprintf("搜索失败: %v", err))
return
}
// 扣减用户算力
err = h.userService.DecreasePower(user.Id, searchPower, model.PowerLog{
Type: types.PowerConsume,
Model: "web_search",
Remark: fmt.Sprintf("网络搜索:%s", utils.CutWords(keyword, 10)),
})
if err != nil {
resp.ERROR(c, "扣减算力失败:"+err.Error())
return
}
// 返回搜索结果
resp.SUCCESS(c, searchResults)
}
// List 获取所有的工具函数列表
func (h *FunctionHandler) List(c *gin.Context) {
var items []model.Function

View File

@@ -8,14 +8,18 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"strings"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"strings"
)
// InviteHandler 用户邀请
@@ -27,6 +31,23 @@ func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler {
return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *InviteHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/invite/")
// 公开接口,不需要授权
group.GET("hits", h.Hits)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.GET("code", h.Code)
group.GET("list", h.List)
group.GET("stats", h.Stats)
group.GET("rules", h.Rules)
}
}
// Code 获取当前用户邀请码
func (h *InviteHandler) Code(c *gin.Context) {
userId := h.GetLoginUserId(c)
@@ -65,21 +86,34 @@ func (h *InviteHandler) List(c *gin.Context) {
var total int64
session.Model(&model.InviteLog{}).Count(&total)
var items []model.InviteLog
var list = make([]vo.InviteLog, 0)
offset := (page - 1) * pageSize
res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items)
if res.Error == nil {
for _, item := range items {
var v vo.InviteLog
err := utils.CopyObject(item, &v)
if err == nil {
v.Id = item.Id
v.CreatedAt = item.CreatedAt.Unix()
list = append(list, v)
} else {
logger.Error(err)
}
err := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
userIds := make([]uint, 0)
for _, item := range items {
userIds = append(userIds, item.UserId)
}
userMap := make(map[uint]model.User)
var users []model.User
h.DB.Model(&model.User{}).Where("id IN (?)", userIds).Find(&users)
for _, user := range users {
userMap[user.Id] = user
}
var list = make([]vo.InviteLog, 0)
for _, item := range items {
var v vo.InviteLog
err := utils.CopyObject(item, &v)
if err != nil {
continue
}
v.CreatedAt = item.CreatedAt.Unix()
v.Avatar = userMap[item.UserId].Avatar
list = append(list, v)
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
}
@@ -90,3 +124,89 @@ func (h *InviteHandler) Hits(c *gin.Context) {
h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
resp.SUCCESS(c)
}
// Stats 获取邀请统计
func (h *InviteHandler) Stats(c *gin.Context) {
userId := h.GetLoginUserId(c)
// 获取邀请码
var inviteCode model.InviteCode
res := h.DB.Where("user_id = ?", userId).First(&inviteCode)
if res.Error != nil {
resp.ERROR(c, "邀请码不存在")
return
}
// 统计累计邀请数
var totalInvite int64
h.DB.Model(&model.InviteLog{}).Where("inviter_id = ?", userId).Count(&totalInvite)
// 统计今日邀请数
today := time.Now().Format("2006-01-02")
var todayInvite int64
h.DB.Model(&model.InviteLog{}).Where("inviter_id = ? AND DATE(created_at) = ?", userId, today).Count(&todayInvite)
// 获取系统配置中的邀请奖励
var config model.Config
var invitePower int = 200 // 默认值
if h.DB.Where("name = ?", "system").First(&config).Error == nil {
var configMap map[string]any
if utils.JsonDecode(config.Value, &configMap) == nil {
if power, ok := configMap["invite_power"].(float64); ok {
invitePower = int(power)
}
}
}
// 计算获得奖励总数
rewardTotal := int(totalInvite) * invitePower
// 构建邀请链接
inviteLink := fmt.Sprintf("%s/register?invite=%s", h.App.Config.StaticUrl, inviteCode.Code)
stats := vo.InviteStats{
InviteCount: int(totalInvite),
RewardTotal: rewardTotal,
TodayInvite: int(todayInvite),
InviteCode: inviteCode.Code,
InviteLink: inviteLink,
}
resp.SUCCESS(c, stats)
}
// Rules 获取奖励规则
func (h *InviteHandler) Rules(c *gin.Context) {
// 获取系统配置中的邀请奖励
var config model.Config
var invitePower int = 200 // 默认值
if h.DB.Where("name = ?", "system").First(&config).Error == nil {
var configMap map[string]interface{}
if utils.JsonDecode(config.Value, &configMap) == nil {
if power, ok := configMap["invite_power"].(float64); ok {
invitePower = int(power)
}
}
}
rules := []vo.RewardRule{
{
Id: 1,
Title: "好友注册",
Desc: "好友通过邀请链接成功注册",
Icon: "icon-user-fill",
Color: "#1989fa",
Reward: invitePower,
},
{
Id: 2,
Title: "好友首次充值",
Desc: "好友首次充值任意金额",
Icon: "icon-money",
Color: "#07c160",
Reward: invitePower * 2, // 假设首次充值奖励是注册奖励的2倍
},
}
resp.SUCCESS(c, rules)
}

View File

@@ -2,11 +2,12 @@ package handler
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/service/jimeng"
"geekai/service/moderation"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
@@ -19,27 +20,34 @@ import (
// JimengHandler 即梦AI处理器
type JimengHandler struct {
BaseHandler
jimengService *jimeng.Service
userService *service.UserService
jimengService *jimeng.Service
userService *service.UserService
moderationManager *moderation.ServiceManager
}
// NewJimengHandler 创建即梦AI处理器
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB, userService *service.UserService) *JimengHandler {
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB, userService *service.UserService, moderationManager *moderation.ServiceManager) *JimengHandler {
return &JimengHandler{
BaseHandler: BaseHandler{App: app, DB: db},
jimengService: jimengService,
userService: userService,
BaseHandler: BaseHandler{App: app, DB: db},
jimengService: jimengService,
userService: userService,
moderationManager: moderationManager,
}
}
// RegisterRoutes 注册路由,新增统一任务接口
func (h *JimengHandler) RegisterRoutes() {
rg := h.App.Engine.Group("/api/jimeng")
rg.POST("task", h.CreateTask) // 只保留统一任务接口
rg.GET("power-config", h.GetPowerConfig) // 新增算力配置接口
rg.POST("jobs", h.Jobs)
rg.GET("remove", h.Remove)
rg.GET("retry", h.Retry)
group := h.App.Engine.Group("/api/jimeng/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("task", h.CreateTask)
group.GET("power-config", h.GetPowerConfig)
group.POST("jobs", h.Jobs)
group.GET("remove", h.Remove)
group.GET("retry", h.Retry)
}
}
// JimengTaskRequest 统一任务请求结构体
@@ -70,6 +78,31 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
// 文本审核
if h.App.SysConfig.Moderation.Enable {
moderationResult, err := h.moderationManager.GetService().Moderate(req.Prompt)
if err != nil {
logger.Error("failed to moderate content: ", err)
}
if moderationResult.Flagged {
// 记录违规内容
moderation := model.Moderation{
UserId: h.GetLoginUserId(c),
Source: types.ModerationSourceJiMeng,
Input: req.Prompt,
Result: utils.JsonEncode(moderationResult),
}
err = h.DB.Create(&moderation).Error
if err != nil {
logger.Error("failed to save moderation: ", err)
}
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
return
}
}
// 新增:除图像特效外,其他任务类型必须有提示词
if req.TaskType != "image_effects" && req.Prompt == "" {
resp.ERROR(c, "提示词不能为空")
@@ -153,12 +186,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
"seed": req.Seed,
"scale": req.Scale,
}
if len(req.ImageUrls) > 0 {
params["image_urls"] = req.ImageUrls
}
if len(req.BinaryDataBase64) > 0 {
params["binary_data_base64"] = req.BinaryDataBase64
}
params["image_urls"] = []string{req.ImageInput}
case "image_effects":
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEffects)
taskType = model.JMTaskTypeImageEffects
@@ -181,9 +209,6 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
taskType = model.JMTaskTypeTextToVideo
reqKey = jimeng.ReqKeyTextToVideo
modelName = "即梦文生视频"
if req.Seed == 0 {
req.Seed = -1
}
if req.AspectRatio == "" {
req.AspectRatio = jimeng.AspectRatio16_9
}
@@ -196,9 +221,6 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
taskType = model.JMTaskTypeImageToVideo
reqKey = jimeng.ReqKeyImageToVideo
modelName = "即梦图生视频"
if req.Seed == 0 {
req.Seed = -1
}
params = map[string]any{
"seed": req.Seed,
"aspect_ratio": req.AspectRatio,
@@ -333,8 +355,10 @@ func (h *JimengHandler) Remove(c *gin.Context) {
resp.ERROR(c, "无权限操作")
return
}
if job.Status != model.JMTaskStatusFailed {
resp.ERROR(c, "只有失败的任务能删除")
// 正在运行中的任务能删除
if job.Status == model.JMTaskStatusGenerating || job.Status == model.JMTaskStatusInQueue {
resp.ERROR(c, "正在运行中的任务不能删除,否则无法退回算力")
return
}
@@ -345,17 +369,20 @@ func (h *JimengHandler) Remove(c *gin.Context) {
return
}
// 退回算力
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "jimeng",
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
})
if err != nil {
resp.ERROR(c, "退回算力失败")
tx.Rollback()
return
// 失败任务删除后退回算力
if job.Status != model.JMTaskStatusFailed {
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "jimeng",
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
})
if err != nil {
resp.ERROR(c, "退回算力失败")
tx.Rollback()
return
}
}
tx.Commit()
resp.SUCCESS(c, gin.H{})
@@ -408,7 +435,7 @@ func (h *JimengHandler) Retry(c *gin.Context) {
// getPowerFromConfig 从配置中获取指定类型的算力消耗
func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
config := h.jimengService.GetConfig()
config := h.App.SysConfig.Jimeng
switch taskType {
case model.JMTaskTypeTextToImage:
@@ -430,7 +457,7 @@ func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
// GetPowerConfig 获取即梦各任务类型算力消耗配置
func (h *JimengHandler) GetPowerConfig(c *gin.Context) {
config := h.jimengService.GetConfig()
config := h.App.SysConfig.Jimeng
resp.SUCCESS(c, gin.H{
"text_to_image": config.Power.TextToImage,
"image_to_image": config.Power.ImageToImage,

View File

@@ -10,6 +10,7 @@ package handler
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
@@ -35,6 +36,17 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us
}
}
// RegisterRoutes 注册路由
func (h *MarkMapHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/markMap/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("gen", h.Generate)
}
}
// Generate 生成思维导图
func (h *MarkMapHandler) Generate(c *gin.Context) {
var data struct {

View File

@@ -13,6 +13,7 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
@@ -25,6 +26,12 @@ func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
return &MenuHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *MenuHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/menu/")
group.GET("list", h.List)
}
// List 数据列表
func (h *MenuHandler) List(c *gin.Context) {
index := h.GetBool(c, "index")
@@ -33,7 +40,7 @@ func (h *MenuHandler) List(c *gin.Context) {
session := h.DB.Session(&gorm.Session{})
session = session.Where("enabled", true)
if index {
session = session.Where("id IN ?", h.App.SysConfig.IndexNavs)
session = session.Where("id IN ?", h.App.SysConfig.Base.IndexNavs)
}
res := session.Order("sort_num ASC").Find(&items)
if res.Error == nil {

View File

@@ -10,9 +10,11 @@ package handler
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/service/mj"
"geekai/service/moderation"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
@@ -27,18 +29,20 @@ import (
type MidJourneyHandler struct {
BaseHandler
mjService *mj.Service
snowflake *service.Snowflake
uploader *oss.UploaderManager
userService *service.UserService
mjService *mj.Service
snowflake *service.Snowflake
uploader *oss.UploaderManager
userService *service.UserService
moderationManager *moderation.ServiceManager
}
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService) *MidJourneyHandler {
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *MidJourneyHandler {
return &MidJourneyHandler{
snowflake: snowflake,
mjService: service,
uploader: manager,
userService: userService,
snowflake: snowflake,
mjService: service,
uploader: manager,
userService: userService,
moderationManager: moderationManager,
BaseHandler: BaseHandler{
App: app,
DB: db,
@@ -46,6 +50,25 @@ func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.S
}
}
// RegisterRoutes 注册路由
func (h *MidJourneyHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/mj/")
// 公开接口,不需要授权
group.GET("imgWall", h.ImgWall)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("image", h.Image)
group.POST("upscale", h.Upscale)
group.POST("variation", h.Variation)
group.GET("jobs", h.JobList)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}
}
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c)
if err != nil {
@@ -53,7 +76,7 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
return false
}
if user.Power < h.App.SysConfig.MjPower {
if user.Power < h.App.SysConfig.Base.MjPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false
}
@@ -90,6 +113,29 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return
}
// 文本审核
if h.App.SysConfig.Moderation.Enable {
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
if err != nil {
logger.Error("failed to moderate content: ", err)
}
if moderationResult.Flagged {
// 记录违规内容
moderation := model.Moderation{
UserId: h.GetLoginUserId(c),
Source: types.ModerationSourceMJ,
Input: data.Prompt,
Result: utils.JsonEncode(moderationResult),
}
err = h.DB.Create(&moderation).Error
if err != nil {
logger.Error("failed to save moderation: ", err)
}
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
return
}
}
var params = ""
if data.Rate != "" && !strings.Contains(params, "--ar") {
params += " --ar " + data.Rate
@@ -159,8 +205,8 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
Params: params,
UserId: userId,
ImgArr: data.ImgArr,
Mode: h.App.SysConfig.MjMode,
TranslateModelId: h.App.SysConfig.AssistantModelId,
Mode: h.App.SysConfig.Base.MjMode,
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
}
job := model.MidJourneyJob{
Type: data.TaskType,
@@ -169,7 +215,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
TaskInfo: utils.JsonEncode(task),
Progress: 0,
Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
Power: h.App.SysConfig.MjPower,
Power: h.App.SysConfig.Base.MjPower,
CreatedAt: time.Now(),
}
opt := "绘图"
@@ -232,7 +278,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
Index: data.Index,
MessageId: data.MessageId,
MessageHash: data.MessageHash,
Mode: h.App.SysConfig.MjMode,
Mode: h.App.SysConfig.Base.MjMode,
}
job := model.MidJourneyJob{
Type: types.TaskUpscale.String(),
@@ -240,7 +286,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
TaskId: taskId,
TaskInfo: utils.JsonEncode(task),
Progress: 0,
Power: h.App.SysConfig.MjActionPower,
Power: h.App.SysConfig.Base.MjActionPower,
CreatedAt: time.Now(),
}
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
@@ -287,7 +333,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
ChannelId: data.ChannelId,
MessageId: data.MessageId,
MessageHash: data.MessageHash,
Mode: h.App.SysConfig.MjMode,
Mode: h.App.SysConfig.Base.MjMode,
}
job := model.MidJourneyJob{
Type: types.TaskVariation.String(),
@@ -296,7 +342,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
TaskId: taskId,
TaskInfo: utils.JsonEncode(task),
Progress: 0,
Power: h.App.SysConfig.MjActionPower,
Power: h.App.SysConfig.Base.MjActionPower,
CreatedAt: time.Now(),
}
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {

View File

@@ -9,6 +9,7 @@ package handler
import (
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service/oss"
"geekai/store/model"
@@ -32,6 +33,22 @@ func NewNetHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManage
return &NetHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
}
// RegisterRoutes 注册路由
func (h *NetHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/upload")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("", h.Upload)
group.POST("list", h.List)
group.GET("remove", h.Remove)
}
// 公开接口,不需要授权
h.App.Engine.GET("/api/download", h.Download)
}
func (h *NetHandler) Upload(c *gin.Context) {
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
if err != nil {

View File

@@ -9,12 +9,12 @@ package handler
import (
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -28,6 +28,18 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *OrderHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/order/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.GET("list", h.List)
group.GET("query", h.Query)
}
}
// List 订单列表
func (h *OrderHandler) List(c *gin.Context) {
page := h.GetInt(c, "page", 1)
@@ -48,20 +60,21 @@ func (h *OrderHandler) List(c *gin.Context) {
order.Id = item.Id
order.CreatedAt = item.CreatedAt.Unix()
order.UpdatedAt = item.UpdatedAt.Unix()
payMethod, ok := types.PayMethods[item.PayWay]
payChannel, ok := types.PayChannel[item.Channel]
if !ok {
payMethod = item.PayWay
payChannel = item.PayWay
}
payName, ok := types.PayNames[item.PayType]
payWays, ok := types.PayWays[item.PayWay]
if !ok {
payName = item.PayWay
payWays = item.PayWay
}
order.PayMethod = payMethod
order.PayName = payName
order.ChannelName = payChannel
order.PayName = payWays
list = append(list, order)
} else {
logger.Error(err)
}
}
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
@@ -82,17 +95,8 @@ func (h *OrderHandler) Query(c *gin.Context) {
return
}
counter := 0
for {
time.Sleep(time.Second)
var item model.Order
h.DB.Where("order_no = ?", orderNo).First(&item)
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
order.Status = item.Status
break
}
counter++
}
var item model.Order
h.DB.Where("order_no = ?", orderNo).First(&item)
resp.SUCCESS(c, gin.H{"status": order.Status})
}

View File

@@ -11,6 +11,7 @@ import (
"embed"
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/service/payment"
@@ -33,52 +34,148 @@ type PayWay struct {
// PaymentHandler 支付服务回调 handler
type PaymentHandler struct {
BaseHandler
alipayService *payment.AlipayService
huPiPayService *payment.HuPiPayService
geekPayService *payment.GeekPayService
wechatPayService *payment.WechatPayService
snowflake *service.Snowflake
userService *service.UserService
fs embed.FS
lock sync.Mutex
signKey string // 用来签名的随机秘钥
alipayService *payment.AlipayService
epayService *payment.EPayService
wxpayService *payment.WxPayService
snowflake *service.Snowflake
userService *service.UserService
fs embed.FS
lock sync.Mutex
config *types.PaymentConfig
}
func NewPaymentHandler(
server *core.AppServer,
alipayService *payment.AlipayService,
huPiPayService *payment.HuPiPayService,
geekPayService *payment.GeekPayService,
wechatPayService *payment.WechatPayService,
geekPayService *payment.EPayService,
wxpayService *payment.WxPayService,
db *gorm.DB,
userService *service.UserService,
snowflake *service.Snowflake,
fs embed.FS) *PaymentHandler {
fs embed.FS,
sysConfig *types.SystemConfig) *PaymentHandler {
return &PaymentHandler{
alipayService: alipayService,
huPiPayService: huPiPayService,
geekPayService: geekPayService,
wechatPayService: wechatPayService,
snowflake: snowflake,
userService: userService,
fs: fs,
lock: sync.Mutex{},
alipayService: alipayService,
epayService: geekPayService,
wxpayService: wxpayService,
snowflake: snowflake,
userService: userService,
fs: fs,
lock: sync.Mutex{},
BaseHandler: BaseHandler{
App: server,
DB: db,
},
signKey: utils.RandString(32),
config: &sysConfig.Payment,
}
}
func (h *PaymentHandler) Pay(c *gin.Context) {
// RegisterRoutes 注册路由
func (h *PaymentHandler) RegisterRoutes() {
rg := h.App.Engine.Group("/api/payment/")
// 支付回调接口(公开)
rg.POST("notify/alipay", h.AlipayNotify)
rg.GET("notify/epay", h.EPayNotify)
rg.POST("notify/wxpay", h.WxpayNotify)
// 需要用户登录的接口
rg.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
rg.POST("create", h.CreateOrder)
}
}
func (h *PaymentHandler) StartSyncOrders() {
go func() {
for {
err := h.SyncOrders()
if err != nil {
logger.Error(err)
}
time.Sleep(time.Second * 5)
}
}()
}
// SyncOrders 同步订单状态
func (h *PaymentHandler) SyncOrders() error {
defer func() {
if err := recover(); err != nil {
logger.Errorf("同步订单状态发生异常: %v", err)
}
}()
var orders []model.Order
err := h.DB.Where("status", types.OrderNotPaid).Where("checked", false).Find(&orders).Error
if err != nil {
return err
}
for _, order := range orders {
time.Sleep(time.Second * 1)
//超时15分钟的订单直接标记为已关闭
if time.Now().After(order.CreatedAt.Add(time.Minute * 5)) {
h.DB.Model(&model.Order{}).Where("id", order.Id).Update("checked", true)
logger.Errorf("订单超时:%v", order)
continue
}
// 查询订单状态
var res payment.OrderInfo
switch order.Channel {
case payment.PayChannelEpay:
res, err = h.epayService.Query(order.OrderNo)
if err != nil {
logger.Errorf("error with query order info: %v", err)
continue
}
// 微信支付
case payment.PayChannelWX:
res, err = h.wxpayService.Query(order.OrderNo)
logger.Debugf("微信支付订单状态:%+v", res)
if err != nil {
logger.Errorf("error with query order info: %v", err)
continue
}
case payment.PayChannelAL:
res, err = h.alipayService.Query(order.OrderNo)
if err != nil {
logger.Errorf("error with query order info: %v", err)
continue
}
}
// 订单已关闭
if res.Closed() {
h.DB.Model(&model.Order{}).Where("id", order.Id).Updates(map[string]any{
"checked": true,
"status": types.OrderPaidFailed,
})
logger.Errorf("订单已关闭:%v", order)
continue
}
// 订单未支付,不处理,继续轮询
if !res.Success() {
continue
}
// 订单支付成功
err = h.paySuccess(res)
if err != nil {
logger.Errorf("error with deal order: %v", err)
continue
}
}
return nil
}
func (h *PaymentHandler) CreateOrder(c *gin.Context) {
var data struct {
PayWay string `json:"pay_way"`
PayType string `json:"pay_type"`
ProductId int `json:"product_id"`
UserId int `json:"user_id"`
Device string `json:"device"`
Host string `json:"host"`
PayWay string `json:"pay_way,omitempty"` // 支付方式:支付宝,微信
Pid int `json:"pid,omitempty"`
Device string `json:"device,omitempty"`
Domain string `json:"domain,omitempty"` // 支付回调域名
Channel string `json:"channel,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -86,7 +183,7 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
}
var product model.Product
err := h.DB.Where("id", data.ProductId).First(&product).Error
err := h.DB.Where("id", data.Pid).First(&product).Error
if err != nil {
resp.ERROR(c, "Product not found")
return
@@ -97,136 +194,118 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
resp.ERROR(c, "error with generate trade no: "+err.Error())
return
}
userId := h.GetLoginUserId(c)
var user model.User
err = h.DB.Where("id", data.UserId).First(&user).Error
err = h.DB.Where("id", userId).First(&user).Error
if err != nil {
resp.NotAuth(c)
return
}
amount := product.Discount
var payURL, returnURL, notifyURL string
amount := product.Price
var payURL, notifyURL string
switch data.PayWay {
case "alipay":
if h.App.Config.AlipayConfig.NotifyURL != "" { // 用于本地调试支付
notifyURL = h.App.Config.AlipayConfig.NotifyURL
} else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Host)
}
if h.App.Config.AlipayConfig.ReturnURL != "" { // 用于本地调试支付
returnURL = h.App.Config.AlipayConfig.ReturnURL
} else {
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
}
money := fmt.Sprintf("%.2f", amount)
if data.Device == "wechat" {
payURL, err = h.alipayService.PayMobile(payment.AlipayParams{
case "wxpay":
logger.Debugf("微信支付,%+v", data)
data.Channel = payment.PayChannelWX
// 优先使用微信官方支付
if h.config.WxPay.Enabled {
data.Channel = "wxpay"
if h.config.WxPay.Domain != "" {
data.Domain = h.config.WxPay.Domain
}
notifyURL = fmt.Sprintf("%s/api/payment/notify/wxpay", data.Domain)
payURL, err = h.wxpayService.Pay(payment.PayRequest{
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: money,
ReturnURL: returnURL,
NotifyURL: notifyURL,
})
} else {
payURL, err = h.alipayService.PayPC(payment.AlipayParams{
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: money,
ReturnURL: returnURL,
NotifyURL: notifyURL,
})
}
if err != nil {
resp.ERROR(c, "error with generate pay url: "+err.Error())
return
}
break
case "wechat":
if h.App.Config.WechatPayConfig.NotifyURL != "" {
notifyURL = h.App.Config.WechatPayConfig.NotifyURL
} else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/wechat", data.Host)
}
if data.Device == "wechat" {
payURL, err = h.wechatPayService.PayUrlH5(payment.WechatPayParams{
OutTradeNo: orderNo,
TotalFee: int(amount * 100),
TotalFee: fmt.Sprintf("%d", int(amount*100)),
Subject: product.Name,
NotifyURL: notifyURL,
ClientIP: c.ClientIP(),
Device: data.Device,
PayWay: payment.PayWayWX,
})
} else {
payURL, err = h.wechatPayService.PayUrlNative(payment.WechatPayParams{
if err != nil {
resp.ERROR(c, err.Error())
return
}
} else if h.config.Epay.Enabled { // 聚合支付
logger.Debugf("聚合支付%+v", data)
data.Channel = payment.PayChannelEpay
if h.config.Epay.Domain != "" {
data.Domain = h.config.Epay.Domain
}
notifyURL = fmt.Sprintf("%s/api/payment/notify/epay", data.Domain)
params := payment.PayRequest{
OutTradeNo: orderNo,
TotalFee: int(amount * 100),
Subject: product.Name,
TotalFee: fmt.Sprintf("%f", amount),
ClientIP: c.ClientIP(),
Device: data.Device,
PayWay: payment.PayWayWX,
NotifyURL: notifyURL,
}
r, err := h.epayService.Pay(params)
logger.Debugf("请求支付结果,%+v", r)
if err != nil {
resp.ERROR(c, err.Error())
return
} else {
payURL = r
}
} else {
resp.ERROR(c, "系统没有配置可用的支付渠道!")
return
}
case "alipay":
if h.config.Alipay.Enabled {
logger.Debugf("支付宝,%+v", data)
data.Channel = payment.PayChannelAL
if h.config.Alipay.Domain != "" { // 用于本地调试支付
data.Domain = h.config.Alipay.Domain
}
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Domain)
money := fmt.Sprintf("%.2f", amount)
payURL, err = h.alipayService.Pay(payment.PayRequest{
Device: data.Device,
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: money,
NotifyURL: notifyURL,
})
}
if err != nil {
resp.ERROR(c, err.Error())
return
}
break
case "hupi":
if h.App.Config.HuPiPayConfig.NotifyURL != "" {
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
} else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/hupi", data.Host)
}
if h.App.Config.HuPiPayConfig.ReturnURL != "" {
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
} else {
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
}
r, err := h.huPiPayService.Pay(payment.HuPiPayParams{
Version: "1.1",
TradeOrderId: orderNo,
TotalFee: fmt.Sprintf("%f", amount),
Title: product.Name,
NotifyURL: notifyURL,
ReturnURL: returnURL,
WapName: "GeekAI助手",
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
payURL = r.URL
break
case "geek":
if h.App.Config.GeekPayConfig.NotifyURL != "" {
notifyURL = h.App.Config.GeekPayConfig.NotifyURL
} else {
notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Host)
}
if h.App.Config.GeekPayConfig.ReturnURL != "" {
data.Host = utils.GetBaseURL(h.App.Config.GeekPayConfig.ReturnURL)
}
if data.Device == "wechat" { // 微信客户端打开,调回手机端用户中心页面
returnURL = fmt.Sprintf("%s/mobile/profile", data.Host)
} else {
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
}
params := payment.GeekPayParams{
OutTradeNo: orderNo,
Method: "web",
Name: product.Name,
Money: fmt.Sprintf("%f", amount),
ClientIP: c.ClientIP(),
Device: data.Device,
Type: data.PayType,
ReturnURL: returnURL,
NotifyURL: notifyURL,
}
res, err := h.geekPayService.Pay(params)
if err != nil {
resp.ERROR(c, err.Error())
if err != nil {
resp.ERROR(c, "error with generate pay url: "+err.Error())
return
}
} else if h.config.Epay.Enabled { // 聚合支付
logger.Debugf("聚合支付,%+v", data)
data.Channel = payment.PayChannelEpay
if h.config.Epay.Domain != "" {
data.Domain = h.config.Epay.Domain
}
notifyURL = fmt.Sprintf("%s/api/payment/notify/epay", data.Domain)
params := payment.PayRequest{
OutTradeNo: orderNo,
Subject: product.Name,
TotalFee: fmt.Sprintf("%f", amount),
ClientIP: c.ClientIP(),
Device: data.Device,
PayWay: data.PayWay,
NotifyURL: notifyURL,
}
r, err := h.epayService.Pay(params)
if err != nil {
resp.ERROR(c, err.Error())
return
} else {
payURL = r
}
} else {
resp.ERROR(c, "系统没有配置可用的支付渠道!")
return
}
payURL = res.PayURL
default:
resp.ERROR(c, "不支持的支付渠道")
return
@@ -234,43 +313,40 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
// 创建订单
remark := types.OrderRemark{
Days: product.Days,
Power: product.Power,
Name: product.Name,
Price: product.Price,
Discount: product.Discount,
Power: product.Power,
Name: product.Name,
Price: product.Price,
}
order := model.Order{
UserId: user.Id,
Username: user.Username,
ProductId: product.Id,
OrderNo: orderNo,
Subject: product.Name,
Amount: amount,
Status: types.OrderNotPaid,
PayWay: data.PayWay,
PayType: data.PayType,
Remark: utils.JsonEncode(remark),
UserId: user.Id,
Username: user.Username,
OrderNo: orderNo,
Subject: product.Name,
Amount: amount,
Status: types.OrderNotPaid,
PayWay: data.PayWay,
Channel: data.Channel,
Remark: utils.JsonEncode(remark),
}
err = h.DB.Create(&order).Error
if err != nil {
resp.ERROR(c, "error with create order: "+err.Error())
return
}
resp.SUCCESS(c, payURL)
resp.SUCCESS(c, gin.H{"pay_url": payURL, "order_no": orderNo})
}
// 异步通知回调公共逻辑
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
// 支付成功处理
func (h *PaymentHandler) paySuccess(info payment.OrderInfo) error {
h.lock.Lock()
defer h.lock.Unlock()
var order model.Order
err := h.DB.Where("order_no = ?", orderNo).First(&order).Error
err := h.DB.Where("order_no", info.OutTradeNo).First(&order).Error
if err != nil {
return fmt.Errorf("error with fetch order: %v", err)
}
h.lock.Lock()
defer h.lock.Unlock()
// 已支付订单,直接返回
if order.Status == types.OrderPaidSuccess {
return nil
@@ -290,19 +366,21 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
// 增加用户算力
err = h.userService.IncreasePower(order.UserId, remark.Power, model.PowerLog{
Type: types.PowerRecharge,
Model: order.PayWay,
Remark: fmt.Sprintf("充值算力,金额:%f订单号%s", order.Amount, order.OrderNo),
Type: types.PowerRecharge,
Model: order.Subject,
Remark: fmt.Sprintf("充值算力,金额:%f订单号%s", order.Amount, order.OrderNo),
CreatedAt: time.Now(),
})
if err != nil {
return err
}
// 更新订单状态
order.PayTime = time.Now().Unix()
order.PayTime = utils.Str2stamp(info.PayTime)
order.Status = types.OrderPaidSuccess
order.TradeNo = tradeNo
err = h.DB.Updates(&order).Error
order.TradeNo = info.TradeId
order.Checked = true
err = h.DB.Debug().Updates(&order).Error
if err != nil {
return fmt.Errorf("error with update order info: %v", err)
}
@@ -317,54 +395,6 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
return nil
}
// GetPayWays 获取支付方式
func (h *PaymentHandler) GetPayWays(c *gin.Context) {
payWays := make([]gin.H, 0)
if h.App.Config.AlipayConfig.Enabled {
payWays = append(payWays, gin.H{"pay_way": "alipay", "pay_type": "alipay"})
}
if h.App.Config.HuPiPayConfig.Enabled {
payWays = append(payWays, gin.H{"pay_way": "hupi", "pay_type": "wxpay"})
}
if h.App.Config.GeekPayConfig.Enabled {
for _, v := range h.App.Config.GeekPayConfig.Methods {
payWays = append(payWays, gin.H{"pay_way": "geek", "pay_type": v})
}
}
if h.App.Config.WechatPayConfig.Enabled {
payWays = append(payWays, gin.H{"pay_way": "wechat", "pay_type": "wxpay"})
}
resp.SUCCESS(c, payWays)
}
// HuPiPayNotify 虎皮椒支付异步回调
func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
err := c.Request.ParseForm()
if err != nil {
c.String(http.StatusOK, "fail")
return
}
orderNo := c.Request.Form.Get("trade_order_id")
tradeNo := c.Request.Form.Get("open_order_id")
logger.Infof("收到虎皮椒订单支付回调,%+v", c.Request.Form)
if err = h.huPiPayService.Check(orderNo); err != nil {
logger.Error("订单校验失败:", err)
c.String(http.StatusOK, "fail")
return
}
err = h.notify(orderNo, tradeNo)
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")
return
}
c.String(http.StatusOK, "success")
}
// AlipayNotify 支付宝支付回调
func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
err := c.Request.ParseForm()
@@ -373,16 +403,15 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
return
}
result := h.alipayService.TradeVerify(c.Request)
logger.Infof("收到支付宝商号订单支付回调:%+v", result)
if !result.Success() {
logger.Error("订单校验失败:", result.Message)
orderInfo, err := h.alipayService.Query(c.Request.Form.Get("out_trade_no"))
logger.Infof("收到支付宝商号订单支付回调:%+v", orderInfo)
if !orderInfo.Success() {
logger.Errorf("订单校验失败:%v", err)
c.String(http.StatusOK, "fail")
return
}
tradeNo := c.Request.Form.Get("trade_no")
err = h.notify(result.OutTradeNo, tradeNo)
err = h.paySuccess(orderInfo)
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")
@@ -392,28 +421,35 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
c.String(http.StatusOK, "success")
}
// GeekPayNotify 支付异步回调
func (h *PaymentHandler) GeekPayNotify(c *gin.Context) {
// EPayNotify 易支付支付异步回调
func (h *PaymentHandler) EPayNotify(c *gin.Context) {
var params = make(map[string]string)
for k := range c.Request.URL.Query() {
params[k] = c.Query(k)
}
logger.Infof("收到GeekPay订单支付回调:%+v", params)
// 检查支付状态
logger.Infof("收到易支付订单支付回调:%+v", params)
// 检查支付状态, 如果未支付,则返回成功
if params["trade_status"] != "TRADE_SUCCESS" {
c.String(http.StatusOK, "success")
return
}
sign := h.geekPayService.Sign(params)
sign := h.epayService.Sign(params)
if sign != c.Query("sign") {
logger.Errorf("签名验证失败, %s, %s", sign, c.Query("sign"))
c.String(http.StatusOK, "fail")
return
}
// 查询订单状态
order, err := h.epayService.Query(params["out_trade_no"])
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")
return
}
err := h.notify(params["out_trade_no"], params["trade_no"])
err = h.paySuccess(order)
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")
@@ -423,26 +459,23 @@ func (h *PaymentHandler) GeekPayNotify(c *gin.Context) {
c.String(http.StatusOK, "success")
}
// WechatPayNotify 微信商户支付异步回调
func (h *PaymentHandler) WechatPayNotify(c *gin.Context) {
// WxpayNotify 微信商户支付异步回调
func (h *PaymentHandler) WxpayNotify(c *gin.Context) {
err := c.Request.ParseForm()
if err != nil {
c.String(http.StatusOK, "fail")
return
}
result := h.wechatPayService.TradeVerify(c.Request)
logger.Infof("收到微信商号订单支付回调:%+v", result)
if !result.Success() {
logger.Error("订单校验失败:", err)
c.JSON(http.StatusBadRequest, gin.H{
"code": "FAIL",
"message": err.Error(),
})
orderInfo, err := h.wxpayService.TradeVerify(c.Request)
logger.Infof("收到微信商号订单支付回调:%+v", orderInfo)
if err != nil {
logger.Errorf("订单校验失败:%v", err)
c.JSON(http.StatusBadRequest, gin.H{"code": "FAIL"})
return
}
err = h.notify(result.OutTradeNo, result.TradeId)
err = h.paySuccess(orderInfo)
if err != nil {
logger.Error(err)
c.String(http.StatusOK, "fail")

View File

@@ -9,11 +9,13 @@ package handler
import (
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -27,6 +29,18 @@ func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *PowerLogHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/powerLog/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("list", h.List)
group.GET("stats", h.Stats)
}
}
func (h *PowerLogHandler) List(c *gin.Context) {
var data struct {
Model string `json:"model"`
@@ -72,3 +86,45 @@ func (h *PowerLogHandler) List(c *gin.Context) {
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
}
// Stats 获取用户算力统计
func (h *PowerLogHandler) Stats(c *gin.Context) {
userId := h.GetLoginUserId(c)
if userId == 0 {
resp.NotAuth(c)
return
}
// 获取用户信息(包含余额)
var user model.User
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
resp.ERROR(c, "用户不存在")
return
}
// 计算总消费(所有支出记录)
var totalConsume int64
h.DB.Model(&model.PowerLog{}).
Where("user_id", userId).
Where("mark", types.PowerSub).
Select("COALESCE(SUM(amount), 0)").
Scan(&totalConsume)
// 计算今日消费
today := time.Now().Format("2006-01-02")
var todayConsume int64
h.DB.Model(&model.PowerLog{}).
Where("user_id", userId).
Where("mark", types.PowerSub).
Where("DATE(created_at) = ?", today).
Select("COALESCE(SUM(amount), 0)").
Scan(&todayConsume)
stats := map[string]interface{}{
"total": totalConsume,
"today": todayConsume,
"balance": user.Power,
}
resp.SUCCESS(c, stats)
}

View File

@@ -13,6 +13,7 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
@@ -25,6 +26,12 @@ func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}}
}
// RegisterRoutes 注册路由
func (h *ProductHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/product/")
group.GET("list", h.List)
}
// List 模型列表
func (h *ProductHandler) List(c *gin.Context) {
var items []model.Product

View File

@@ -10,12 +10,14 @@ package handler
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"strings"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -39,6 +41,20 @@ func NewPromptHandler(app *core.AppServer, db *gorm.DB, userService *service.Use
}
}
// RegisterRoutes 注册路由
func (h *PromptHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/prompt/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis)).Use(middleware.RateLimitEvery(h.App.Redis, 30*time.Second))
{
group.POST("lyric", h.Lyric)
group.POST("image", h.Image)
group.POST("video", h.Video)
group.POST("meta", h.MetaPrompt)
}
}
// Lyric 生成歌词
func (h *PromptHandler) Lyric(c *gin.Context) {
var data struct {
@@ -48,25 +64,12 @@ func (h *PromptHandler) Lyric(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if h.App.SysConfig.PromptPower > 0 {
userId := h.GetLoginUserId(c)
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
Type: types.PowerConsume,
Model: h.getPromptModel(),
Remark: "生成歌词",
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c, content)
}
@@ -79,23 +82,12 @@ func (h *PromptHandler) Image(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if h.App.SysConfig.PromptPower > 0 {
userId := h.GetLoginUserId(c)
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
Type: types.PowerConsume,
Model: h.getPromptModel(),
Remark: "生成绘画提示词",
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c, strings.Trim(content, `"`))
}
@@ -108,25 +100,12 @@ func (h *PromptHandler) Video(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if h.App.SysConfig.PromptPower > 0 {
userId := h.GetLoginUserId(c)
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
Type: types.PowerConsume,
Model: h.getPromptModel(),
Remark: "生成视频脚本",
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
}
resp.SUCCESS(c, strings.Trim(content, `"`))
}
@@ -158,9 +137,9 @@ func (h *PromptHandler) MetaPrompt(c *gin.Context) {
}
func (h *PromptHandler) getPromptModel() string {
if h.App.SysConfig.AssistantModelId > 0 {
if h.App.SysConfig.Base.AssistantModelId > 0 {
var chatModel model.ChatModel
h.DB.Where("id", h.App.SysConfig.AssistantModelId).First(&chatModel)
h.DB.Where("id", h.App.SysConfig.Base.AssistantModelId).First(&chatModel)
return chatModel.Value
}
return "gpt-4o"

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
@@ -39,6 +40,18 @@ func NewRealtimeHandler(server *core.AppServer, db *gorm.DB, userService *servic
return &RealtimeHandler{BaseHandler: BaseHandler{App: server, DB: db}, userService: userService}
}
// RegisterRoutes 注册路由
func (h *RealtimeHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/realtime/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.Any("", h.Connection)
group.POST("voice", h.VoiceChat)
}
}
func (h *RealtimeHandler) Connection(c *gin.Context) {
// 获取客户端请求中指定的子协议
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
@@ -154,7 +167,7 @@ func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
return
}
if user.Power < h.App.SysConfig.AdvanceVoicePower {
if user.Power < h.App.SysConfig.Base.AdvanceVoicePower {
resp.ERROR(c, "当前用户算力不足,无法使用该功能")
return
}
@@ -198,7 +211,7 @@ func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 扣减算力
err = h.userService.DecreasePower(userId, h.App.SysConfig.AdvanceVoicePower, model.PowerLog{
err = h.userService.DecreasePower(userId, h.App.SysConfig.Base.AdvanceVoicePower, model.PowerLog{
Type: types.PowerConsume,
Model: "advanced-voice",
Remark: "实时语音通话",

View File

@@ -10,14 +10,16 @@ package handler
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"sync"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type RedeemHandler struct {
@@ -30,6 +32,17 @@ func NewRedeemHandler(app *core.AppServer, db *gorm.DB, userService *service.Use
return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}, userService: userService}
}
// RegisterRoutes 注册路由
func (h *RedeemHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/redeem/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("verify", h.Verify)
}
}
func (h *RedeemHandler) Verify(c *gin.Context) {
var data struct {
Code string `json:"code"`

View File

@@ -10,8 +10,10 @@ package handler
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/service/moderation"
"geekai/service/oss"
"geekai/service/sd"
"geekai/store"
@@ -28,12 +30,13 @@ import (
type SdJobHandler struct {
BaseHandler
redis *redis.Client
sdService *sd.Service
uploader *oss.UploaderManager
snowflake *service.Snowflake
leveldb *store.LevelDB
userService *service.UserService
redis *redis.Client
sdService *sd.Service
uploader *oss.UploaderManager
snowflake *service.Snowflake
leveldb *store.LevelDB
userService *service.UserService
moderationManager *moderation.ServiceManager
}
func NewSdJobHandler(app *core.AppServer,
@@ -42,13 +45,15 @@ func NewSdJobHandler(app *core.AppServer,
manager *oss.UploaderManager,
snowflake *service.Snowflake,
userService *service.UserService,
levelDB *store.LevelDB) *SdJobHandler {
levelDB *store.LevelDB,
moderationManager *moderation.ServiceManager) *SdJobHandler {
return &SdJobHandler{
sdService: service,
uploader: manager,
snowflake: snowflake,
leveldb: levelDB,
userService: userService,
sdService: service,
uploader: manager,
snowflake: snowflake,
leveldb: levelDB,
userService: userService,
moderationManager: moderationManager,
BaseHandler: BaseHandler{
App: app,
DB: db,
@@ -56,6 +61,23 @@ func NewSdJobHandler(app *core.AppServer,
}
}
// RegisterRoutes 注册路由
func (h *SdJobHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/sd/")
// 公开接口,不需要授权
group.GET("imgWall", h.ImgWall)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}
}
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c)
if err != nil {
@@ -63,7 +85,7 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool {
return false
}
if user.Power < h.App.SysConfig.SdPower {
if user.Power < h.App.SysConfig.Base.SdPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false
}
@@ -84,6 +106,29 @@ func (h *SdJobHandler) Image(c *gin.Context) {
return
}
if h.App.SysConfig.Moderation.Enable {
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
if err != nil {
logger.Error("failed to moderate content: ", err)
}
if moderationResult.Flagged {
// 记录违规内容
moderation := model.Moderation{
UserId: h.GetLoginUserId(c),
Source: types.ModerationSourceSD,
Input: data.Prompt,
Result: utils.JsonEncode(moderationResult),
}
err = h.DB.Create(&moderation).Error
if err != nil {
logger.Error("failed to save moderation: ", err)
}
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
return
}
}
if data.Width <= 0 {
data.Width = 512
}
@@ -131,7 +176,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
HdSteps: data.HdSteps,
},
UserId: userId,
TranslateModelId: h.App.SysConfig.AssistantModelId,
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
}
job := model.SdJob{
@@ -142,7 +187,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
TaskInfo: utils.JsonEncode(task),
Prompt: data.Prompt,
Progress: 0,
Power: h.App.SysConfig.SdPower,
Power: h.App.SysConfig.Base.SdPower,
CreatedAt: time.Now(),
}
res := h.DB.Create(&job)

View File

@@ -24,24 +24,31 @@ const CodeStorePrefix = "/verify/codes/"
type SmsHandler struct {
BaseHandler
redis *redis.Client
sms *sms.ServiceManager
smtp *service.SmtpService
captcha *service.CaptchaService
redis *redis.Client
sms *sms.SmsManager
smtp *service.SmtpService
captchaService *service.CaptchaService
}
func NewSmsHandler(
app *core.AppServer,
client *redis.Client,
sms *sms.ServiceManager,
sms *sms.SmsManager,
smtp *service.SmtpService,
captcha *service.CaptchaService) *SmsHandler {
return &SmsHandler{
redis: client,
sms: sms,
captcha: captcha,
smtp: smtp,
BaseHandler: BaseHandler{App: app}}
redis: client,
sms: sms,
captchaService: captcha,
smtp: smtp,
BaseHandler: BaseHandler{App: app}}
}
// RegisterRoutes 注册路由
func (h *SmsHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/sms/")
// 无需授权的接口
group.POST("code", h.SendCode)
}
// SendCode 发送验证码
@@ -56,12 +63,12 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
if h.App.SysConfig.EnabledVerify {
if h.captchaService.GetConfig().Enabled {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
check = h.captchaService.SlideCheck(data)
} else {
check = h.captcha.Check(data)
check = h.captchaService.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
@@ -72,14 +79,14 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
code := utils.RandomNumber(6)
var err error
if strings.Contains(data.Receiver, "@") { // email
if !utils.Contains(h.App.SysConfig.RegisterWays, "email") {
if !utils.Contains(h.App.SysConfig.Base.RegisterWays, "email") {
resp.ERROR(c, "系统已禁用邮箱注册!")
return
}
// 检查邮箱后缀是否在白名单
if len(h.App.SysConfig.EmailWhiteList) > 0 {
if len(h.App.SysConfig.Base.EmailWhiteList) > 0 {
inWhiteList := false
for _, suffix := range h.App.SysConfig.EmailWhiteList {
for _, suffix := range h.App.SysConfig.Base.EmailWhiteList {
if strings.HasSuffix(data.Receiver, suffix) {
inWhiteList = true
break
@@ -92,7 +99,7 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
}
err = h.smtp.SendVerifyCode(data.Receiver, code)
} else {
if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") {
if !utils.Contains(h.App.SysConfig.Base.RegisterWays, "mobile") {
resp.ERROR(c, "系统已禁用手机号注册!")
return
}

View File

@@ -10,8 +10,10 @@ package handler
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/service/moderation"
"geekai/service/oss"
"geekai/service/suno"
"geekai/store/model"
@@ -26,20 +28,41 @@ import (
type SunoHandler struct {
BaseHandler
sunoService *suno.Service
uploader *oss.UploaderManager
userService *service.UserService
sunoService *suno.Service
uploader *oss.UploaderManager
userService *service.UserService
moderationManager *moderation.ServiceManager
}
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService) *SunoHandler {
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *SunoHandler {
return &SunoHandler{
BaseHandler: BaseHandler{
App: app,
DB: db,
},
sunoService: service,
uploader: uploader,
userService: userService,
sunoService: service,
uploader: uploader,
userService: userService,
moderationManager: moderationManager,
}
}
// RegisterRoutes 注册路由
func (h *SunoHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/suno/")
// 公开接口,不需要授权
group.GET("play", h.Play)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("create", h.Create)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
group.POST("update", h.Update)
group.GET("detail", h.Detail)
}
}
@@ -64,13 +87,36 @@ func (h *SunoHandler) Create(c *gin.Context) {
return
}
if h.App.SysConfig.Moderation.Enable {
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
if err != nil {
logger.Error("failed to moderate content: ", err)
}
if moderationResult.Flagged {
// 记录违规内容
moderation := model.Moderation{
UserId: h.GetLoginUserId(c),
Source: types.ModerationSourceSuno,
Input: data.Prompt,
Result: utils.JsonEncode(moderationResult),
}
err = h.DB.Create(&moderation).Error
if err != nil {
logger.Error("failed to save moderation: ", err)
}
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
return
}
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
if user.Power < h.App.SysConfig.SunoPower {
if user.Power < h.App.SysConfig.Base.SunoPower {
resp.ERROR(c, "您的算力不足,请充值后再试!")
return
}
@@ -118,7 +164,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
RefSongId: data.RefSongId,
RefTaskId: data.RefTaskId,
ExtendSecs: data.ExtendSecs,
Power: h.App.SysConfig.SunoPower,
Power: h.App.SysConfig.Base.SunoPower,
SongId: utils.RandString(32),
}
if data.Lyrics != "" {

View File

@@ -1,21 +1,36 @@
package handler
import (
"geekai/core"
"geekai/core/middleware"
"geekai/service"
"geekai/service/payment"
"net/http"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"net/http"
)
type TestHandler struct {
App *core.AppServer
db *gorm.DB
snowflake *service.Snowflake
js *payment.GeekPayService
js *payment.EPayService
}
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.GeekPayService) *TestHandler {
return &TestHandler{db: db, snowflake: snowflake, js: js}
func NewTestHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, js *payment.EPayService) *TestHandler {
return &TestHandler{App: app, db: db, snowflake: snowflake, js: js}
}
// RegisterRoutes 注册路由
func (h *TestHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/test/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.Any("sse", h.PostTest, h.SseTest)
}
}
func (h *TestHandler) SseTest(c *gin.Context) {

View File

@@ -8,8 +8,10 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/store"
@@ -20,8 +22,6 @@ import (
"strings"
"time"
"github.com/imroc/req/v3"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
@@ -36,8 +36,10 @@ type UserHandler struct {
redis *redis.Client
levelDB *store.LevelDB
licenseService *service.LicenseService
captcha *service.CaptchaService
captchaService *service.CaptchaService
userService *service.UserService
wxLoginService *service.WxLoginService
ipSearcher *xdb.Searcher
}
func NewUserHandler(
@@ -48,15 +50,45 @@ func NewUserHandler(
levelDB *store.LevelDB,
captcha *service.CaptchaService,
userService *service.UserService,
wxLoginService *service.WxLoginService,
ipSearcher *xdb.Searcher,
licenseService *service.LicenseService) *UserHandler {
return &UserHandler{
BaseHandler: BaseHandler{DB: db, App: app},
searcher: searcher,
redis: client,
levelDB: levelDB,
captcha: captcha,
captchaService: captcha,
licenseService: licenseService,
userService: userService,
wxLoginService: wxLoginService,
ipSearcher: ipSearcher,
}
}
// RegisterRoutes 注册路由
func (h *UserHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/user/")
// 公开接口,不需要授权
group.POST("register", h.Register)
group.POST("login", h.Login)
group.POST("resetPass", h.ResetPass)
group.GET("login/qrcode", h.GetWxLoginQRCode)
group.POST("login/callback", h.WxLoginCallback)
group.GET("login/status", h.GetWxLoginState)
group.GET("logout", h.Logout)
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.GET("session", h.Session)
group.GET("profile", h.Profile)
group.POST("profile/update", h.ProfileUpdate)
group.POST("password", h.UpdatePass)
group.POST("bind/mobile", h.BindMobile)
group.POST("bind/email", h.BindEmail)
group.GET("signin", h.SignIn)
}
}
@@ -80,12 +112,13 @@ func (h *UserHandler) Register(c *gin.Context) {
return
}
if h.App.SysConfig.EnabledVerify && data.RegWay == "username" {
// 人机验证
if h.captchaService.GetConfig().Enabled {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
check = h.captchaService.SlideCheck(data)
} else {
check = h.captcha.Check(data)
check = h.captchaService.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
@@ -125,30 +158,8 @@ func (h *UserHandler) Register(c *gin.Context) {
}
}
// 验证邀请码
inviteCode := model.InviteCode{}
if data.InviteCode != "" {
res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode)
if res.Error != nil {
resp.ERROR(c, "无效的邀请码")
return
}
}
salt := utils.RandString(8)
user := model.User{
Username: data.Username,
Password: utils.GenPassword(data.Password, salt),
Avatar: "/images/avatar/user.png",
Salt: salt,
Status: true,
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
ChatConfig: "{}",
ChatModels: "{}",
Power: h.App.SysConfig.InitPower,
}
// check if the username is existing
user := model.User{Username: data.Username, Password: data.Password}
var item model.User
session := h.DB.Session(&gorm.Session{})
if data.Mobile != "" {
@@ -168,78 +179,19 @@ func (h *UserHandler) Register(c *gin.Context) {
return
}
// 被邀请人也获得赠送算力
if data.InviteCode != "" {
user.Power += h.App.SysConfig.InvitePower
}
if h.licenseService.GetLicense().Configs.DeCopy {
user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
} else {
defaultNickname := h.App.SysConfig.DefaultNickname
if defaultNickname == "" {
defaultNickname = "极客学长"
}
user.Nickname = fmt.Sprintf("%s@%d", defaultNickname, utils.RandomNumber(6))
}
tx := h.DB.Begin()
if err := tx.Create(&user).Error; err != nil {
user, err := h.createNewUser(user, data.InviteCode)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 记录邀请关系
if data.InviteCode != "" {
// 增加邀请数量
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.InvitePower > 0 {
err := h.userService.IncreasePower(inviteCode.UserId, h.App.SysConfig.InvitePower, model.PowerLog{
Type: types.PowerInvite,
Model: "Invite",
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
// 添加邀请记录
err := tx.Create(&model.InviteLog{
InviterId: inviteCode.UserId,
UserId: user.Id,
Username: user.Username,
InviteCode: inviteCode.Code,
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
}).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
// 自动登录创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.Id,
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
token, err := h.doLogin(&user, c.ClientIP())
if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error())
resp.ERROR(c, err.Error())
return
}
// 保存到 redis
key = fmt.Sprintf("users/%d", user.Id)
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
resp.SUCCESS(c, gin.H{"token": token, "user_id": user.Id, "username": user.Username})
}
// Login 用户登录
@@ -255,15 +207,12 @@ func (h *UserHandler) Login(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
verifyKey := fmt.Sprintf("users/verify/%s", data.Username)
needVerify, err := h.redis.Get(c, verifyKey).Bool()
if h.App.SysConfig.EnabledVerify && needVerify {
if h.captchaService.GetConfig().Enabled {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
check = h.captchaService.SlideCheck(data)
} else {
check = h.captcha.Check(data)
check = h.captchaService.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
@@ -274,54 +223,28 @@ func (h *UserHandler) Login(c *gin.Context) {
var user model.User
res := h.DB.Where("username = ?", data.Username).First(&user)
if res.Error != nil {
h.redis.Set(c, verifyKey, true, 0)
resp.ERROR(c, "用户名不存在")
return
}
password := utils.GenPassword(data.Password, user.Salt)
if password != user.Password {
h.redis.Set(c, verifyKey, true, 0)
resp.ERROR(c, "用户名或密码错误")
return
}
if user.Status == false {
if !user.Status {
resp.ERROR(c, "该用户已被禁止登录,请联系管理员")
return
}
// 更新最后登录时间和IP
user.LastLoginIp = c.ClientIP()
user.LastLoginAt = time.Now().Unix()
h.DB.Model(&user).Updates(user)
h.DB.Create(&model.UserLoginLog{
UserId: user.Id,
Username: user.Username,
LoginIp: c.ClientIP(),
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
})
// 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.Id,
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
token, err := h.doLogin(&user, c.ClientIP())
if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error())
resp.ERROR(c, err.Error())
return
}
// 保存到 redis
sessionKey := fmt.Sprintf("users/%d", user.Id)
if _, err = h.redis.Set(c, sessionKey, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
return
}
// 移除登录行为验证码
h.redis.Del(c, verifyKey)
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
resp.SUCCESS(c, gin.H{"token": token, "user_id": user.Id, "username": user.Username})
}
// Logout 注 销
@@ -333,134 +256,165 @@ func (h *UserHandler) Logout(c *gin.Context) {
resp.SUCCESS(c)
}
// CLogin 第三方登录请求二维码
func (h *UserHandler) CLogin(c *gin.Context) {
returnURL := h.GetTrim(c, "return_url")
var res types.BizVo
apiURL := fmt.Sprintf("%s/api/clogin/request", h.App.Config.ApiConfig.ApiURL)
r, err := req.C().R().SetBody(gin.H{"login_type": "wx", "return_url": returnURL}).
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
SetSuccessResult(&res).
Post(apiURL)
// GetWxLoginQRCode 获取微信登录二维码URL
func (h *UserHandler) GetWxLoginQRCode(c *gin.Context) {
if !h.wxLoginService.GetConfig().Enabled {
resp.ERROR(c, "微信登录功能未启用")
return
}
if h.wxLoginService.GetConfig().ApiKey == "" {
resp.ERROR(c, "微信登录服务令牌未配置")
return
}
state := utils.RandString(32)
qrCodeURL, err := h.wxLoginService.GetLoginQrCodeUrl(state)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "error with login http status: "+r.Status)
return
}
if res.Code != types.Success {
resp.ERROR(c, "error with http response: "+res.Message)
return
}
resp.SUCCESS(c, res.Data)
resp.SUCCESS(c, gin.H{
"url": qrCodeURL,
"state": state,
})
}
// CLoginCallback 第三方登录回调
func (h *UserHandler) CLoginCallback(c *gin.Context) {
loginType := c.Query("login_type")
code := c.Query("code")
userId := h.GetInt(c, "user_id", 0)
action := c.Query("action")
// 查询微信登录状态
func (h *UserHandler) GetWxLoginState(c *gin.Context) {
state := c.Query("state")
if state == "" {
resp.ERROR(c, "参数错误")
return
}
var res types.BizVo
apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL)
r, err := req.C().R().SetBody(gin.H{"login_type": loginType, "code": code}).
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
SetSuccessResult(&res).
Post(apiURL)
status, err := h.wxLoginService.GetLoginStatus(state)
if err != nil {
resp.ERROR(c, err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "error with login http status: "+r.Status)
if status.Status != service.LoginStatusSuccess {
resp.SUCCESS(c, status)
return
}
if res.Code != types.Success {
resp.ERROR(c, "error with http response: "+res.Message)
return
}
// login successfully
data := res.Data.(map[string]interface{})
// 登录成功
var user model.User
if action == "bind" && userId > 0 {
err = h.DB.Where("openid", data["openid"]).First(&user).Error
if err == nil {
resp.ERROR(c, "该微信已经绑定其他账号,请先解绑")
return
}
err = h.DB.Where("id", userId).First(&user).Error
h.DB.Where("openid = ?", status.OpenID).First(&user)
if user.Id == 0 {
// 创建新用户
user, err = h.createNewUser(model.User{OpenId: status.OpenID}, "")
if err != nil {
resp.ERROR(c, "绑定用户不存在")
resp.ERROR(c, err.Error())
return
}
}
err = h.DB.Model(&user).UpdateColumn("openid", data["openid"]).Error
if err != nil {
resp.ERROR(c, "更新用户信息失败,"+err.Error())
return
}
resp.SUCCESS(c, gin.H{"token": ""})
token, err := h.doLogin(&user, c.ClientIP())
if err != nil {
resp.ERROR(c, err.Error())
return
}
session := gin.H{}
tx := h.DB.Where("openid", data["openid"]).First(&user)
if tx.Error != nil {
// create new user
var totalUser int64
h.DB.Model(&model.User{}).Count(&totalUser)
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
return
}
status.Status = service.LoginStatusExpired
h.wxLoginService.SetLoginStatus(state, *status)
salt := utils.RandString(8)
password := fmt.Sprintf("%d", utils.RandomNumber(8))
user = model.User{
Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)),
Password: utils.GenPassword(password, salt),
Avatar: fmt.Sprintf("%s", data["avatar"]),
Salt: salt,
Status: true,
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
Power: h.App.SysConfig.InitPower,
OpenId: fmt.Sprintf("%s", data["openid"]),
Nickname: fmt.Sprintf("%s", data["nickname"]),
}
status.Status = service.LoginStatusSuccess
status.Token = token
resp.SUCCESS(c, status)
}
tx = h.DB.Create(&user)
if tx.Error != nil {
resp.ERROR(c, "保存数据失败")
logger.Error(tx.Error)
return
// createNewUser 创建新用户
func (h *UserHandler) createNewUser(user model.User, inviteCode string) (model.User, error) {
if user.OpenId != "" {
user.Platform = "wechat"
user.Nickname = fmt.Sprintf("微信用户@%d", utils.RandomNumber(6))
user.Username = fmt.Sprintf("wx@%d", utils.RandomNumber(8))
user.Password = "geekai123"
} else {
user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
if user.Username == "" || user.Password == "" {
return user, fmt.Errorf("用户名或密码不能为空")
}
session["username"] = user.Username
session["password"] = password
} else { // login directly
// 更新最后登录时间和IP
user.LastLoginIp = c.ClientIP()
user.LastLoginAt = time.Now().Unix()
h.DB.Model(&user).Updates(user)
h.DB.Create(&model.UserLoginLog{
UserId: user.Id,
Username: user.Username,
LoginIp: c.ClientIP(),
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
})
}
salt := utils.RandString(8)
user.Salt = salt
user.Password = utils.GenPassword(user.Password, salt)
user.Avatar = "/images/avatar/user.png"
user.Status = true
user.ChatRoles = utils.JsonEncode([]string{"gpt"})
user.ChatConfig = "{}"
user.ChatModels = "{}"
user.Power = h.App.SysConfig.Base.InitPower
// 创建用户
tx := h.DB.Begin()
if err := tx.Create(&user).Error; err != nil {
return user, err
}
// 记录邀请关系
if inviteCode != "" {
inviteCode := model.InviteCode{}
err := h.DB.Where("code = ?", inviteCode).First(&inviteCode).Error
if err != nil {
return user, fmt.Errorf("无效的邀请码")
}
// 增加邀请数量
h.DB.Model(&model.InviteCode{}).Where("code = ?", inviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.Base.InvitePower > 0 {
err := h.userService.IncreasePower(inviteCode.UserId, h.App.SysConfig.Base.InvitePower, model.PowerLog{
Type: types.PowerInvite,
Model: "Invite",
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.Base.InvitePower, inviteCode.Code, user.Username),
})
if err != nil {
tx.Rollback()
return user, err
}
// 添加邀请记录
err = tx.Create(&model.InviteLog{
InviterId: inviteCode.UserId,
UserId: user.Id,
Username: user.Username,
InviteCode: inviteCode.Code,
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.Base.InvitePower),
}).Error
if err != nil {
tx.Rollback()
return user, err
}
}
}
tx.Commit()
return user, nil
}
// doLogin 执行登录操作
func (h *UserHandler) doLogin(user *model.User, ip string) (string, error) {
// 更新最后登录时间和IP
user.LastLoginIp = ip
user.LastLoginAt = time.Now().Unix()
err := h.DB.Model(user).Updates(user).Error
if err != nil {
return "", fmt.Errorf("failed to update user: %v", err)
}
// 记录登录日志
h.DB.Create(&model.UserLoginLog{
UserId: user.Id,
Username: user.Username,
LoginIp: ip,
LoginAddress: utils.Ip2Region(h.ipSearcher, ip),
})
// 创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.Id,
@@ -468,17 +422,42 @@ func (h *UserHandler) CLoginCallback(c *gin.Context) {
})
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
if err != nil {
resp.ERROR(c, "Failed to generate token, "+err.Error())
return
return "", fmt.Errorf("failed to generate token: %v", err)
}
// 保存到 redis
key := fmt.Sprintf("users/%d", user.Id)
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
resp.ERROR(c, "error with save token: "+err.Error())
sessionKey := fmt.Sprintf("users/%d", user.Id)
if _, err = h.redis.Set(context.Background(), sessionKey, tokenString, 0).Result(); err != nil {
return "", fmt.Errorf("error with save token: %v", err)
}
return tokenString, nil
}
// WxLoginCallback 微信登录回调处理
func (h *UserHandler) WxLoginCallback(c *gin.Context) {
var data struct {
OpenID string `json:"openid"`
State string `json:"state"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session["token"] = tokenString
resp.SUCCESS(c, session)
if data.OpenID == "" || data.State == "" {
resp.ERROR(c, "参数错误")
return
}
// 设置登录状态
status := service.LoginStatus{
Status: service.LoginStatusSuccess,
OpenID: data.OpenID,
}
h.wxLoginService.SetLoginStatus(data.State, status)
resp.SUCCESS(c, status)
}
// Session 获取/验证会话
@@ -742,11 +721,11 @@ func (h *UserHandler) SignIn(c *gin.Context) {
// 签到
h.levelDB.Put(key, true)
if h.App.SysConfig.DailyPower > 0 {
h.userService.IncreasePower(userId, h.App.SysConfig.DailyPower, model.PowerLog{
if h.App.SysConfig.Base.DailyPower > 0 {
h.userService.IncreasePower(userId, h.App.SysConfig.Base.DailyPower, model.PowerLog{
Type: types.PowerSignIn,
Model: "SignIn",
Remark: fmt.Sprintf("每日签到奖励,金额:%d", h.App.SysConfig.DailyPower),
Remark: fmt.Sprintf("每日签到奖励,金额:%d", h.App.SysConfig.Base.DailyPower),
})
}
resp.SUCCESS(c)

View File

@@ -10,8 +10,10 @@ package handler
import (
"fmt"
"geekai/core"
"geekai/core/middleware"
"geekai/core/types"
"geekai/service"
"geekai/service/moderation"
"geekai/service/oss"
"geekai/service/video"
"geekai/store/model"
@@ -26,20 +28,37 @@ import (
type VideoHandler struct {
BaseHandler
videoService *video.Service
uploader *oss.UploaderManager
userService *service.UserService
videoService *video.Service
uploader *oss.UploaderManager
userService *service.UserService
moderationManager *moderation.ServiceManager
}
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService) *VideoHandler {
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *VideoHandler {
return &VideoHandler{
BaseHandler: BaseHandler{
App: app,
DB: db,
},
videoService: service,
uploader: uploader,
userService: userService,
videoService: service,
uploader: uploader,
userService: userService,
moderationManager: moderationManager,
}
}
// RegisterRoutes 注册路由
func (h *VideoHandler) RegisterRoutes() {
group := h.App.Engine.Group("/api/video/")
// 需要用户授权的接口
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
{
group.POST("luma/create", h.LumaCreate)
group.POST("keling/create", h.KeLingCreate)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}
}
@@ -62,13 +81,36 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
return
}
if h.App.SysConfig.Moderation.Enable {
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
if err != nil {
logger.Error("failed to moderate content: ", err)
}
if moderationResult.Flagged {
// 记录违规内容
moderation := model.Moderation{
UserId: h.GetLoginUserId(c),
Source: types.ModerationSourceVideo,
Input: data.Prompt,
Result: utils.JsonEncode(moderationResult),
}
err = h.DB.Create(&moderation).Error
if err != nil {
logger.Error("failed to save moderation: ", err)
}
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
return
}
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
if user.Power < h.App.SysConfig.LumaPower {
if user.Power < h.App.SysConfig.Base.LumaPower {
resp.ERROR(c, "您的算力不足,请充值后再试!")
return
}
@@ -85,14 +127,14 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
Type: types.VideoLuma,
Prompt: data.Prompt,
Params: params,
TranslateModelId: h.App.SysConfig.AssistantModelId,
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
}
// 插入数据库
job := model.VideoJob{
UserId: uint(userId),
Type: types.VideoLuma,
Prompt: data.Prompt,
Power: h.App.SysConfig.LumaPower,
Power: h.App.SysConfig.Base.LumaPower,
TaskInfo: utils.JsonEncode(task),
}
tx := h.DB.Create(&job)
@@ -147,7 +189,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) {
// 计算当前任务所需算力
key := fmt.Sprintf("%s_%s_%s", data.Model, data.Mode, data.Duration)
power := h.App.SysConfig.KeLingPowers[key]
power := h.App.SysConfig.Base.KeLingPowers[key]
if power == 0 {
resp.ERROR(c, "当前模型暂不支持")
return
@@ -181,7 +223,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) {
Type: types.VideoKeLing,
Prompt: data.Prompt,
Params: params,
TranslateModelId: h.App.SysConfig.AssistantModelId,
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
Channel: data.Channel,
}
// 插入数据库

View File

@@ -19,6 +19,7 @@ import (
"geekai/service/dalle"
"geekai/service/jimeng"
"geekai/service/mj"
"geekai/service/moderation"
"geekai/service/oss"
"geekai/service/payment"
"geekai/service/sd"
@@ -30,7 +31,7 @@ import (
"log"
"os"
"os/signal"
"strconv"
"runtime/debug"
"syscall"
"time"
@@ -71,15 +72,16 @@ func main() {
if configFile == "" {
configFile = "config.toml"
}
debug, _ := strconv.ParseBool(os.Getenv("APP_DEBUG"))
logger.Info("Loading config file: ", configFile)
if !debug {
defer func() {
if err := recover(); err != nil {
logger.Error("Panic Error:", err)
defer func() {
if err := recover(); err != nil {
logger.Error("Panic Error:", err)
// 打印堆栈信息
if os.Getenv("GEEKAI_DEBUG") == "true" {
debug.PrintStack()
}
}()
}
}
}()
app := fx.New(
// 初始化配置应用配置
@@ -89,16 +91,16 @@ func main() {
log.Fatal(err)
}
config.Path = configFile
if debug {
_ = core.SaveConfig(config)
}
return config
}),
// 创建应用服务
fx.Provide(core.NewServer),
// 初始化
fx.Invoke(func(s *core.AppServer, client *redis.Client) {
s.Init(debug, client)
s.Init(client)
}),
fx.Provide(func(db *gorm.DB) *types.SystemConfig {
return core.LoadSystemConfig(db)
}),
// 初始化数据库
@@ -126,7 +128,7 @@ func main() {
}),
// 创建控制器
fx.Provide(handler.NewChatRoleHandler),
fx.Provide(handler.NewChatAppHandler),
fx.Provide(handler.NewUserHandler),
fx.Provide(handler.NewChatHandler),
fx.Provide(handler.NewNetHandler),
@@ -143,6 +145,12 @@ func main() {
fx.Provide(handler.NewPowerLogHandler),
fx.Provide(handler.NewJimengHandler),
fx.Provide(service.NewMigrationService),
fx.Invoke(func(migrationService *service.MigrationService) {
migrationService.StartMigrate()
}),
// 管理后台控制器
fx.Provide(admin.NewConfigHandler),
fx.Provide(admin.NewAdminHandler),
fx.Provide(admin.NewApiKeyHandler),
@@ -153,34 +161,23 @@ func main() {
fx.Provide(admin.NewChatModelHandler),
fx.Provide(admin.NewProductHandler),
fx.Provide(admin.NewOrderHandler),
fx.Provide(admin.NewChatHandler),
fx.Provide(admin.NewPowerLogHandler),
fx.Provide(admin.NewAdminJimengHandler),
// 创建服务
fx.Provide(sms.NewSendServiceManager),
fx.Provide(func(config *types.AppConfig) *service.CaptchaService {
return service.NewCaptchaService(config.ApiConfig)
}),
fx.Provide(oss.NewUploaderManager),
fx.Provide(dalle.NewService),
fx.Invoke(func(s *dalle.Service) {
s.Run()
s.DownloadImages()
s.CheckTaskStatus()
}),
fx.Provide(service.NewMigrationService),
fx.Invoke(func(s *service.MigrationService) {
s.Migrate()
}),
// 邮件服务
fx.Provide(service.NewSmtpService),
// License 服务
fx.Provide(service.NewLicenseService),
fx.Invoke(func(licenseService *service.LicenseService) {
// licenseService.SyncLicense()
licenseService.SyncLicense()
}),
// Dalle 服务
fx.Provide(dalle.NewService),
fx.Invoke(func(s *dalle.Service) {
s.Run()
s.DownloadImages()
s.CheckTaskStatus()
}),
// MidJourney service pool
@@ -213,302 +210,179 @@ func main() {
}),
// 即梦AI 服务
fx.Provide(jimeng.NewClient),
fx.Provide(jimeng.NewService),
fx.Invoke(func(service *jimeng.Service) {
service.Start()
}),
fx.Provide(service.NewUserService),
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay),
fx.Provide(payment.NewJPayService),
fx.Provide(payment.NewWechatService),
fx.Provide(service.NewSnowflake),
fx.Provide(service.NewXXLJobExecutor),
fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) {
if config.XXLConfig.Enabled {
go func() {
log.Fatal(exec.Run())
}()
}
// 创建短信服务
fx.Provide(sms.NewAliYunSmsService),
fx.Provide(sms.NewBaoSmsService),
fx.Provide(sms.NewSmsManager),
fx.Provide(func(config *types.SystemConfig) *service.CaptchaService {
return service.NewCaptchaService(config.Captcha)
}),
fx.Provide(func(config *types.SystemConfig, client *redis.Client) *service.WxLoginService {
return service.NewWxLoginService(config.WxLogin, client)
}),
// 支付服务
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewEPayService),
fx.Provide(payment.NewWxpayService),
// 文件上传服务
fx.Provide(oss.NewLocalStorage),
fx.Provide(oss.NewMiniOss),
fx.Provide(oss.NewQiNiuOss),
fx.Provide(oss.NewAliYunOss),
fx.Provide(oss.NewUploaderManager),
// 用户服务
fx.Provide(service.NewUserService),
// 文本审查服务
fx.Provide(moderation.NewGiteeAIModeration),
fx.Provide(moderation.NewBaiduAIModeration),
fx.Provide(moderation.NewTencentAIModeration),
fx.Provide(moderation.NewServiceManager),
fx.Provide(admin.NewModerationHandler),
fx.Invoke(func(s *core.AppServer, h *admin.ModerationHandler) {
h.RegisterRoutes()
}),
// 注册路由
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
group := s.Engine.Group("/api/app/")
group.GET("list", h.List)
group.GET("list/user", h.ListByUser)
group.POST("update", h.UpdateRole)
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppHandler) {
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) {
group := s.Engine.Group("/api/user/")
group.POST("register", h.Register)
group.POST("login", h.Login)
group.GET("logout", h.Logout)
group.GET("session", h.Session)
group.GET("profile", h.Profile)
group.POST("profile/update", h.ProfileUpdate)
group.POST("password", h.UpdatePass)
group.POST("bind/mobile", h.BindMobile)
group.POST("bind/email", h.BindEmail)
group.POST("resetPass", h.ResetPass)
group.GET("clogin", h.CLogin)
group.GET("clogin/callback", h.CLoginCallback)
group.GET("signin", h.SignIn)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) {
group := s.Engine.Group("/api/chat/")
group.Any("message", h.Chat)
group.GET("list", h.List)
group.GET("detail", h.Detail)
group.POST("update", h.Update)
group.GET("remove", h.Remove)
group.GET("history", h.History)
group.GET("clear", h.Clear)
group.POST("tokens", h.Tokens)
group.GET("stop", h.StopGenerate)
group.POST("tts", h.TextToSpeech)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *handler.NetHandler) {
s.Engine.POST("/api/upload", h.Upload)
s.Engine.POST("/api/upload/list", h.List)
s.Engine.GET("/api/upload/remove", h.Remove)
s.Engine.GET("/api/download", h.Download)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
group := s.Engine.Group("/api/sms/")
group.POST("code", h.SendCode)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *handler.CaptchaHandler) {
group := s.Engine.Group("/api/captcha/")
group.GET("get", h.Get)
group.POST("check", h.Check)
group.GET("slide/get", h.SlideGet)
group.POST("slide/check", h.SlideCheck)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *handler.RedeemHandler) {
group := s.Engine.Group("/api/redeem/")
group.POST("verify", h.Verify)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
group := s.Engine.Group("/api/mj/")
group.POST("image", h.Image)
group.POST("upscale", h.Upscale)
group.POST("variation", h.Variation)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
group := s.Engine.Group("/api/sd")
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
group := s.Engine.Group("/api/config/")
group.GET("get", h.Get)
group.GET("license", h.License)
h.RegisterRoutes()
}),
// 管理后台控制器
// 管理后台路由注册
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
group := s.Engine.Group("/api/admin/config")
group.POST("update", h.Update)
group.GET("get", h.Get)
group.POST("active", h.Active)
group.GET("fixData", h.FixData)
group.GET("license", h.GetLicense)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
group := s.Engine.Group("/api/admin/")
group.POST("login", h.Login)
group.GET("logout", h.Logout)
group.GET("session", h.Session)
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("enable", h.Enable)
group.GET("remove", h.Remove)
group.POST("resetPass", h.ResetPass)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
group := s.Engine.Group("/api/admin/apikey/")
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) {
group := s.Engine.Group("/api/admin/user/")
group.GET("list", h.List)
group.POST("save", h.Save)
group.GET("remove", h.Remove)
group.GET("loginLog", h.LoginLog)
group.GET("genLoginLink", h.GenLoginLink)
group.POST("resetPass", h.ResetPass)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppHandler) {
group := s.Engine.Group("/api/admin/role/")
group.GET("list", h.List)
group.POST("save", h.Save)
group.POST("sort", h.Sort)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *admin.RedeemHandler) {
group := s.Engine.Group("/api/admin/redeem/")
group.GET("list", h.List)
group.POST("create", h.Create)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
group.POST("export", h.Export)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
group := s.Engine.Group("/api/admin/dashboard/")
group.GET("stats", h.Stats)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *handler.ChatModelHandler) {
group := s.Engine.Group("/api/model/")
group.GET("list", h.List)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *admin.ChatModelHandler) {
group := s.Engine.Group("/api/admin/model/")
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("set", h.Set)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) {
group := s.Engine.Group("/api/payment/")
group.POST("doPay", h.Pay)
group.GET("payWays", h.GetPayWays)
group.POST("notify/alipay", h.AlipayNotify)
group.GET("notify/geek", h.GeekPayNotify)
group.POST("notify/wechat", h.WechatPayNotify)
group.POST("notify/hupi", h.HuPiPayNotify)
h.RegisterRoutes()
h.StartSyncOrders()
}),
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
group := s.Engine.Group("/api/admin/product/")
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *admin.OrderHandler) {
group := s.Engine.Group("/api/admin/order/")
group.POST("list", h.List)
group.GET("remove", h.Remove)
group.GET("clear", h.Clear)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) {
group := s.Engine.Group("/api/order/")
group.GET("list", h.List)
group.GET("query", h.Query)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *handler.ProductHandler) {
group := s.Engine.Group("/api/product/")
group.GET("list", h.List)
h.RegisterRoutes()
}),
fx.Provide(handler.NewInviteHandler),
fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) {
group := s.Engine.Group("/api/invite/")
group.GET("code", h.Code)
group.GET("list", h.List)
group.GET("hits", h.Hits)
h.RegisterRoutes()
}),
fx.Provide(admin.NewFunctionHandler),
fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
group := s.Engine.Group("/api/admin/function/")
group.POST("save", h.Save)
group.POST("set", h.Set)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("token", h.GenToken)
h.RegisterRoutes()
}),
fx.Provide(admin.NewUploadHandler),
fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
s.Engine.POST("/api/admin/upload", h.Upload)
h.RegisterRoutes()
}),
fx.Provide(handler.NewFunctionHandler),
fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
group := s.Engine.Group("/api/function/")
group.POST("weibo", h.WeiBo)
group.POST("zaobao", h.ZaoBao)
group.POST("dalle3", h.Dall3)
group.POST("websearch", h.WebSearch)
group.GET("list", h.List)
h.RegisterRoutes()
}),
fx.Provide(admin.NewChatHandler),
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
group := s.Engine.Group("/api/admin/chat/")
group.POST("list", h.List)
group.POST("message", h.Messages)
group.GET("history", h.History)
group.GET("remove", h.RemoveChat)
group.GET("message/remove", h.RemoveMessage)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) {
group := s.Engine.Group("/api/powerLog/")
group.POST("list", h.List)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) {
group := s.Engine.Group("/api/admin/powerLog/")
group.POST("list", h.List)
h.RegisterRoutes()
}),
fx.Provide(admin.NewMenuHandler),
fx.Invoke(func(s *core.AppServer, h *admin.MenuHandler) {
group := s.Engine.Group("/api/admin/menu/")
group.POST("save", h.Save)
group.GET("list", h.List)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
group.GET("remove", h.Remove)
h.RegisterRoutes()
}),
fx.Provide(handler.NewMenuHandler),
fx.Invoke(func(s *core.AppServer, h *handler.MenuHandler) {
group := s.Engine.Group("/api/menu/")
group.GET("list", h.List)
h.RegisterRoutes()
}),
fx.Provide(handler.NewMarkMapHandler),
fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
s.Engine.POST("/api/markMap/gen", h.Generate)
h.RegisterRoutes()
}),
fx.Provide(handler.NewDallJobHandler),
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
group := s.Engine.Group("/api/dall")
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
group.GET("models", h.GetModels)
h.RegisterRoutes()
}),
fx.Provide(handler.NewSunoHandler),
fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
group := s.Engine.Group("/api/suno")
group.POST("create", h.Create)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
group.POST("update", h.Update)
group.GET("detail", h.Detail)
group.GET("play", h.Play)
h.RegisterRoutes()
}),
fx.Provide(handler.NewVideoHandler),
fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
group := s.Engine.Group("/api/video")
group.POST("luma/create", h.LumaCreate)
group.POST("keling/create", h.KeLingCreate)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
h.RegisterRoutes()
}),
// 即梦AI 路由
@@ -520,30 +394,19 @@ func main() {
}),
fx.Provide(admin.NewChatAppTypeHandler),
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) {
group := s.Engine.Group("/api/admin/app/type")
group.POST("save", h.Save)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.POST("enable", h.Enable)
group.POST("sort", h.Sort)
h.RegisterRoutes()
}),
fx.Provide(handler.NewChatAppTypeHandler),
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) {
group := s.Engine.Group("/api/app/type")
group.GET("list", h.List)
h.RegisterRoutes()
}),
fx.Provide(handler.NewTestHandler),
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
group := s.Engine.Group("/api/test")
group.Any("sse", h.PostTest, h.SseTest)
h.RegisterRoutes()
}),
fx.Provide(handler.NewPromptHandler),
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
group := s.Engine.Group("/api/prompt")
group.POST("/lyric", h.Lyric)
group.POST("/image", h.Image)
group.POST("/video", h.Video)
group.POST("/meta", h.MetaPrompt)
h.RegisterRoutes()
}),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
go func() {
@@ -568,23 +431,15 @@ func main() {
}),
fx.Provide(admin.NewImageHandler),
fx.Invoke(func(s *core.AppServer, h *admin.ImageHandler) {
group := s.Engine.Group("/api/admin/image")
group.POST("/list/mj", h.MjList)
group.POST("/list/sd", h.SdList)
group.POST("/list/dall", h.DallList)
group.GET("/remove", h.Remove)
h.RegisterRoutes()
}),
fx.Provide(admin.NewMediaHandler),
fx.Invoke(func(s *core.AppServer, h *admin.MediaHandler) {
group := s.Engine.Group("/api/admin/media")
group.POST("/suno", h.SunoList)
group.POST("/videos", h.Videos)
group.GET("/remove", h.Remove)
h.RegisterRoutes()
}),
fx.Provide(handler.NewRealtimeHandler),
fx.Invoke(func(s *core.AppServer, h *handler.RealtimeHandler) {
s.Engine.Any("/api/realtime", h.Connection)
s.Engine.POST("/api/realtime/voice", h.VoiceChat)
h.RegisterRoutes()
}),
)
// 启动应用程序

View File

@@ -8,35 +8,38 @@ package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"fmt"
"geekai/core/types"
"github.com/imroc/req/v3"
"time"
"github.com/imroc/req/v3"
)
type CaptchaService struct {
config types.ApiConfig
config types.CaptchaConfig
client *req.Client
}
func NewCaptchaService(config types.ApiConfig) *CaptchaService {
func NewCaptchaService(captchaConfig types.CaptchaConfig) *CaptchaService {
return &CaptchaService{
config: config,
config: captchaConfig,
client: req.C().SetTimeout(10 * time.Second),
}
}
func (s *CaptchaService) UpdateConfig(config types.CaptchaConfig) {
s.config = config
}
func (s *CaptchaService) GetConfig() types.CaptchaConfig {
return s.config
}
func (s *CaptchaService) Get() (interface{}, error) {
if s.config.Token == "" {
return nil, errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/captcha/get", s.config.ApiURL)
url := fmt.Sprintf("%s/api/captcha/get", types.GeekAPIURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
return nil, fmt.Errorf("请求 API 失败:%v", err)
@@ -49,12 +52,11 @@ func (s *CaptchaService) Get() (interface{}, error) {
return res.Data, nil
}
func (s *CaptchaService) Check(data interface{}) bool {
url := fmt.Sprintf("%s/api/captcha/check", s.config.ApiURL)
func (s *CaptchaService) Check(data any) bool {
url := fmt.Sprintf("%s/api/captcha/check", types.GeekAPIURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
SetBodyJsonMarshal(data).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
@@ -68,16 +70,11 @@ func (s *CaptchaService) Check(data interface{}) bool {
return true
}
func (s *CaptchaService) SlideGet() (interface{}, error) {
if s.config.Token == "" {
return nil, errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/captcha/slide/get", s.config.ApiURL)
func (s *CaptchaService) SlideGet() (any, error) {
url := fmt.Sprintf("%s/api/captcha/slide/get", types.GeekAPIURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
return nil, fmt.Errorf("请求 API 失败:%v", err)
@@ -90,12 +87,11 @@ func (s *CaptchaService) SlideGet() (interface{}, error) {
return res.Data, nil
}
func (s *CaptchaService) SlideCheck(data interface{}) bool {
url := fmt.Sprintf("%s/api/captcha/slide/check", s.config.ApiURL)
func (s *CaptchaService) SlideCheck(data any) bool {
url := fmt.Sprintf("%s/api/captcha/slide/check", types.GeekAPIURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
SetBodyJsonMarshal(data).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {

View File

@@ -1,333 +0,0 @@
package crawler
import (
"context"
"errors"
"fmt"
"geekai/logger"
"net/url"
"strings"
"time"
"github.com/go-rod/rod"
"github.com/go-rod/rod/lib/launcher"
"github.com/go-rod/rod/lib/proto"
)
// Service 网络爬虫服务
type Service struct {
browser *rod.Browser
}
// NewService 创建一个新的爬虫服务
func NewService() (*Service, error) {
// 启动浏览器
path, _ := launcher.LookPath()
u := launcher.New().Bin(path).
Headless(true). // 无头模式
Set("disable-web-security", ""). // 禁用网络安全限制
Set("disable-gpu", ""). // 禁用 GPU 加速
Set("no-sandbox", ""). // 禁用沙箱模式
Set("disable-setuid-sandbox", ""). // 禁用 setuid 沙箱
MustLaunch()
browser := rod.New().ControlURL(u).MustConnect()
return &Service{
browser: browser,
}, nil
}
// SearchResult 搜索结果
type SearchResult struct {
Title string `json:"title"` // 标题
URL string `json:"url"` // 链接
Content string `json:"content"` // 内容摘要
}
// WebSearch 网络搜索
func (s *Service) WebSearch(keyword string, maxPages int) ([]SearchResult, error) {
if keyword == "" {
return nil, errors.New("搜索关键词不能为空")
}
if maxPages <= 0 {
maxPages = 1
}
if maxPages > 10 {
maxPages = 10 // 最多搜索 10 页
}
results := make([]SearchResult, 0)
// 使用百度搜索
searchURL := fmt.Sprintf("https://www.baidu.com/s?wd=%s", url.QueryEscape(keyword))
// 设置页面超时
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 创建页面
page := s.browser.MustPage()
defer page.MustClose()
// 设置视口大小
err := page.SetViewport(&proto.EmulationSetDeviceMetricsOverride{
Width: 1280,
Height: 800,
})
if err != nil {
return nil, fmt.Errorf("设置视口失败: %v", err)
}
// 导航到搜索页面
err = page.Context(ctx).Navigate(searchURL)
if err != nil {
return nil, fmt.Errorf("导航到搜索页面失败: %v", err)
}
// 等待搜索结果加载完成
err = page.WaitLoad()
if err != nil {
return nil, fmt.Errorf("等待页面加载完成失败: %v", err)
}
// 分析当前页面的搜索结果
for i := 0; i < maxPages; i++ {
if i > 0 {
// 点击下一页按钮
nextPage, err := page.Element("a.n")
if err != nil || nextPage == nil {
break // 没有下一页
}
err = nextPage.Click(proto.InputMouseButtonLeft, 1)
if err != nil {
break // 点击下一页失败
}
// 等待新页面加载
err = page.WaitLoad()
if err != nil {
break
}
}
// 提取搜索结果
resultElements, err := page.Elements(".result, .c-container")
if err != nil || resultElements == nil {
continue
}
for _, result := range resultElements {
// 获取标题
titleElement, err := result.Element("h3, .t")
if err != nil || titleElement == nil {
continue
}
title, err := titleElement.Text()
if err != nil {
continue
}
// 获取 URL
linkElement, err := titleElement.Element("a")
if err != nil || linkElement == nil {
continue
}
href, err := linkElement.Attribute("href")
if err != nil || href == nil {
continue
}
// 获取内容摘要 - 尝试多个可能的选择器
var contentElement *rod.Element
var content string
// 尝试多个可能的选择器来适应不同版本的百度搜索结果
selectors := []string{".content-right_8Zs40", ".c-abstract", ".content_LJ0WN", ".content"}
for _, selector := range selectors {
contentElement, err = result.Element(selector)
if err == nil && contentElement != nil {
content, _ = contentElement.Text()
if content != "" {
break
}
}
}
// 如果所有选择器都失败,尝试直接从结果块中提取文本
if content == "" {
// 获取结果元素的所有文本
fullText, err := result.Text()
if err == nil && fullText != "" {
// 简单处理:从全文中移除标题,剩下的可能是摘要
fullText = strings.Replace(fullText, title, "", 1)
// 清理文本
content = strings.TrimSpace(fullText)
// 限制内容长度
if len(content) > 200 {
content = content[:200] + "..."
}
}
}
// 添加到结果集
results = append(results, SearchResult{
Title: title,
URL: *href,
Content: content,
})
// 限制结果数量,每页最多 10 条
if len(results) >= 10*maxPages {
break
}
}
}
// 获取真实 URL百度搜索结果中的 URL 是短链接,需要跳转获取真实 URL
for i, result := range results {
realURL, err := s.getRedirectURL(result.URL)
if err == nil && realURL != "" {
results[i].URL = realURL
}
}
return results, nil
}
// 获取真实 URL
func (s *Service) getRedirectURL(shortURL string) (string, error) {
// 创建页面
page, err := s.browser.Page(proto.TargetCreateTarget{URL: ""})
if err != nil {
return shortURL, err // 返回原始URL
}
defer func() {
_ = page.Close()
}()
// 导航到短链接
err = page.Navigate(shortURL)
if err != nil {
return shortURL, err // 返回原始URL
}
// 等待重定向完成
time.Sleep(2 * time.Second)
// 获取当前 URL
info, err := page.Info()
if err != nil {
return shortURL, err // 返回原始URL
}
return info.URL, nil
}
// Close 关闭浏览器
func (s *Service) Close() error {
if s.browser != nil {
err := s.browser.Close()
s.browser = nil
return err
}
return nil
}
// SearchWeb 封装的搜索方法
func SearchWeb(keyword string, maxPages int) (string, error) {
// 添加panic恢复机制
defer func() {
if r := recover(); r != nil {
log := logger.GetLogger()
log.Errorf("爬虫服务崩溃: %v", r)
}
}()
service, err := NewService()
if err != nil {
return "", fmt.Errorf("创建爬虫服务失败: %v", err)
}
defer service.Close()
// 设置超时上下文
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
// 使用goroutine和通道来处理超时
resultChan := make(chan []SearchResult, 1)
errChan := make(chan error, 1)
go func() {
results, err := service.WebSearch(keyword, maxPages)
if err != nil {
errChan <- err
return
}
resultChan <- results
}()
// 等待结果或超时
select {
case <-ctx.Done():
return "", fmt.Errorf("搜索超时: %v", ctx.Err())
case err := <-errChan:
return "", fmt.Errorf("搜索失败: %v", err)
case results := <-resultChan:
if len(results) == 0 {
return "未找到关于 \"" + keyword + "\" 的相关搜索结果", nil
}
// 格式化结果
var builder strings.Builder
builder.WriteString(fmt.Sprintf("为您找到关于 \"%s\" 的 %d 条搜索结果:\n\n", keyword, len(results)))
for i, result := range results {
// // 尝试打开链接获取实际内容
// page := service.browser.MustPage()
// defer page.MustClose()
// // 设置页面超时
// pageCtx, pageCancel := context.WithTimeout(context.Background(), 10*time.Second)
// defer pageCancel()
// // 导航到目标页面
// err := page.Context(pageCtx).Navigate(result.URL)
// if err == nil {
// // 等待页面加载
// _ = page.WaitLoad()
// // 获取页面标题
// title, err := page.Eval("() => document.title")
// if err == nil && title.Value.String() != "" {
// result.Title = title.Value.String()
// }
// // 获取页面主要内容
// if content, err := page.Element("body"); err == nil {
// if text, err := content.Text(); err == nil {
// // 清理并截取内容
// text = strings.TrimSpace(text)
// if len(text) > 200 {
// text = text[:200] + "..."
// }
// result.Prompt = text
// }
// }
// }
builder.WriteString(fmt.Sprintf("%d. **%s**\n", i+1, result.Title))
builder.WriteString(fmt.Sprintf(" 链接: %s\n", result.URL))
if result.Content != "" {
builder.WriteString(fmt.Sprintf(" 摘要: %s\n", result.Content))
}
builder.WriteString("\n")
}
return builder.String(), nil
}
}

View File

@@ -16,6 +16,7 @@ import (
"geekai/store"
"geekai/store/model"
"geekai/utils"
"strings"
"time"
"github.com/go-redis/redis/v8"
@@ -94,12 +95,14 @@ func (s *Service) Run() {
}
type imgReq struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
Style string `json:"style,omitempty"`
Model string `json:"model"`
Image []string `json:"image,omitempty"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
Style string `json:"style,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}
type imgRes struct {
@@ -122,15 +125,6 @@ type ErrRes struct {
func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
logger.Debugf("绘画参数:%+v", task)
prompt := task.Prompt
// translate prompt
if utils.HasChinese(prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, prompt), task.TranslateModelId)
if err == nil {
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
}
}
var chatModel model.ChatModel
if task.ModelId > 0 {
@@ -160,12 +154,17 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
reqBody := imgReq{
Model: chatModel.Value,
Prompt: prompt,
Prompt: task.Prompt,
N: 1,
Size: task.Size,
Style: task.Style,
Quality: task.Quality,
}
// 图片编辑
if len(task.Image) > 0 {
reqBody.Prompt = fmt.Sprintf("%s, %s", strings.Join(task.Image, " "), task.Prompt)
}
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
r, err := s.httpClient.R().SetHeader("Body-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
@@ -188,7 +187,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
var imgURL string
var data = map[string]interface{}{
"progress": 100,
"prompt": prompt,
"prompt": task.Prompt,
}
// 如果返回的是base64则需要上传到oss
if res.Data[0].B64Json != "" {
@@ -210,11 +209,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
var content string
if sync {
imgURL, err := s.downloadImage(task.Id, res.Data[0].Url)
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片\n\n![](%s)\n", prompt, imgURL)
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片\n\n![](%s)\n", task.Prompt, imgURL)
}
return content, nil

View File

@@ -3,8 +3,10 @@ package jimeng
import (
"encoding/json"
"fmt"
"geekai/core/types"
"net/http"
"net/url"
"strings"
"github.com/volcengine/volc-sdk-golang/base"
"github.com/volcengine/volc-sdk-golang/service/visual"
@@ -13,14 +15,22 @@ import (
// Client 即梦API客户端
type Client struct {
visual *visual.Visual
config types.JimengConfig
}
// NewClient 创建即梦API客户端
func NewClient(accessKey, secretKey string) *Client {
func NewClient(sysConfig *types.SystemConfig) *Client {
client := &Client{}
client.UpdateConfig(sysConfig.Jimeng)
return client
}
func (c *Client) UpdateConfig(config types.JimengConfig) error {
// 使用官方SDK的visual实例
visualInstance := visual.NewInstance()
visualInstance.Client.SetAccessKey(accessKey)
visualInstance.Client.SetSecretKey(secretKey)
visualInstance.Client.SetAccessKey(config.AccessKey)
visualInstance.Client.SetSecretKey(config.SecretKey)
// 添加即梦AI专有的API配置
jimengApis := map[string]*base.ApiInfo{
@@ -55,9 +65,32 @@ func NewClient(accessKey, secretKey string) *Client {
visualInstance.Client.ApiInfoList[name] = info
}
return &Client{
visual: visualInstance,
c.config = config
c.visual = visualInstance
return c.testConnection()
}
// testConnection 测试即梦AI连接
func (c *Client) testConnection() error {
// 使用一个简单的查询任务来测试连接
testReq := &QueryTaskRequest{
ReqKey: "test_connection",
TaskId: "test_task_id_12345",
}
_, err := c.QueryTask(testReq)
// 即使任务不存在,只要不是认证错误就说明连接正常
if err != nil {
// 检查是否是认证错误
if strings.Contains(err.Error(), "InvalidAccessKey") {
return fmt.Errorf("认证失败请检查AccessKey和SecretKey是否正确")
}
// 其他错误(如任务不存在)说明连接正常
return nil
}
return nil
}
// SubmitTask 提交异步任务

View File

@@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"gorm.io/gorm"
@@ -16,8 +15,6 @@ import (
"geekai/store/model"
"geekai/utils"
"geekai/core/types"
"github.com/go-redis/redis/v8"
)
@@ -36,17 +33,8 @@ type Service struct {
}
// NewService 创建即梦服务
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager) *Service {
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager, client *Client) *Service {
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
// 从数据库加载配置
var config model.Config
db.Where("name = ?", "Jimeng").First(&config)
var jimengConfig types.JimengConfig
if config.Id > 0 {
_ = utils.JsonDecode(config.Value, &jimengConfig)
}
client := NewClient(jimengConfig.AccessKey, jimengConfig.SecretKey)
ctx, cancel := context.WithCancel(context.Background())
return &Service{
db: db,
@@ -378,7 +366,7 @@ func (s *Service) pollTaskStatus() {
for _, job := range jobs {
// 任务超时处理
if job.UpdatedAt.Before(time.Now().Add(-5 * time.Minute)) {
if job.UpdatedAt.Before(time.Now().Add(-10 * time.Minute)) {
s.handleTaskError(job.Id, "task timeout")
continue
}
@@ -391,7 +379,7 @@ func (s *Service) pollTaskStatus() {
})
if err != nil {
logger.Errorf("query jimeng task status failed: %v", err)
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", err.Error()))
continue
}
@@ -446,9 +434,7 @@ func (s *Service) pollTaskStatus() {
s.handleTaskError(job.Id, "task not found")
case model.JMTaskStatusExpired:
// 任务过期
s.handleTaskError(job.Id, "task expired")
continue
default:
logger.Warnf("unknown task status: %s", resp.Data.Status)
}
@@ -524,77 +510,3 @@ func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
}
return &job, nil
}
// testConnection 测试即梦AI连接
func (s *Service) testConnection(accessKey, secretKey string) error {
testClient := NewClient(accessKey, secretKey)
// 使用一个简单的查询任务来测试连接
testReq := &QueryTaskRequest{
ReqKey: "test_connection",
TaskId: "test_task_id_12345",
}
_, err := testClient.QueryTask(testReq)
// 即使任务不存在,只要不是认证错误就说明连接正常
if err != nil {
// 检查是否是认证错误
if strings.Contains(err.Error(), "InvalidAccessKey") {
return fmt.Errorf("认证失败请检查AccessKey和SecretKey是否正确")
}
// 其他错误(如任务不存在)说明连接正常
return nil
}
return nil
}
// UpdateClientConfig 更新客户端配置
func (s *Service) UpdateClientConfig(accessKey, secretKey string) error {
// 创建新的客户端
newClient := NewClient(accessKey, secretKey)
// 测试新客户端是否可用
err := s.testConnection(accessKey, secretKey)
if err != nil {
return err
}
// 更新客户端
s.client = newClient
return nil
}
var defaultPower = types.JimengPower{
TextToImage: 20,
ImageToImage: 20,
ImageEdit: 20,
ImageEffects: 20,
TextToVideo: 300,
ImageToVideo: 300,
}
// GetConfig 获取即梦AI配置
func (s *Service) GetConfig() *types.JimengConfig {
var config model.Config
err := s.db.Where("name", "jimeng").First(&config).Error
if err != nil {
// 如果配置不存在,返回默认配置
return &types.JimengConfig{
AccessKey: "",
SecretKey: "",
Power: defaultPower,
}
}
var jimengConfig types.JimengConfig
err = utils.JsonDecode(config.Value, &jimengConfig)
if err != nil {
return &types.JimengConfig{
AccessKey: "",
SecretKey: "",
Power: defaultPower,
}
}
return &jimengConfig
}

View File

@@ -8,30 +8,37 @@ package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"strings"
"time"
"github.com/imroc/req/v3"
"github.com/shirou/gopsutil/host"
"gorm.io/gorm"
)
type LicenseService struct {
config types.ApiConfig
levelDB *store.LevelDB
license *types.License
urlWhiteList []string
machineId string
db *gorm.DB
}
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
var license types.License
func NewLicenseService(sysConfig *types.SystemConfig, db *gorm.DB) *LicenseService {
var machineId string
info, err := host.Info()
if err == nil {
machineId = info.HostID
}
logger.Infof("License: %+v", sysConfig.License)
return &LicenseService{
config: server.Config.ApiConfig,
levelDB: levelDB,
license: &license,
machineId: "",
license: &sysConfig.License,
machineId: machineId,
db: db,
}
}
@@ -46,15 +53,15 @@ type License struct {
}
// ActiveLicense 激活 License
func (s *LicenseService) ActiveLicense(license string, machineId string) error {
func (s *LicenseService) ActiveLicense(license string) error {
var res struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data License `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active")
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/active")
response, err := req.C().R().
SetBody(map[string]string{"license": license, "machine_id": machineId}).
SetBody(map[string]string{"license": license, "machine_id": s.machineId}).
SetSuccessResult(&res).Post(apiURL)
if err != nil {
return fmt.Errorf("发送激活请求失败: %v", err)
@@ -68,17 +75,24 @@ func (s *LicenseService) ActiveLicense(license string, machineId string) error {
return fmt.Errorf("激活失败:%v", res.Message)
}
if res.Data.ExpiredAt > 0 && res.Data.ExpiredAt < time.Now().Unix() {
return fmt.Errorf("License 已过期")
}
s.license = &types.License{
Key: license,
MachineId: machineId,
MachineId: s.machineId,
Configs: res.Data.Configs,
ExpiredAt: res.Data.ExpiredAt,
IsActive: true,
}
err = s.levelDB.Put(types.LicenseKey, s.license)
// 保存 License 到数据库
err = s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).UpdateColumn("value", utils.JsonEncode(s.license)).Error
if err != nil {
return fmt.Errorf("保存许可证书失败:%v", err)
return fmt.Errorf("保存 License 到数据库失败: %v", err)
}
return nil
}
@@ -96,6 +110,11 @@ func (s *LicenseService) SyncLicense() {
s.license.IsActive = false
} else {
s.license = license
// 保存 License 到数据库
err = s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).UpdateColumn("value", utils.JsonEncode(s.license)).Error
if err != nil {
logger.Errorf("保存 License 到数据库失败: %v", err)
}
}
urls, err := s.fetchUrlWhiteList()
@@ -109,33 +128,30 @@ func (s *LicenseService) SyncLicense() {
}
func (s *LicenseService) fetchLicense() (*types.License, error) {
//var res struct {
// Code types.BizCode `json:"code"`
// Message string `json:"message"`
// Data License `json:"data"`
//}
//apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
//response, err := req.C().R().
// SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
// SetSuccessResult(&res).Post(apiURL)
//if err != nil {
// return nil, fmt.Errorf("发送激活请求失败: %v", err)
//}
//if response.IsErrorState() {
// return nil, fmt.Errorf("激活失败:%v", response.Status)
//}
//if res.Code != types.Success {
// return nil, fmt.Errorf("激活失败:%v", res.Message)
//}
var res struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data License `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/check")
response, err := req.C().R().
SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
SetSuccessResult(&res).Post(apiURL)
if err != nil {
return nil, fmt.Errorf("License 同步失败: %v", err)
}
if response.IsErrorState() {
return nil, fmt.Errorf("License 同步失败:%v", response.Status)
}
if res.Code != types.Success {
return nil, fmt.Errorf("License 同步失败:%v", res.Message)
}
return &types.License{
Key: "abc",
MachineId: "abc",
Configs: types.LicenseConfig{
UserNum: 10000,
DeCopy: false,
},
ExpiredAt: 0,
Key: res.Data.License,
MachineId: res.Data.MachineId,
Configs: res.Data.Configs,
ExpiredAt: res.Data.ExpiredAt,
IsActive: true,
}, nil
}
@@ -146,7 +162,7 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
Message string `json:"message"`
Data []string `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls")
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/urls")
response, err := req.C().R().SetSuccessResult(&res).Get(apiURL)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %v", err)
@@ -163,35 +179,46 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
// GetLicense 获取许可信息
func (s *LicenseService) GetLicense() *types.License {
if s.license == nil {
var config model.Config
s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).First(&config)
if config.Value != "" {
utils.JsonDecode(config.Value, &s.license)
}
}
return s.license
}
func (s *LicenseService) SetLicense(licenseKey string) {
s.license.Key = licenseKey
}
// IsValidApiURL 判断是否合法的中转 URL
func (s *LicenseService) IsValidApiURL(uri string) error {
// 获得许可授权的直接放行
return nil
//if s.license.IsActive {
// if s.license.MachineId != s.machineId {
// return errors.New("系统使用了盗版的许可证书")
// }
//
// if time.Now().Unix() > s.license.ExpiredAt {
// return errors.New("系统许可证书已经过期")
// }
// return nil
//}
//
//if len(s.urlWhiteList) == 0 {
// urls, err := s.fetchUrlWhiteList()
// if err == nil {
// s.urlWhiteList = urls
// }
//}
//
//for _, v := range s.urlWhiteList {
// if strings.HasPrefix(uri, v) {
// return nil
// }
//}
//return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
if s.license.IsActive {
if s.license.MachineId != s.machineId {
return errors.New("系统使用了盗版的许可证书")
}
if time.Now().Unix() > s.license.ExpiredAt {
return errors.New("系统许可证书已经过期")
}
return nil
}
if len(s.urlWhiteList) == 0 {
urls, err := s.fetchUrlWhiteList()
if err == nil {
s.urlWhiteList = urls
}
}
for _, v := range s.urlWhiteList {
if strings.HasPrefix(uri, v) {
return nil
}
}
return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
}

View File

@@ -1,52 +1,342 @@
package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// Copyright 2023 The Geek-AI Authors. All rights reserved.
// Use of this source code is governed by a Apache-2.0 license
// that can be found in the LICENSE file.
// @Author yangjian102621@163.com
// ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"encoding/json"
"fmt"
"geekai/core/types"
"geekai/store"
"geekai/store/model"
"strings"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
const (
// 迁移状态Redis key
MigrationStatusKey = "config_migration:status"
// 迁移完成标志
MigrationCompleted = "completed"
)
// MigrationService 配置迁移服务
type MigrationService struct {
db *gorm.DB
db *gorm.DB
redisClient *redis.Client
appConfig *types.AppConfig
levelDB *store.LevelDB
licenseService *LicenseService
}
func NewMigrationService(db *gorm.DB) *MigrationService {
return &MigrationService{db: db}
func NewMigrationService(db *gorm.DB, redisClient *redis.Client, appConfig *types.AppConfig, levelDB *store.LevelDB, licenseService *LicenseService) *MigrationService {
return &MigrationService{
db: db,
redisClient: redisClient,
appConfig: appConfig,
levelDB: levelDB,
licenseService: licenseService,
}
}
func (s *MigrationService) Migrate() error {
err := s.db.AutoMigrate(
&model.AdminUser{},
&model.ApiKey{},
&model.AppType{},
&model.ChatItem{},
&model.ChatMessage{},
&model.ChatModel{},
&model.ChatRole{},
&model.Config{},
&model.DallJob{},
&model.File{},
&model.Function{},
&model.InviteCode{},
&model.InviteLog{},
&model.Menu{},
&model.MidJourneyJob{},
&model.Order{},
&model.PowerLog{},
&model.Product{},
&model.Redeem{},
&model.SdJob{},
&model.SunoJob{},
&model.User{},
&model.UserLoginLog{},
&model.VideoJob{},
)
return err
func (s *MigrationService) StartMigrate() {
go func() {
s.MigrateConfig(s.appConfig)
s.TableMigration()
s.MigrateLicense()
}()
}
// 迁移 License
func (s *MigrationService) MigrateLicense() {
key := "migrate:license"
if s.redisClient.Get(context.Background(), key).Val() == "1" {
logger.Info("License 已迁移,跳过迁移")
return
}
logger.Info("开始迁移 License...")
var license types.License
err := s.levelDB.Get(types.LicenseKey, &license)
if err != nil {
license = types.License{
Key: "",
MachineId: "",
Configs: types.LicenseConfig{UserNum: 0, DeCopy: false},
ExpiredAt: 0,
IsActive: false,
}
}
logger.Infof("迁移 License: %+v", license)
if err := s.saveConfig(types.ConfigKeyLicense, license); err != nil {
logger.Errorf("迁移 License 失败: %v", err)
return
}
s.licenseService.SetLicense(license.Key)
logger.Info("迁移 License 完成")
s.redisClient.Set(context.Background(), key, "1", 0)
}
// 迁移配置内容
func (s *MigrationService) MigrateConfigContent() error {
// 用户协议
if err := s.saveConfig(types.ConfigKeyPrivacy, map[string]string{
"content": "用户协议内容",
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 隐私政策
if err := s.saveConfig(types.ConfigKeyAgreement, map[string]string{
"content": "隐私政策内容",
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 思维导图
if err := s.saveConfig(types.ConfigKeyMarkMap, map[string]string{
"content": `# GeekAI 演示站
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
- 基于 Websocket 实现,完美的打字机体验。
- 内置了各种预训练好的角色应用,轻松满足你的各种聊天和应用需求。
- 支持 OPenAIAzure文心一言讯飞星火清华 ChatGLM等多个大语言模型。
- 支持 MidJourney / Stable Diffusion AI 绘画集成,开箱即用。
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件。`,
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 微信登录配置
if err := s.saveConfig(types.ConfigKeyWxLogin, map[string]string{
"api_key": "",
"notify_url": "",
"enabled": "false",
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 验证码配置
if err := s.saveConfig(types.ConfigKeyCaptcha, map[string]string{
"api_key": "",
"type": "dot",
"enabled": "false",
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 文本审核
if err := s.saveConfig(types.ConfigKeyModeration, map[string]any{
"enable": "false",
"active": "gitee",
"enable_guide": "false",
"guide_prompt": "",
"gitee": map[string]string{
"api_key": "",
"model": "Security-semantic-filtering",
},
"baidu": map[string]string{
"access_key": "",
"secret_key": "",
},
"tencent": map[string]string{
"access_key": "",
"secret_key": "",
},
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
return nil
}
// 数据表迁移
func (s *MigrationService) TableMigration() {
// 新数据表
s.db.AutoMigrate(&model.Moderation{})
// 订单字段整理
if s.db.Migrator().HasColumn(&model.Order{}, "pay_type") {
s.db.Migrator().RenameColumn(&model.Order{}, "pay_type", "channel")
}
if !s.db.Migrator().HasColumn(&model.Order{}, "checked") {
s.db.Migrator().AddColumn(&model.Order{}, "checked")
}
// 重命名 config 表字段
if s.db.Migrator().HasColumn(&model.Config{}, "config_json") {
s.db.Migrator().RenameColumn(&model.Config{}, "config_json", "value")
}
if s.db.Migrator().HasColumn(&model.Config{}, "marker") {
s.db.Migrator().RenameColumn(&model.Config{}, "marker", "name")
}
if s.db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") {
s.db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key")
}
if s.db.Migrator().HasIndex(&model.Config{}, "marker") {
s.db.Migrator().DropIndex(&model.Config{}, "marker")
}
// 手动删除字段
if s.db.Migrator().HasColumn(&model.Order{}, "deleted_at") {
s.db.Migrator().DropColumn(&model.Order{}, "deleted_at")
}
if s.db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") {
s.db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at")
}
if s.db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") {
s.db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at")
}
if s.db.Migrator().HasColumn(&model.User{}, "chat_config") {
s.db.Migrator().DropColumn(&model.User{}, "chat_config")
}
if s.db.Migrator().HasColumn(&model.ChatModel{}, "category") {
s.db.Migrator().DropColumn(&model.ChatModel{}, "category")
}
if s.db.Migrator().HasColumn(&model.ChatModel{}, "description") {
s.db.Migrator().DropColumn(&model.ChatModel{}, "description")
}
if s.db.Migrator().HasColumn(&model.Product{}, "discount") {
s.db.Migrator().DropColumn(&model.Product{}, "discount")
}
if s.db.Migrator().HasColumn(&model.Product{}, "days") {
s.db.Migrator().DropColumn(&model.Product{}, "days")
}
if s.db.Migrator().HasColumn(&model.Product{}, "app_url") {
s.db.Migrator().DropColumn(&model.Product{}, "app_url")
}
if s.db.Migrator().HasColumn(&model.Product{}, "url") {
s.db.Migrator().DropColumn(&model.Product{}, "url")
}
}
// 迁移配置数据
func (s *MigrationService) MigrateConfig(config *types.AppConfig) error {
logger.Info("开始迁移配置到数据库...")
// 迁移支付配置
if err := s.migratePaymentConfig(config); err != nil {
logger.Errorf("迁移支付配置失败: %v", err)
return err
}
// 迁移存储配置
if err := s.migrateStorageConfig(config); err != nil {
logger.Errorf("迁移存储配置失败: %v", err)
return err
}
// 迁移通信配置
if err := s.migrateCommunicationConfig(config); err != nil {
logger.Errorf("迁移通信配置失败: %v", err)
return err
}
// 迁移配置内容
if err := s.MigrateConfigContent(); err != nil {
logger.Errorf("迁移配置内容失败: %v", err)
return err
}
logger.Info("配置迁移完成")
return nil
}
// 迁移支付配置
func (s *MigrationService) migratePaymentConfig(config *types.AppConfig) error {
paymentConfig := types.PaymentConfig{
Alipay: config.AlipayConfig,
Epay: config.GeekPayConfig,
WxPay: config.WechatPayConfig,
}
if err := s.saveConfig(types.ConfigKeyPayment, paymentConfig); err != nil {
return err
}
return nil
}
// 迁移存储配置
func (s *MigrationService) migrateStorageConfig(config *types.AppConfig) error {
ossConfig := types.OSSConfig{
Active: config.OSS.Active,
Local: config.OSS.Local,
Minio: config.OSS.Minio,
QiNiu: config.OSS.QiNiu,
AliYun: config.OSS.AliYun,
}
return s.saveConfig(types.ConfigKeyOss, ossConfig)
}
// 迁移通信配置
func (s *MigrationService) migrateCommunicationConfig(config *types.AppConfig) error {
// SMTP配置
smtpConfig := map[string]any{
"use_tls": config.SmtpConfig.UseTls,
"host": config.SmtpConfig.Host,
"port": config.SmtpConfig.Port,
"app_name": config.SmtpConfig.AppName,
"from": config.SmtpConfig.From,
"password": config.SmtpConfig.Password,
}
if err := s.saveConfig(types.ConfigKeySmtp, smtpConfig); err != nil {
return err
}
// 短信配置
smsConfig := map[string]any{
"active": strings.ToLower(config.SMS.Active),
"aliyun": map[string]any{
"access_key": config.SMS.Ali.AccessKey,
"access_secret": config.SMS.Ali.AccessSecret,
"sign": config.SMS.Ali.Sign,
"code_temp_id": config.SMS.Ali.CodeTempId,
},
"bao": map[string]any{
"username": config.SMS.Bao.Username,
"password": config.SMS.Bao.Password,
"sign": config.SMS.Bao.Sign,
"code_template": config.SMS.Bao.CodeTemplate,
},
}
return s.saveConfig(types.ConfigKeySms, smsConfig)
}
// 保存配置到数据库
func (s *MigrationService) saveConfig(key string, config any) error {
// 检查是否已存在
var existingConfig model.Config
if err := s.db.Where("name", key).First(&existingConfig).Error; err == nil {
// 配置已存在,跳过
logger.Infof("配置 %s 已存在,跳过迁移", key)
return nil
}
// 序列化配置
configJSON, err := json.Marshal(config)
if err != nil {
return err
}
// 保存到数据库
newConfig := model.Config{
Name: key,
Value: string(configJSON),
}
if err := s.db.Create(&newConfig).Error; err != nil {
return err
}
logger.Infof("成功迁移配置 %s", key)
return nil
}

View File

@@ -67,25 +67,6 @@ func (s *Service) Run() {
continue
}
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
// translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), task.TranslateModelId)
if err == nil {
task.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
// use fast mode as default
if task.Mode == "" {
task.Mode = "fast"

View File

@@ -0,0 +1,33 @@
package moderation
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"geekai/core/types"
)
type BaiduAIModeration struct {
config types.ModerationBaiduConfig
}
func NewBaiduAIModeration(sysConfig *types.SystemConfig) *BaiduAIModeration {
return &BaiduAIModeration{
config: sysConfig.Moderation.Baidu,
}
}
func (s *BaiduAIModeration) UpdateConfig(config types.ModerationBaiduConfig) {
s.config = config
}
func (s *BaiduAIModeration) Moderate(text string) (types.ModerationResult, error) {
return types.ModerationResult{}, errors.New("not implemented")
}
var _ Service = (*BaiduAIModeration)(nil)

View File

@@ -0,0 +1,58 @@
package moderation
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"geekai/core/types"
"github.com/imroc/req/v3"
)
type GiteeAIModeration struct {
config types.ModerationGiteeConfig
apiURL string
}
func NewGiteeAIModeration(sysConfig *types.SystemConfig) *GiteeAIModeration {
return &GiteeAIModeration{
config: sysConfig.Moderation.Gitee,
apiURL: "https://ai.gitee.com/v1/moderations",
}
}
func (s *GiteeAIModeration) UpdateConfig(config types.ModerationGiteeConfig) {
s.config = config
}
type GiteeAIModerationResult struct {
ID string `json:"id"`
Model string `json:"model"`
Results []types.ModerationResult `json:"results"`
}
func (s *GiteeAIModeration) Moderate(text string) (types.ModerationResult, error) {
body := map[string]any{
"input": text,
"model": s.config.Model,
}
var res GiteeAIModerationResult
r, err := req.C().R().SetHeader("Authorization", "Bearer "+s.config.ApiKey).SetBody(body).SetSuccessResult(&res).Post(s.apiURL)
if err != nil {
return types.ModerationResult{}, err
}
if r.IsErrorState() {
return types.ModerationResult{}, errors.New(r.String())
}
return res.Results[0], nil
}
var _ Service = (*GiteeAIModeration)(nil)

View File

@@ -0,0 +1,58 @@
package moderation
import (
"geekai/core/types"
logger2 "geekai/logger"
)
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
var logger = logger2.GetLogger()
type Service interface {
Moderate(text string) (types.ModerationResult, error)
}
type ServiceManager struct {
gitee *GiteeAIModeration
baidu *BaiduAIModeration
tencent *TencentAIModeration
active string
}
func NewServiceManager(gitee *GiteeAIModeration, baidu *BaiduAIModeration, tencent *TencentAIModeration) *ServiceManager {
return &ServiceManager{
gitee: gitee,
baidu: baidu,
tencent: tencent,
}
}
func (s *ServiceManager) GetService() Service {
switch s.active {
case types.ModerationBaidu:
return s.baidu
case types.ModerationTencent:
return s.tencent
default:
return s.gitee
}
}
func (s *ServiceManager) UpdateConfig(config types.ModerationConfig) {
switch config.Active {
case types.ModerationGitee:
s.gitee.UpdateConfig(config.Gitee)
case types.ModerationBaidu:
s.baidu.UpdateConfig(config.Baidu)
case types.ModerationTencent:
s.tencent.UpdateConfig(config.Tencent)
}
s.active = config.Active
}

View File

@@ -0,0 +1,33 @@
package moderation
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"geekai/core/types"
)
type TencentAIModeration struct {
config types.ModerationTencentConfig
}
func NewTencentAIModeration(sysConfig *types.SystemConfig) *TencentAIModeration {
return &TencentAIModeration{
config: sysConfig.Moderation.Tencent,
}
}
func (s *TencentAIModeration) UpdateConfig(config types.ModerationTencentConfig) {
s.config = config
}
func (s *TencentAIModeration) Moderate(text string) (types.ModerationResult, error) {
return types.ModerationResult{}, errors.New("not implemented")
}
var _ Service = (*TencentAIModeration)(nil)

View File

@@ -23,35 +23,35 @@ import (
)
type AliYunOss struct {
config *types.AliYunOssConfig
config types.AliYunOssConfig
bucket *oss.Bucket
proxyURL string
}
func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) {
config := &appConfig.OSS.AliYun
// 创建 OSS 客户端
func NewAliYunOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*AliYunOss, error) {
s := &AliYunOss{
proxyURL: appConfig.ProxyURL,
}
err := s.UpdateConfig(sysConfig.OSS.AliYun)
if err != nil {
logger.Warnf("阿里云OSS初始化失败: %v", err)
}
return s, nil
}
func (s *AliYunOss) UpdateConfig(config types.AliYunOssConfig) error {
client, err := oss.New(config.Endpoint, config.AccessKey, config.AccessSecret)
if err != nil {
return nil, err
return err
}
// 获取存储空间
bucket, err := client.Bucket(config.Bucket)
if err != nil {
return nil, err
return err
}
if config.SubDir == "" {
config.SubDir = "gpt"
}
return &AliYunOss{
config: config,
bucket: bucket,
proxyURL: appConfig.ProxyURL,
}, nil
s.bucket = bucket
s.config = config
return nil
}
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
@@ -68,7 +68,7 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
defer src.Close()
fileExt := filepath.Ext(file.Filename)
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
// 上传文件
err = s.bucket.PutObject(objectKey, src)
if err != nil {
@@ -102,7 +102,7 @@ func (s AliYunOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
if ext == "" {
ext = filepath.Ext(parse.Path)
}
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
// 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
if err != nil {
@@ -116,7 +116,7 @@ func (s AliYunOss) PutBase64(base64Img string) (string, error) {
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
// 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
if err != nil {
@@ -128,8 +128,7 @@ func (s AliYunOss) PutBase64(base64Img string) (string, error) {
func (s AliYunOss) Delete(fileURL string) error {
var objectKey string
if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
objectKey = filepath.Base(fileURL)
} else {
objectKey = fileURL
}

View File

@@ -21,17 +21,21 @@ import (
)
type LocalStorage struct {
config *types.LocalStorageConfig
config types.LocalStorageConfig
proxyURL string
}
func NewLocalStorage(config *types.AppConfig) LocalStorage {
return LocalStorage{
config: &config.OSS.Local,
proxyURL: config.ProxyURL,
func NewLocalStorage(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *LocalStorage {
return &LocalStorage{
config: sysConfig.OSS.Local,
proxyURL: appConfig.ProxyURL,
}
}
func (s *LocalStorage) UpdateConfig(config types.LocalStorageConfig) {
s.config = config
}
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
file, err := ctx.FormFile(name)
if err != nil {

View File

@@ -24,24 +24,32 @@ import (
)
type MiniOss struct {
config *types.MiniOssConfig
config types.MiniOssConfig
client *minio.Client
proxyURL string
}
func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
config := &appConfig.OSS.Minio
func NewMiniOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*MiniOss, error) {
s := &MiniOss{proxyURL: appConfig.ProxyURL}
err := s.UpdateConfig(sysConfig.OSS.Minio)
if err != nil {
logger.Warnf("MinioOSS初始化失败: %v", err)
}
return s, nil
}
func (s *MiniOss) UpdateConfig(config types.MiniOssConfig) error {
minioClient, err := minio.New(config.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(config.AccessKey, config.AccessSecret, ""),
Secure: config.UseSSL,
})
if err != nil {
return MiniOss{}, err
return err
}
if config.SubDir == "" {
config.SubDir = "gpt"
}
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
s.config = config
s.client = minioClient
return nil
}
func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
@@ -62,7 +70,7 @@ func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string,
if ext == "" {
ext = filepath.Ext(parse.Path)
}
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
info, err := s.client.PutObject(
context.Background(),
s.config.Bucket,
@@ -89,7 +97,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
defer fileReader.Close()
fileExt := filepath.Ext(file.Filename)
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
ContentType: file.Header.Get("Body-Type"),
})
@@ -111,7 +119,7 @@ func (s MiniOss) PutBase64(base64Img string) (string, error) {
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
info, err := s.client.PutObject(
context.Background(),
s.config.Bucket,
@@ -128,8 +136,7 @@ func (s MiniOss) PutBase64(base64Img string) (string, error) {
func (s MiniOss) Delete(fileURL string) error {
var objectKey string
if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
objectKey = filepath.Base(fileURL)
} else {
objectKey = fileURL
}

View File

@@ -24,18 +24,24 @@ import (
"github.com/qiniu/go-sdk/v7/storage"
)
type QinNiuOss struct {
config *types.QiNiuOssConfig
type QiNiuOss struct {
config types.QiNiuOssConfig
mac *qbox.Mac
putPolicy storage.PutPolicy
uploader *storage.FormUploader
manager *storage.BucketManager
bucket *storage.BucketManager
proxyURL string
}
func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
config := &appConfig.OSS.QiNiu
// build storage uploader
func NewQiNiuOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *QiNiuOss {
s := &QiNiuOss{
proxyURL: appConfig.ProxyURL,
}
s.UpdateConfig(sysConfig.OSS.QiNiu)
return s
}
func (s *QiNiuOss) UpdateConfig(config types.QiNiuOssConfig) {
zone, ok := storage.GetRegionByID(storage.RegionID(config.Zone))
if !ok {
zone = storage.ZoneHuanan
@@ -47,20 +53,13 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
putPolicy := storage.PutPolicy{
Scope: config.Bucket,
}
if config.SubDir == "" {
config.SubDir = "gpt"
}
return QinNiuOss{
config: config,
mac: mac,
putPolicy: putPolicy,
uploader: formUploader,
manager: storage.NewBucketManager(mac, &storeConfig),
proxyURL: appConfig.ProxyURL,
}
s.config = config
s.mac = mac
s.putPolicy = putPolicy
s.uploader = formUploader
s.bucket = storage.NewBucketManager(mac, &storeConfig)
}
func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
func (s QiNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
// 解析表单
file, err := ctx.FormFile(name)
if err != nil {
@@ -74,7 +73,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
defer src.Close()
fileExt := filepath.Ext(file.Filename)
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
key := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
// 上传文件
ret := storage.PutRet{}
extra := storage.PutExtra{}
@@ -93,7 +92,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
}
func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
func (s QiNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
var fileData []byte
var err error
if useProxy {
@@ -111,7 +110,7 @@ func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
if ext == "" {
ext = filepath.Ext(parse.Path)
}
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
key := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
ret := storage.PutRet{}
extra := storage.PutExtra{}
// 上传文件字节数据
@@ -122,12 +121,12 @@ func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
}
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
func (s QiNiuOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
ret := storage.PutRet{}
extra := storage.PutExtra{}
// 上传文件字节数据
@@ -138,16 +137,15 @@ func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
}
func (s QinNiuOss) Delete(fileURL string) error {
func (s QiNiuOss) Delete(fileURL string) error {
var objectKey string
if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
objectKey = filepath.Base(fileURL)
} else {
objectKey = fileURL
}
return s.manager.Delete(s.config.Bucket, objectKey)
return s.bucket.Delete(s.config.Bucket, objectKey)
}
var _ Uploader = QinNiuOss{}
var _ Uploader = QiNiuOss{}

View File

@@ -9,10 +9,10 @@ package oss
import "github.com/gin-gonic/gin"
const Local = "LOCAL"
const Minio = "MINIO"
const QiNiu = "QINIU"
const AliYun = "ALIYUN"
const Local = "local"
const Minio = "minio"
const QiNiu = "qiniu"
const AliYun = "aliyun"
type File struct {
Name string `json:"name"`

View File

@@ -9,45 +9,58 @@ package oss
import (
"geekai/core/types"
"strings"
logger2 "geekai/logger"
)
var logger = logger2.GetLogger()
type UploaderManager struct {
handler Uploader
local *LocalStorage
aliyun *AliYunOss
mini *MiniOss
qiniu *QiNiuOss
active string
}
func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) {
active := Local
if config.OSS.Active != "" {
active = strings.ToUpper(config.OSS.Active)
}
var handler Uploader
switch active {
case Local:
handler = NewLocalStorage(config)
break
case Minio:
client, err := NewMiniOss(config)
if err != nil {
return nil, err
}
handler = client
break
case QiNiu:
handler = NewQiNiuOss(config)
break
case AliYun:
client, err := NewAliYunOss(config)
if err != nil {
return nil, err
}
handler = client
break
func NewUploaderManager(sysConfig *types.SystemConfig, local *LocalStorage, aliyun *AliYunOss, mini *MiniOss, qiniu *QiNiuOss) (*UploaderManager, error) {
if sysConfig.OSS.Active == "" {
sysConfig.OSS.Active = Local
}
return &UploaderManager{handler: handler}, nil
return &UploaderManager{
active: sysConfig.OSS.Active,
local: local,
aliyun: aliyun,
mini: mini,
qiniu: qiniu,
}, nil
}
func (m *UploaderManager) GetUploadHandler() Uploader {
return m.handler
switch m.active {
case Local:
return m.local
case AliYun:
return m.aliyun
case Minio:
return m.mini
case QiNiu:
return m.qiniu
}
return m.local
}
func (m *UploaderManager) UpdateConfig(config types.OSSConfig) {
switch config.Active {
case Local:
m.local.UpdateConfig(config.Local)
case AliYun:
m.aliyun.UpdateConfig(config.AliYun)
case Minio:
m.mini.UpdateConfig(config.Minio)
case QiNiu:
m.qiniu.UpdateConfig(config.QiNiu)
}
m.active = config.Active
}

View File

@@ -12,129 +12,98 @@ import (
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/alipay"
"net/http"
"os"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/alipay"
)
type AlipayService struct {
config *types.AlipayConfig
client *alipay.Client
config *types.AlipayConfig
}
var logger = logger2.GetLogger()
func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
config := appConfig.AlipayConfig
func NewAlipayService(sysConfig *types.SystemConfig) (*AlipayService, error) {
config := sysConfig.Payment.Alipay
if !config.Enabled {
logger.Info("Disabled Alipay service")
return nil, nil
logger.Debug("Disabled Alipay service")
}
priKey, err := readKey(config.PrivateKey)
service := &AlipayService{config: &config}
if config.Enabled {
err := service.UpdateConfig(&config)
if err != nil {
logger.Errorf("支付宝服务初始化失败: %v", err)
}
}
return service, nil
}
func (s *AlipayService) UpdateConfig(config *types.AlipayConfig) error {
client, err := alipay.NewClient(config.AppId, config.PrivateKey, !config.SandBox)
if err != nil {
return nil, fmt.Errorf("error with read App Private key: %v", err)
return fmt.Errorf("error with initialize alipay service: %v", err)
}
client, err := alipay.NewClient(config.AppId, priKey, !config.SandBox)
if err != nil {
return nil, fmt.Errorf("error with initialize alipay service: %v", err)
s.client = client
s.config = config
if os.Getenv("GEEKAI_DEBUG") == "true" {
logger.Info("Alipay Debug mode is enabled")
client.DebugSwitch = gopay.DebugOn
}
//client.DebugSwitch = gopay.DebugOn // 开启调试模式
client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间
SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
SetSignType(alipay.RSA2) // 设置签名类型,不设置默认 RSA2
if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil {
return nil, fmt.Errorf("error with load payment public key: %v", err)
}
return &AlipayService{config: &config, client: client}, nil
return nil
}
type AlipayParams struct {
OutTradeNo string `json:"out_trade_no"`
Subject string `json:"subject"`
TotalFee string `json:"total_fee"`
ReturnURL string `json:"return_url"`
NotifyURL string `json:"notify_url"`
}
func (s *AlipayService) PayMobile(params AlipayParams) (string, error) {
bm := make(gopay.BodyMap)
bm.Set("subject", params.Subject)
bm.Set("out_trade_no", params.OutTradeNo)
bm.Set("quit_url", params.ReturnURL)
bm.Set("total_amount", params.TotalFee)
bm.Set("product_code", "QUICK_WAP_WAY")
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradeWapPay(context.Background(), bm)
}
func (s *AlipayService) PayPC(params AlipayParams) (string, error) {
func (s *AlipayService) Pay(params PayRequest) (string, error) {
bm := make(gopay.BodyMap)
bm.Set("subject", params.Subject)
bm.Set("out_trade_no", params.OutTradeNo)
bm.Set("total_amount", params.TotalFee)
bm.Set("product_code", "FAST_INSTANT_TRADE_PAY")
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradePagePay(context.Background(), bm)
return s.client.TradeWapPay(context.Background(), bm)
}
func (s *AlipayService) Query(outTradeNo string) (OrderInfo, error) {
bm := make(gopay.BodyMap)
bm.Set("out_trade_no", outTradeNo)
rsp, err := s.client.TradeQuery(context.Background(), bm)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with trade query: %v", err)
}
switch rsp.Response.TradeStatus {
case "TRADE_SUCCESS":
logger.Debugf("支付宝查询订单成功:%+v", rsp.Response)
return OrderInfo{
OutTradeNo: rsp.Response.OutTradeNo,
TradeId: rsp.Response.TradeNo,
Amount: rsp.Response.TotalAmount,
Status: Success,
PayTime: rsp.Response.SendPayDate,
}, nil
case "TRADE_CLOSED":
return OrderInfo{Status: Closed}, nil
default:
return OrderInfo{}, fmt.Errorf("error with trade query: %v", rsp.Response.TradeStatus)
}
}
// TradeVerify 交易验证
func (s *AlipayService) TradeVerify(request *http.Request) NotifyVo {
func (s *AlipayService) TradeVerify(request *http.Request) (OrderInfo, error) {
notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法
if err != nil {
return NotifyVo{
Status: Failure,
Message: "error with parse notify request: " + err.Error(),
}
return OrderInfo{}, fmt.Errorf("error with parse notify request: %v", err)
}
_, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq)
if err != nil {
return NotifyVo{
Status: Failure,
Message: "error with verify sign: " + err.Error(),
}
return OrderInfo{}, fmt.Errorf("error with verify sign: %v", err)
}
return s.TradeQuery(request.Form.Get("out_trade_no"))
return s.Query(request.Form.Get("out_trade_no"))
}
func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo {
bm := make(gopay.BodyMap)
bm.Set("out_trade_no", outTradeNo)
//查询订单
rsp, err := s.client.TradeQuery(context.Background(), bm)
if err != nil {
return NotifyVo{
Status: Failure,
Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(),
}
}
if rsp.Response.TradeStatus == "TRADE_SUCCESS" {
return NotifyVo{
Status: Success,
OutTradeNo: rsp.Response.OutTradeNo,
TradeId: rsp.Response.TradeNo,
Amount: rsp.Response.TotalAmount,
Subject: rsp.Response.Subject,
Message: "OK",
}
} else {
return NotifyVo{
Status: Failure,
Message: "异步查询验证订单信息发生错误" + outTradeNo,
}
}
}
func readKey(filename string) (string, error) {
data, err := os.ReadFile(filename)
if err != nil {
return "", err
}
return string(data), nil
}
var _ PayService = (*AlipayService)(nil)

View File

@@ -22,41 +22,30 @@ import (
"time"
)
// GeekPayService Geek 支付服务
type GeekPayService struct {
config *types.GeekPayConfig
// EPayService 支付服务
type EPayService struct {
config *types.EpayConfig
}
func NewJPayService(appConfig *types.AppConfig) *GeekPayService {
return &GeekPayService{
config: &appConfig.GeekPayConfig,
func NewEPayService(sysConfig *types.SystemConfig) *EPayService {
return &EPayService{
config: &sysConfig.Payment.Epay,
}
}
type GeekPayParams struct {
Method string `json:"method"` // 接口类型
Device string `json:"device"` // 设备类型
Type string `json:"type"` // 支付方式
OutTradeNo string `json:"out_trade_no"` // 商户订单号
Name string `json:"name"` // 商品名称
Money string `json:"money"` // 商品金额
ClientIP string `json:"clientip"` //用户IP地址
SubOpenId string `json:"sub_openid"` // 微信用户 openid仅小程序支付需要
SubAppId string `json:"sub_appid"` // 小程序 AppId仅小程序支付需要
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"return_url"`
func (s *EPayService) UpdateConfig(config *types.EpayConfig) {
s.config = config
}
// Pay 支付订单
func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
func (s *EPayService) Pay(params PayRequest) (string, error) {
p := map[string]string{
"pid": s.config.AppId,
//"method": params.Method,
"pid": s.config.AppId,
"device": params.Device,
"type": params.Type,
"type": params.PayWay,
"out_trade_no": params.OutTradeNo,
"name": params.Name,
"money": params.Money,
"name": params.Subject,
"money": params.TotalFee,
"clientip": params.ClientIP,
"notify_url": params.NotifyURL,
"return_url": params.ReturnURL,
@@ -64,10 +53,21 @@ func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
}
p["sign"] = s.Sign(p)
p["sign_type"] = "MD5"
return s.sendRequest(s.config.ApiURL, p)
resp, err := s.sendRequest(s.config.ApiURL, p)
if err != nil {
return "", err
}
if resp.Code != 1 {
return "", errors.New(resp.Msg)
}
if resp.PayURL != "" {
return resp.PayURL, nil
} else {
return resp.QrCode, nil
}
}
func (s *GeekPayService) Sign(params map[string]string) string {
func (s *EPayService) Sign(params map[string]string) string {
// 按字母顺序排序参数
var keys []string
for k := range params {
@@ -100,7 +100,7 @@ type GeekPayResp struct {
UrlScheme string `json:"urlscheme"` // 小程序跳转支付链接
}
func (s *GeekPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
func (s *EPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
form := url.Values{}
for k, v := range params {
form.Add(k, v)
@@ -137,3 +137,61 @@ func (s *GeekPayService) sendRequest(endpoint string, params map[string]string)
}
return &r, nil
}
func (s *EPayService) Query(outTradeNo string) (OrderInfo, error) {
params := url.Values{}
params.Set("act", "order")
params.Set("pid", s.config.AppId)
params.Set("key", s.config.PrivateKey)
params.Set("out_trade_no", outTradeNo)
apiURL := fmt.Sprintf("%s/api.php?%s", s.config.ApiURL, params.Encode())
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{Transport: tr}
resp, err := client.Get(apiURL)
if err != nil {
return OrderInfo{}, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return OrderInfo{}, err
}
logger.Debugf(string(body))
var result struct {
Code int `json:"code"`
Msg string `json:"msg"`
Status string `json:"status"`
Name string `json:"name"`
Money string `json:"money"`
EndTime string `json:"endtime"`
TradeNo string `json:"trade_no"`
}
if err := json.Unmarshal(body, &result); err != nil {
return OrderInfo{}, errors.New("订单查询响应解析失败")
}
if result.Code != 1 {
return OrderInfo{}, errors.New(result.Msg)
}
logger.Debugf("订单信息:%+v", result)
orderInfo := OrderInfo{
OutTradeNo: outTradeNo,
TradeId: result.TradeNo,
Amount: result.Money,
PayTime: result.EndTime,
}
if result.Status == "1" {
orderInfo.Status = Success
} else {
orderInfo.Status = Failure
}
return orderInfo, nil
}
var _ PayService = (*EPayService)(nil)

View File

@@ -1,171 +0,0 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
)
type HuPiPayService struct {
appId string
appSecret string
apiURL string
}
func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
return &HuPiPayService{
appId: config.HuPiPayConfig.AppId,
appSecret: config.HuPiPayConfig.AppSecret,
apiURL: config.HuPiPayConfig.ApiURL,
}
}
type HuPiPayParams struct {
AppId string `json:"appid"`
Version string `json:"version"`
TradeOrderId string `json:"trade_order_id"`
TotalFee string `json:"total_fee"`
Title string `json:"title"`
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"return_url"`
WapName string `json:"wap_name"`
CallbackURL string `json:"callback_url"`
Time string `json:"time"`
NonceStr string `json:"nonce_str"`
Type string `json:"type"`
WapUrl string `json:"wap_url"`
}
type HuPiPayResp struct {
Openid interface{} `json:"openid"`
UrlQrcode string `json:"url_qrcode"`
URL string `json:"url"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg,omitempty"`
}
// Pay 执行支付请求操作
func (s *HuPiPayService) Pay(params HuPiPayParams) (HuPiPayResp, error) {
data := url.Values{}
simple := strconv.FormatInt(time.Now().Unix(), 10)
params.AppId = s.appId
params.Time = simple
params.NonceStr = simple
encode := utils.JsonEncode(params)
m := make(map[string]string)
_ = utils.JsonDecode(encode, &m)
for k, v := range m {
data.Add(k, fmt.Sprintf("%v", v))
}
// 生成签名
data.Add("hash", s.Sign(data))
// 发送支付请求
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
resp, err := http.PostForm(apiURL, data)
if err != nil {
return HuPiPayResp{}, fmt.Errorf("error with requst api: %v", err)
}
defer resp.Body.Close()
all, err := io.ReadAll(resp.Body)
if err != nil {
return HuPiPayResp{}, fmt.Errorf("error with reading response: %v", err)
}
var res HuPiPayResp
err = utils.JsonDecode(string(all), &res)
if err != nil {
return HuPiPayResp{}, fmt.Errorf("error with decode payment result: %v", err)
}
if res.ErrCode != 0 {
return HuPiPayResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
}
return res, nil
}
// Sign 签名方法
func (s *HuPiPayService) Sign(params url.Values) string {
params.Del(`Sign`)
var keys = make([]string, 0, 0)
for key := range params {
if params.Get(key) != `` {
keys = append(keys, key)
}
}
sort.Strings(keys)
var pList = make([]string, 0, 0)
for _, key := range keys {
var value = strings.TrimSpace(params.Get(key))
if len(value) > 0 {
pList = append(pList, key+"="+value)
}
}
var src = strings.Join(pList, "&")
src += s.appSecret
md5bs := md5.Sum([]byte(src))
return hex.EncodeToString(md5bs[:])
}
// Check 校验订单状态
func (s *HuPiPayService) Check(outTradeNo string) error {
data := url.Values{}
data.Add("appid", s.appId)
data.Add("out_trade_order", outTradeNo)
stamp := strconv.FormatInt(time.Now().Unix(), 10)
data.Add("time", stamp)
data.Add("nonce_str", stamp)
data.Add("hash", s.Sign(data))
apiURL := fmt.Sprintf("%s/payment/query.html", s.apiURL)
resp, err := http.PostForm(apiURL, data)
if err != nil {
return fmt.Errorf("error with http reqeust: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var r struct {
ErrCode int `json:"errcode"`
Data struct {
Status string `json:"status"`
OpenOrderId string `json:"open_order_id"`
} `json:"data,omitempty"`
ErrMsg string `json:"errmsg"`
Hash string `json:"hash"`
}
err = utils.JsonDecode(string(body), &r)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if r.ErrCode == 0 && r.Data.Status == "OD" {
return nil
} else {
logger.Debugf("%+v", r)
return errors.New("order not paid" + r.ErrMsg)
}
}

View File

@@ -0,0 +1,54 @@
package payment
// 支付渠道定义
const PayChannelAL = "alipay" // 支付宝
const PayChannelWX = "wxpay" // 微信支付
const PayChannelEpay = "epay" // 易支付
// 支付方式
const PayWayAL = "alipay"
const PayWayWX = "wxpay"
const (
Success = 0
Failure = 1
Closed = 2
)
type PayRequest struct {
OutTradeNo string // 商户订单号
Subject string // 商品名称
TotalFee string // 商品金额
ReturnURL string // 回调地址
NotifyURL string // 回调地址
// 易支付专有参数
Method string // 接口类型
Device string // 设备类型
PayWay string // 支付方式
ClientIP string //用户IP地址
OpenID string // 用户openid
}
type OrderInfo struct {
Mchid string // 商户号
OutTradeNo string // 商户订单号
TradeId string // 交易号
Amount string // 金额
Status int // 状态 0: 未支付 1: 已支付 2: 已关闭
PayTime string // 完成支付时间
}
func (o OrderInfo) Closed() bool {
return o.Status == Closed
}
func (o OrderInfo) Success() bool {
return o.Status == Success
}
type PayService interface {
Pay(params PayRequest) (string, error) // 生成支付链接
Query(outTradeNo string) (OrderInfo, error) // 查询订单
}

View File

@@ -1,19 +0,0 @@
package payment
type NotifyVo struct {
Status int
OutTradeNo string // 商户订单号
TradeId string // 交易ID
Amount string // 交易金额
Message string
Subject string
}
func (v NotifyVo) Success() bool {
return v.Status == Success
}
const (
Success = 0
Failure = 1
)

View File

@@ -1,144 +0,0 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"fmt"
"geekai/core/types"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/wechat/v3"
"net/http"
"time"
)
type WechatPayService struct {
config *types.WechatPayConfig
client *wechat.ClientV3
}
func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) {
config := appConfig.WechatPayConfig
if !config.Enabled {
logger.Info("Disabled WechatPay service")
return nil, nil
}
priKey, err := readKey(config.PrivateKey)
if err != nil {
return nil, fmt.Errorf("error with read App Private key: %v", err)
}
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, priKey)
if err != nil {
return nil, fmt.Errorf("error with initialize WechatPay service: %v", err)
}
err = client.AutoVerifySign()
if err != nil {
return nil, fmt.Errorf("error with autoVerifySign: %v", err)
}
//client.DebugSwitch = gopay.DebugOn
return &WechatPayService{config: &config, client: client}, nil
}
type WechatPayParams struct {
OutTradeNo string `json:"out_trade_no"`
TotalFee int `json:"total_fee"`
Subject string `json:"subject"`
ClientIP string `json:"client_ip"`
ReturnURL string `json:"return_url"`
NotifyURL string `json:"notify_url"`
}
func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", params.TotalFee).
Set("currency", "CNY")
})
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.CodeUrl, nil
}
func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", params.TotalFee).
Set("currency", "CNY")
}).
SetBodyMap("scene_info", func(bm gopay.BodyMap) {
bm.Set("payer_client_ip", params.ClientIP).
SetBodyMap("h5_info", func(bm gopay.BodyMap) {
bm.Set("type", "Wap")
})
})
wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.H5Url, nil
}
type NotifyResponse struct {
Code string `json:"code"`
Message string `xml:"message"`
}
// TradeVerify 交易验证
func (s *WechatPayService) TradeVerify(request *http.Request) NotifyVo {
notifyReq, err := wechat.V3ParseNotify(request)
if err != nil {
return NotifyVo{Status: 1, Message: fmt.Sprintf("error with client v3 parse notify: %v", err)}
}
// TODO: 这里验签程序有 Bug一直报错crypto/rsa: verification error先暂时取消验签
//err = notifyReq.VerifySignByPK(s.client.WxPublicKey())
//if err != nil {
// return fmt.Errorf("error with client v3 verify sign: %v", err)
//}
// 解密支付密文,验证订单信息
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
if err != nil {
return NotifyVo{Status: Failure, Message: fmt.Sprintf("error with client v3 decrypt: %v", err)}
}
return NotifyVo{
Status: Success,
OutTradeNo: result.OutTradeNo,
TradeId: result.TransactionId,
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
}
}

View File

@@ -0,0 +1,217 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"fmt"
"geekai/core/types"
"geekai/utils"
"net/http"
"os"
"time"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/wechat/v3"
)
type WxPayService struct {
config *types.WxPayConfig
client *wechat.ClientV3
}
func NewWxpayService(sysConfig *types.SystemConfig) (*WxPayService, error) {
config := sysConfig.Payment.WxPay
if !config.Enabled {
logger.Debug("Disabled WechatPay service")
}
service := &WxPayService{config: &config}
if config.Enabled {
err := service.UpdateConfig(&config)
if err != nil {
logger.Errorf("微信支付服务初始化失败: %v", err)
}
}
return service, nil
}
func (s *WxPayService) UpdateConfig(config *types.WxPayConfig) error {
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, config.PrivateKey)
if err != nil {
return fmt.Errorf("error with initialize WechatPay service: %v", err)
}
err = client.AutoVerifySign()
if err != nil {
return fmt.Errorf("error with autoVerifySign: %v", err)
}
s.client = client
if os.Getenv("GEEKAI_DEBUG") == "true" {
logger.Info("WechatPay Debug mode is enabled")
client.DebugSwitch = gopay.DebugOn
}
s.config = config
return nil
}
func (s *WxPayService) Pay(params PayRequest) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", utils.IntValue(params.TotalFee, 0)).
Set("currency", "CNY")
})
logger.Debugf("wxpay params: %+v", bm)
if params.Device == "mobile" {
bm.SetBodyMap("scene_info", func(bm gopay.BodyMap) {
bm.Set("payer_client_ip", params.ClientIP)
}).SetBodyMap("payer", func(bm gopay.BodyMap) {
bm.Set("openid", params.OpenID)
})
wxRsp, err := s.client.V3TransactionJsapi(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction Jsapi: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.PrepayId, nil
} else if params.Device == "pc" {
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.CodeUrl, nil
}
return "", nil
}
func (s *WxPayService) Query(outTradeNo string) (OrderInfo, error) {
wxRsp, err := s.client.V3TransactionQueryOrder(context.Background(), wechat.OutTradeNo, outTradeNo)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with client v3 transaction query: %v", err)
}
if wxRsp.Code != wechat.Success {
return OrderInfo{}, fmt.Errorf("error status with querying order: %v", wxRsp.Error)
}
if wxRsp.Response.TradeState == "CLOSED" {
return OrderInfo{Status: Closed}, nil
}
orderInfo := OrderInfo{
OutTradeNo: wxRsp.Response.OutTradeNo,
TradeId: wxRsp.Response.TransactionId,
Amount: fmt.Sprintf("%d", wxRsp.Response.Amount.Total/100),
PayTime: wxRsp.Response.SuccessTime,
}
if wxRsp.Response.TradeState == "SUCCESS" {
orderInfo.Status = Success
} else {
orderInfo.Status = Failure
}
return orderInfo, nil
}
// TradeVerify 交易验证
func (s *WxPayService) TradeVerify(request *http.Request) (OrderInfo, error) {
notifyReq, err := wechat.V3ParseNotify(request)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with client v3 parse notify: %v", err)
}
// 解密支付密文,验证订单信息
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with client v3 decrypt: %v", err)
}
return OrderInfo{
Status: Success,
OutTradeNo: result.OutTradeNo,
TradeId: result.TransactionId,
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
PayTime: result.SuccessTime,
}, nil
}
// func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// // 初始化 BodyMap
// bm := make(gopay.BodyMap)
// bm.Set("appid", s.config.AppId).
// Set("mchid", s.config.MchId).
// Set("description", params.Subject).
// Set("out_trade_no", params.OutTradeNo).
// Set("time_expire", expire).
// Set("notify_url", params.NotifyURL).
// SetBodyMap("amount", func(bm gopay.BodyMap) {
// bm.Set("total", params.TotalFee).
// Set("currency", "CNY")
// })
// wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
// if err != nil {
// return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
// }
// if wxRsp.Code != wechat.Success {
// return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
// }
// return wxRsp.Response.CodeUrl, nil
// }
// func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// // 初始化 BodyMap
// bm := make(gopay.BodyMap)
// bm.Set("appid", s.config.AppId).
// Set("mchid", s.config.MchId).
// Set("description", params.Subject).
// Set("out_trade_no", params.OutTradeNo).
// Set("time_expire", expire).
// Set("notify_url", params.NotifyURL).
// SetBodyMap("amount", func(bm gopay.BodyMap) {
// bm.Set("total", params.TotalFee).
// Set("currency", "CNY")
// }).
// SetBodyMap("scene_info", func(bm gopay.BodyMap) {
// bm.Set("payer_client_ip", params.ClientIP).
// SetBodyMap("h5_info", func(bm gopay.BodyMap) {
// bm.Set("type", "Wap")
// })
// })
// wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
// if err != nil {
// return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
// }
// if wxRsp.Code != wechat.Success {
// return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
// }
// return wxRsp.Response.H5Url, nil
// }
// type NotifyResponse struct {
// Code string `json:"code"`
// Message string `xml:"message"`
// }
var _ PayService = (*WxPayService)(nil)

View File

@@ -10,36 +10,57 @@ package sms
import (
"fmt"
"geekai/core/types"
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
)
type AliYunSmsService struct {
config *types.SmsConfigAli
config types.SmsConfigAli
client *dysmsapi.Client
domain string
zoneId string
}
func NewAliYunSmsService(appConfig *types.AppConfig) (*AliYunSmsService, error) {
config := &appConfig.SMS.Ali
// 创建阿里云短信客户端
func NewAliYunSmsService(sysConfig *types.SystemConfig) (*AliYunSmsService, error) {
config := sysConfig.SMS.Ali
domain := "dysmsapi.aliyuncs.com"
zoneId := "cn-hangzhou"
s := AliYunSmsService{
config: config,
domain: domain,
zoneId: zoneId,
}
if sysConfig.SMS.Active == Ali {
err := s.UpdateConfig(config)
if err != nil {
logger.Errorf("阿里云短信初始化失败: %v", err)
}
}
return &s, nil
}
func (s *AliYunSmsService) UpdateConfig(config types.SmsConfigAli) error {
client, err := dysmsapi.NewClientWithAccessKey(
"cn-hangzhou",
s.zoneId,
config.AccessKey,
config.AccessSecret)
if err != nil {
return nil, fmt.Errorf("failed to create client: %v", err)
return fmt.Errorf("failed to create client: %v", err)
}
return &AliYunSmsService{
config: config,
client: client,
}, nil
s.client = client
s.config = config
return nil
}
func (s *AliYunSmsService) SendVerifyCode(mobile string, code int) error {
if s.client == nil {
return fmt.Errorf("阿里云短信服务未初始化")
}
// 创建短信请求并设置参数
request := dysmsapi.CreateSendSmsRequest()
request.Scheme = "https"
request.Domain = s.config.Domain
request.Domain = s.domain
request.PhoneNumbers = mobile
request.SignName = s.config.Sign
request.TemplateCode = s.config.CodeTempId

View File

@@ -19,20 +19,21 @@ import (
)
type BaoSmsService struct {
config *types.SmsConfigBao
config types.SmsConfigBao
domain string
}
func NewSmsBaoSmsService(appConfig *types.AppConfig) *BaoSmsService {
config := appConfig.SMS.Bao
if config.Domain == "" { // use default domain
config.Domain = "api.smsbao.com"
logger.Infof("Using default domain for SMS-BAO: %s", config.Domain)
}
func NewBaoSmsService(sysConfig *types.SystemConfig) *BaoSmsService {
return &BaoSmsService{
config: &config,
config: sysConfig.SMS.Bao,
domain: "api.smsbao.com",
}
}
func (s *BaoSmsService) UpdateConfig(config types.SmsConfigBao) {
s.config = config
}
var errMsg = map[string]string{
"0": "短信发送成功",
"-1": "参数不全",
@@ -56,7 +57,7 @@ func (s *BaoSmsService) SendVerifyCode(mobile string, code int) error {
params.Set("m", mobile)
params.Set("c", content)
apiURL := fmt.Sprintf("https://%s/sms?%s", s.config.Domain, params.Encode())
apiURL := fmt.Sprintf("https://%s/sms?%s", s.domain, params.Encode())
response, err := http.Get(apiURL)
if err != nil {
return err

View File

@@ -7,8 +7,8 @@ package sms
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
const Ali = "ALI"
const Bao = "BAO"
const Ali = "aliyun"
const Bao = "bao"
type Service interface {
SendVerifyCode(mobile string, code int) error

View File

@@ -1,46 +0,0 @@
package sms
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core/types"
logger2 "geekai/logger"
"strings"
)
type ServiceManager struct {
handler Service
}
var logger = logger2.GetLogger()
func NewSendServiceManager(config *types.AppConfig) (*ServiceManager, error) {
active := Ali
if config.SMS.Active != "" {
active = strings.ToUpper(config.SMS.Active)
}
var handler Service
switch active {
case Ali:
client, err := NewAliYunSmsService(config)
if err != nil {
return nil, err
}
handler = client
break
case Bao:
handler = NewSmsBaoSmsService(config)
break
}
return &ServiceManager{handler: handler}, nil
}
func (m *ServiceManager) GetService() Service {
return m.handler
}

View File

@@ -0,0 +1,54 @@
package sms
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core/types"
logger2 "geekai/logger"
)
type SmsManager struct {
aliyun *AliYunSmsService
bao *BaoSmsService
active string
}
var logger = logger2.GetLogger()
func NewSmsManager(sysConfig *types.SystemConfig, aliyun *AliYunSmsService, bao *BaoSmsService) (*SmsManager, error) {
return &SmsManager{
active: sysConfig.SMS.Active,
aliyun: aliyun,
bao: bao,
}, nil
}
func (m *SmsManager) GetService() Service {
switch m.active {
case Ali:
return m.aliyun
case Bao:
return m.bao
}
return nil
}
func (m *SmsManager) SetActive(active string) {
m.active = active
}
func (m *SmsManager) UpdateConfig(config types.SMSConfig) {
switch config.Active {
case Ali:
m.aliyun.UpdateConfig(config.Ali)
case Bao:
m.bao.UpdateConfig(config.Bao)
}
m.active = config.Active
}

Some files were not shown because too many files have changed in this diff Show More