mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-04-26 04:54:28 +08:00
merge v4.2.6
整合 v4.2.6 的后端中间件与服务层重构、前端样式体系迁移和管理端/移动端功能更新,统一清理历史冲突并完成版本升级。 Made-with: Cursor
This commit is contained in:
@@ -8,118 +8,50 @@ package core
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/nfnt/resize"
|
||||
"github.com/shirou/gopsutil/host"
|
||||
"golang.org/x/image/webp"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AuthConfig 定义授权配置
|
||||
type AuthConfig struct {
|
||||
ExactPaths map[string]bool // 精确匹配的路径
|
||||
PrefixPaths map[string]bool // 前缀匹配的路径
|
||||
}
|
||||
|
||||
var authConfig = &AuthConfig{
|
||||
ExactPaths: map[string]bool{
|
||||
"/api/user/login": false,
|
||||
"/api/user/logout": false,
|
||||
"/api/user/resetPass": false,
|
||||
"/api/user/register": false,
|
||||
"/api/admin/login": false,
|
||||
"/api/admin/logout": false,
|
||||
"/api/admin/login/captcha": false,
|
||||
"/api/app/list": false,
|
||||
"/api/app/type/list": false,
|
||||
"/api/app/list/user": false,
|
||||
"/api/model/list": false,
|
||||
"/api/mj/imgWall": false,
|
||||
"/api/mj/notify": false,
|
||||
"/api/invite/hits": false,
|
||||
"/api/sd/imgWall": false,
|
||||
"/api/dall/imgWall": false,
|
||||
"/api/product/list": false,
|
||||
"/api/menu/list": false,
|
||||
"/api/markMap/client": false,
|
||||
"/api/payment/doPay": false,
|
||||
"/api/payment/payWays": false,
|
||||
"/api/download": false,
|
||||
"/api/dall/models": false,
|
||||
},
|
||||
PrefixPaths: map[string]bool{
|
||||
"/api/test/": false,
|
||||
"/api/payment/notify/": false,
|
||||
"/api/user/clogin": false,
|
||||
"/api/config/": false,
|
||||
"/api/function/": false,
|
||||
"/api/sms/": false,
|
||||
"/api/captcha/": false,
|
||||
"/static/": false,
|
||||
},
|
||||
}
|
||||
|
||||
type AppServer struct {
|
||||
Config *types.AppConfig
|
||||
Engine *gin.Engine
|
||||
SysConfig *types.SystemConfig // system config cache
|
||||
Redis *redis.Client
|
||||
}
|
||||
|
||||
func NewServer(appConfig *types.AppConfig) *AppServer {
|
||||
func NewServer(appConfig *types.AppConfig, redis *redis.Client, sysConfig *types.SystemConfig) *AppServer {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
gin.DefaultWriter = io.Discard
|
||||
return &AppServer{
|
||||
Config: appConfig,
|
||||
Engine: gin.Default(),
|
||||
Config: appConfig,
|
||||
Redis: redis,
|
||||
Engine: gin.Default(),
|
||||
SysConfig: sysConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AppServer) Init(debug bool, client *redis.Client) {
|
||||
// 允许跨域请求 API
|
||||
s.Engine.Use(corsMiddleware())
|
||||
s.Engine.Use(staticResourceMiddleware())
|
||||
s.Engine.Use(authorizeMiddleware(s, client))
|
||||
s.Engine.Use(parameterHandlerMiddleware())
|
||||
func (s *AppServer) Init(client *redis.Client) {
|
||||
s.Engine.Use(middleware.ParameterHandlerMiddleware())
|
||||
s.Engine.Use(errorHandler)
|
||||
// 添加静态资源访问
|
||||
s.Engine.Static("/static", s.Config.StaticDir)
|
||||
s.Engine.Use(middleware.StaticMiddleware())
|
||||
}
|
||||
|
||||
func (s *AppServer) Run(db *gorm.DB) error {
|
||||
|
||||
// 重命名 config 表字段
|
||||
if db.Migrator().HasColumn(&model.Config{}, "config_json") {
|
||||
db.Migrator().RenameColumn(&model.Config{}, "config_json", "value")
|
||||
}
|
||||
if db.Migrator().HasColumn(&model.Config{}, "marker") {
|
||||
db.Migrator().RenameColumn(&model.Config{}, "marker", "name")
|
||||
}
|
||||
if db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") {
|
||||
db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key")
|
||||
}
|
||||
if db.Migrator().HasIndex(&model.Config{}, "marker") {
|
||||
db.Migrator().DropIndex(&model.Config{}, "marker")
|
||||
}
|
||||
|
||||
// load system configs
|
||||
var sysConfig model.Config
|
||||
err := db.Where("name", "system").First(&sysConfig).Error
|
||||
@@ -131,94 +63,22 @@ func (s *AppServer) Run(db *gorm.DB) error {
|
||||
return fmt.Errorf("failed to decode system config: %v", err)
|
||||
}
|
||||
|
||||
// 迁移数据表
|
||||
logger.Info("Migrating database tables...")
|
||||
db.AutoMigrate(
|
||||
&model.ChatItem{},
|
||||
&model.ChatMessage{},
|
||||
&model.ChatRole{},
|
||||
&model.ChatModel{},
|
||||
&model.InviteCode{},
|
||||
&model.InviteLog{},
|
||||
&model.Menu{},
|
||||
&model.Order{},
|
||||
&model.Product{},
|
||||
&model.User{},
|
||||
&model.Function{},
|
||||
&model.File{},
|
||||
&model.Redeem{},
|
||||
&model.Config{},
|
||||
&model.ApiKey{},
|
||||
&model.AdminUser{},
|
||||
&model.AppType{},
|
||||
&model.SdJob{},
|
||||
&model.SunoJob{},
|
||||
&model.PowerLog{},
|
||||
&model.VideoJob{},
|
||||
&model.MidJourneyJob{},
|
||||
&model.UserLoginLog{},
|
||||
&model.DallJob{},
|
||||
&model.JimengJob{},
|
||||
)
|
||||
// 手动删除字段
|
||||
if db.Migrator().HasColumn(&model.Order{}, "deleted_at") {
|
||||
db.Migrator().DropColumn(&model.Order{}, "deleted_at")
|
||||
}
|
||||
if db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") {
|
||||
db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at")
|
||||
}
|
||||
if db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") {
|
||||
db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at")
|
||||
}
|
||||
if db.Migrator().HasColumn(&model.User{}, "chat_config") {
|
||||
db.Migrator().DropColumn(&model.User{}, "chat_config")
|
||||
}
|
||||
if db.Migrator().HasColumn(&model.ChatModel{}, "category") {
|
||||
db.Migrator().DropColumn(&model.ChatModel{}, "category")
|
||||
}
|
||||
if db.Migrator().HasColumn(&model.ChatModel{}, "description") {
|
||||
db.Migrator().DropColumn(&model.ChatModel{}, "description")
|
||||
}
|
||||
|
||||
logger.Info("Database tables migrated successfully")
|
||||
|
||||
// 统计安装信息
|
||||
go func() {
|
||||
info, err := host.Info()
|
||||
if err == nil {
|
||||
apiURL := fmt.Sprintf("%s/%s", s.Config.ApiConfig.ApiURL, "api/installs/push")
|
||||
apiURL := fmt.Sprintf("%s/api/installs/push", types.GeekAPIURL)
|
||||
timestamp := time.Now().Unix()
|
||||
product := "geekai-plus"
|
||||
signStr := fmt.Sprintf("%s#%s#%d", product, info.HostID, timestamp)
|
||||
sign := utils.Sha256(signStr)
|
||||
resp, err := req.C().R().SetBody(map[string]interface{}{"product": product, "device_id": info.HostID, "timestamp": timestamp, "sign": sign}).Post(apiURL)
|
||||
if err != nil {
|
||||
logger.Errorf("register install info failed: %v", err)
|
||||
} else {
|
||||
if err == nil {
|
||||
logger.Debugf("register install info success: %v", resp.String())
|
||||
}
|
||||
}
|
||||
}()
|
||||
logger.Infof("http://%s", s.Config.Listen)
|
||||
|
||||
// 统计安装信息
|
||||
go func() {
|
||||
info, err := host.Info()
|
||||
if err == nil {
|
||||
apiURL := fmt.Sprintf("%s/%s", s.Config.ApiConfig.ApiURL, "api/installs/push")
|
||||
timestamp := time.Now().Unix()
|
||||
product := "geekai-plus"
|
||||
signStr := fmt.Sprintf("%s#%s#%d", product, info.HostID, timestamp)
|
||||
sign := utils.Sha256(signStr)
|
||||
resp, err := req.C().R().SetBody(map[string]interface{}{"product": product, "device_id": info.HostID, "timestamp": timestamp, "sign": sign}).Post(apiURL)
|
||||
if err != nil {
|
||||
logger.Errorf("register install info failed: %v", err)
|
||||
} else {
|
||||
logger.Debugf("register install info success: %v", resp.String())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return s.Engine.Run(s.Config.Listen)
|
||||
}
|
||||
|
||||
@@ -235,283 +95,3 @@ func errorHandler(c *gin.Context) {
|
||||
//加载完 defer recover,继续后续接口调用
|
||||
c.Next()
|
||||
}
|
||||
|
||||
// 跨域中间件设置
|
||||
func corsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
method := c.Request.Method
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
|
||||
// 设置允许的请求源
|
||||
if origin != "" {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
} else {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
|
||||
//允许跨域设置可以返回其他子段,可以自定义字段
|
||||
c.Header("Access-Control-Allow-Headers", "Authorization, Body-Length, Body-Type, Admin-Authorization,content-type")
|
||||
// 允许浏览器(客户端)可以解析的头部 (重要)
|
||||
c.Header("Access-Control-Expose-Headers", "Body-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
|
||||
//设置缓存时间
|
||||
c.Header("Access-Control-Max-Age", "172800")
|
||||
//允许客户端传递校验信息比如 cookie (重要)
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if method == http.MethodOptions {
|
||||
c.JSON(http.StatusOK, "ok!")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Info("Panic info is: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// 用户授权验证
|
||||
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !needLogin(c) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||
var tokenString string
|
||||
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
|
||||
if isAdminApi { // 后台管理 API
|
||||
tokenString = c.GetHeader(types.AdminAuthHeader)
|
||||
} else if clientProtocols != "" { // Websocket 连接
|
||||
// 解析子协议内容
|
||||
protocols := strings.Split(clientProtocols, ",")
|
||||
if protocols[0] == "realtime" {
|
||||
tokenString = strings.TrimSpace(protocols[1][25:])
|
||||
} else if protocols[0] == "token" {
|
||||
tokenString = strings.TrimSpace(protocols[1])
|
||||
}
|
||||
} else {
|
||||
tokenString = c.GetHeader(types.UserAuthHeader)
|
||||
}
|
||||
|
||||
if tokenString == "" {
|
||||
resp.NotAuth(c, "You should put Authorization in request headers")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
if isAdminApi {
|
||||
return []byte(s.Config.AdminSession.SecretKey), nil
|
||||
} else {
|
||||
return []byte(s.Config.Session.SecretKey), nil
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
resp.NotAuth(c, "Token is invalid")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
||||
if expr > 0 && int64(expr) < time.Now().Unix() {
|
||||
resp.NotAuth(c, "Token is expired")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("users/%v", claims["user_id"])
|
||||
if isAdminApi {
|
||||
key = fmt.Sprintf("admin/%v", claims["user_id"])
|
||||
}
|
||||
if _, err := client.Get(context.Background(), key).Result(); err != nil {
|
||||
resp.NotAuth(c, "Token is not found in redis")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Set(types.LoginUserID, claims["user_id"])
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func needLogin(c *gin.Context) bool {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// 如果不是 API 路径,不需要登录
|
||||
if !strings.HasPrefix(path, "/api") {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查精确匹配的路径
|
||||
if skip, exists := authConfig.ExactPaths[path]; exists {
|
||||
return skip
|
||||
}
|
||||
|
||||
// 检查前缀匹配的路径
|
||||
for prefix, skip := range authConfig.PrefixPaths {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
return skip
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// 跳过授权
|
||||
func (s *AppServer) SkipAuth(url string, prefix bool) {
|
||||
if prefix {
|
||||
authConfig.PrefixPaths[url] = false
|
||||
} else {
|
||||
authConfig.ExactPaths[url] = false
|
||||
}
|
||||
}
|
||||
|
||||
// 统一参数处理
|
||||
func parameterHandlerMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// GET 参数处理
|
||||
params := c.Request.URL.Query()
|
||||
for key, values := range params {
|
||||
for i, value := range values {
|
||||
params[key][i] = strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
// update get parameters
|
||||
c.Request.URL.RawQuery = params.Encode()
|
||||
// skip file upload requests
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
// process POST JSON request body
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 还原请求体
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
// 将请求体解析为 JSON
|
||||
var jsonData map[string]interface{}
|
||||
if err := c.ShouldBindJSON(&jsonData); err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 对 JSON 数据中的字符串值去除两端空格
|
||||
trimJSONStrings(jsonData)
|
||||
// 更新请求体
|
||||
c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData)))
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// 递归对 JSON 数据中的字符串值去除两端空格
|
||||
func trimJSONStrings(data interface{}) {
|
||||
switch v := data.(type) {
|
||||
case map[string]interface{}:
|
||||
for key, value := range v {
|
||||
switch valueType := value.(type) {
|
||||
case string:
|
||||
v[key] = strings.TrimSpace(valueType)
|
||||
case map[string]interface{}, []interface{}:
|
||||
trimJSONStrings(value)
|
||||
}
|
||||
}
|
||||
case []interface{}:
|
||||
for i, value := range v {
|
||||
switch valueType := value.(type) {
|
||||
case string:
|
||||
v[i] = strings.TrimSpace(valueType)
|
||||
case map[string]interface{}, []interface{}:
|
||||
trimJSONStrings(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 静态资源中间件
|
||||
func staticResourceMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
|
||||
url := c.Request.URL.String()
|
||||
// 拦截生成缩略图请求
|
||||
if strings.HasPrefix(url, "/static/") && strings.Contains(url, "?imageView2") {
|
||||
r := strings.SplitAfter(url, "imageView2")
|
||||
size := strings.Split(r[1], "/")
|
||||
if len(size) != 8 {
|
||||
c.String(http.StatusNotFound, "invalid thumb args")
|
||||
return
|
||||
}
|
||||
with := utils.IntValue(size[3], 0)
|
||||
height := utils.IntValue(size[5], 0)
|
||||
quality := utils.IntValue(size[7], 75)
|
||||
|
||||
// 打开图片文件
|
||||
filePath := strings.TrimLeft(c.Request.URL.Path, "/")
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
c.String(http.StatusNotFound, "Image not found")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 解码图片
|
||||
img, _, err := image.Decode(file)
|
||||
// for .webp image
|
||||
if err != nil {
|
||||
img, err = webp.Decode(file)
|
||||
}
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, "Error decoding image")
|
||||
return
|
||||
}
|
||||
|
||||
var newImg image.Image
|
||||
if height == 0 || with == 0 {
|
||||
// 固定宽度,高度自适应
|
||||
newImg = resize.Resize(uint(with), uint(height), img, resize.Lanczos3)
|
||||
} else {
|
||||
// 生成缩略图
|
||||
newImg = resize.Thumbnail(uint(with), uint(height), img, resize.Lanczos3)
|
||||
}
|
||||
var buffer bytes.Buffer
|
||||
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 设置图片缓存有效期为一年 (365天)
|
||||
c.Header("Cache-Control", "max-age=31536000, public")
|
||||
// 直接输出图像数据流
|
||||
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
|
||||
c.Abort() // 中断请求
|
||||
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,10 +11,12 @@ import (
|
||||
"bytes"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"os"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
@@ -30,7 +32,6 @@ func NewDefaultConfig() *types.AppConfig {
|
||||
SecretKey: utils.RandString(64),
|
||||
MaxAge: 86400,
|
||||
},
|
||||
ApiConfig: types.ApiConfig{},
|
||||
OSS: types.OSSConfig{
|
||||
Active: "local",
|
||||
Local: types.LocalStorageConfig{
|
||||
@@ -38,7 +39,6 @@ func NewDefaultConfig() *types.AppConfig {
|
||||
BasePath: "./static/upload",
|
||||
},
|
||||
},
|
||||
AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,3 +74,108 @@ func SaveConfig(config *types.AppConfig) error {
|
||||
|
||||
return os.WriteFile(config.Path, buf.Bytes(), 0644)
|
||||
}
|
||||
|
||||
func LoadSystemConfig(db *gorm.DB) *types.SystemConfig {
|
||||
// 加载系统配置
|
||||
var sysConfig model.Config
|
||||
var baseConfig types.BaseConfig
|
||||
db.Where("name", "system").First(&sysConfig)
|
||||
err := utils.JsonDecode(sysConfig.Value, &baseConfig)
|
||||
if err != nil {
|
||||
logger.Error("load system config error: ", err)
|
||||
}
|
||||
|
||||
// 加载许可证配置
|
||||
var license types.License
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeyLicense).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &license)
|
||||
if err != nil {
|
||||
logger.Error("load license config error: ", err)
|
||||
}
|
||||
|
||||
// 加载验证码配置
|
||||
var captchaConfig types.CaptchaConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeyCaptcha).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &captchaConfig)
|
||||
if err != nil {
|
||||
logger.Error("load geek service config error: ", err)
|
||||
}
|
||||
|
||||
// 加载微信登录配置
|
||||
var wxLoginConfig types.WxLoginConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeyWxLogin).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &wxLoginConfig)
|
||||
if err != nil {
|
||||
logger.Error("load wx login config error: ", err)
|
||||
}
|
||||
|
||||
// 加载短信配置
|
||||
var smsConfig types.SMSConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeySms).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &smsConfig)
|
||||
if err != nil {
|
||||
logger.Error("load sms config error: ", err)
|
||||
}
|
||||
|
||||
// 加载 OSS 配置
|
||||
var ossConfig types.OSSConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeyOss).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &ossConfig)
|
||||
if err != nil {
|
||||
logger.Error("load oss config error: ", err)
|
||||
}
|
||||
|
||||
// 加载 SMTP 配置
|
||||
var smtpConfig types.SmtpConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeySmtp).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &smtpConfig)
|
||||
if err != nil {
|
||||
logger.Error("load smtp config error: ", err)
|
||||
}
|
||||
|
||||
// 加载支付配置
|
||||
var paymentConfig types.PaymentConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeyPayment).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &paymentConfig)
|
||||
if err != nil {
|
||||
logger.Error("load payment config error: ", err)
|
||||
}
|
||||
|
||||
// 加载文本审查配置
|
||||
var moderationConfig types.ModerationConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeyModeration).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &moderationConfig)
|
||||
if err != nil {
|
||||
logger.Error("load moderation config error: ", err)
|
||||
}
|
||||
|
||||
// 加载即梦AI配置
|
||||
var jimengConfig types.JimengConfig
|
||||
sysConfig.Id = 0
|
||||
db.Where("name", types.ConfigKeyJimeng).First(&sysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &jimengConfig)
|
||||
if err != nil {
|
||||
logger.Error("load jimeng config error: ", err)
|
||||
}
|
||||
|
||||
return &types.SystemConfig{
|
||||
Base: baseConfig,
|
||||
License: license,
|
||||
SMS: smsConfig,
|
||||
OSS: ossConfig,
|
||||
SMTP: smtpConfig,
|
||||
Payment: paymentConfig,
|
||||
Captcha: captchaConfig,
|
||||
WxLogin: wxLoginConfig,
|
||||
Moderation: moderationConfig,
|
||||
Jimeng: jimengConfig,
|
||||
}
|
||||
}
|
||||
|
||||
109
api/core/middleware/auth.go
Normal file
109
api/core/middleware/auth.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
// 前端用户授权验证
|
||||
func UserAuthMiddleware(secretKey string, redis *redis.Client) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tokenString := c.GetHeader(types.UserAuthHeader)
|
||||
if tokenString == "" {
|
||||
resp.NotAuth(c, "无效的授权令牌")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("不支持的令牌签名方法: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(secretKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
resp.NotAuth(c, fmt.Sprintf("解析授权令牌失败: %v", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
resp.NotAuth(c, "令牌无效")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
||||
if expr > 0 && int64(expr) < time.Now().Unix() {
|
||||
resp.NotAuth(c, "令牌过期")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("users/%v", claims["user_id"])
|
||||
if _, err := redis.Get(context.Background(), key).Result(); err != nil {
|
||||
resp.NotAuth(c, "当前用户已退出登录")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Set(types.LoginUserID, claims["user_id"])
|
||||
}
|
||||
}
|
||||
|
||||
// 管理后台用户授权验证
|
||||
func AdminAuthMiddleware(secretKey string, redis *redis.Client) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tokenString := c.GetHeader(types.AdminAuthHeader)
|
||||
if tokenString == "" {
|
||||
resp.NotAuth(c, "无效的授权令牌")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("不支持的令牌签名方法: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(secretKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
resp.NotAuth(c, fmt.Sprintf("解析授权令牌失败: %v", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
resp.NotAuth(c, "令牌无效")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
||||
if expr > 0 && int64(expr) < time.Now().Unix() {
|
||||
resp.NotAuth(c, "令牌过期")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("admin/%v", claims["user_id"])
|
||||
if _, err := redis.Get(context.Background(), key).Result(); err != nil {
|
||||
resp.NotAuth(c, "当前用户已退出登录")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(types.AdminUserID, claims["user_id"])
|
||||
}
|
||||
}
|
||||
80
api/core/middleware/parameter.go
Normal file
80
api/core/middleware/parameter.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"geekai/utils"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 统一参数处理
|
||||
func ParameterHandlerMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// GET 参数处理
|
||||
params := c.Request.URL.Query()
|
||||
for key, values := range params {
|
||||
for i, value := range values {
|
||||
params[key][i] = strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
// update get parameters
|
||||
c.Request.URL.RawQuery = params.Encode()
|
||||
// skip file upload requests
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
// process POST JSON request body
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 还原请求体
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
// 将请求体解析为 JSON
|
||||
var jsonData map[string]any
|
||||
if err := c.ShouldBindJSON(&jsonData); err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 对 JSON 数据中的字符串值去除两端空格
|
||||
trimJSONStrings(jsonData)
|
||||
// 更新请求体
|
||||
c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData)))
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// 递归对 JSON 数据中的字符串值去除两端空格
|
||||
func trimJSONStrings(data any) {
|
||||
switch v := data.(type) {
|
||||
case map[string]any:
|
||||
for key, value := range v {
|
||||
switch valueType := value.(type) {
|
||||
case string:
|
||||
v[key] = strings.TrimSpace(valueType)
|
||||
case map[string]any, []any:
|
||||
trimJSONStrings(value)
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for i, value := range v {
|
||||
switch valueType := value.(type) {
|
||||
case string:
|
||||
v[i] = strings.TrimSpace(valueType)
|
||||
case map[string]any, []any:
|
||||
trimJSONStrings(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
43
api/core/middleware/rate_limit.go
Normal file
43
api/core/middleware/rate_limit.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
// RateLimitEvery 使用 Redis 做固定间隔限流:在 interval 内仅允许一次请求
|
||||
// Key 优先使用登录用户ID,若没有则退化为 route + IP
|
||||
func RateLimitEvery(redisClient *redis.Client, interval time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
keyID := ""
|
||||
if userID, ok := c.Get(types.LoginUserID); ok {
|
||||
keyID = fmt.Sprintf("user:%s", utils.InterfaceToString(userID))
|
||||
} else {
|
||||
keyID = fmt.Sprintf("ip:%s", c.ClientIP())
|
||||
}
|
||||
|
||||
fullPath := c.FullPath()
|
||||
if fullPath == "" {
|
||||
fullPath = c.Request.URL.Path
|
||||
}
|
||||
key := fmt.Sprintf("rl:%s:%s", fullPath, keyID)
|
||||
|
||||
okSet, err := redisClient.SetNX(context.Background(), key, 1, interval).Result()
|
||||
if err != nil {
|
||||
// Redis 异常时放行,避免误伤可用性
|
||||
return
|
||||
}
|
||||
if !okSet {
|
||||
c.JSON(http.StatusTooManyRequests, types.BizVo{Code: types.Failed, Message: "请求过于频繁,请稍后重试"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
78
api/core/middleware/static.go
Normal file
78
api/core/middleware/static.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"geekai/utils"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/nfnt/resize"
|
||||
"golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
// 静态资源中间件
|
||||
func StaticMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
|
||||
url := c.Request.URL.String()
|
||||
// 拦截生成缩略图请求
|
||||
if strings.HasPrefix(url, "/static/") && strings.Contains(url, "?imageView2") {
|
||||
r := strings.SplitAfter(url, "imageView2")
|
||||
size := strings.Split(r[1], "/")
|
||||
if len(size) != 8 {
|
||||
c.String(http.StatusNotFound, "invalid thumb args")
|
||||
return
|
||||
}
|
||||
with := utils.IntValue(size[3], 0)
|
||||
height := utils.IntValue(size[5], 0)
|
||||
quality := utils.IntValue(size[7], 75)
|
||||
|
||||
// 打开图片文件
|
||||
filePath := strings.TrimLeft(c.Request.URL.Path, "/")
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
c.String(http.StatusNotFound, "Image not found")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 解码图片
|
||||
img, _, err := image.Decode(file)
|
||||
// for .webp image
|
||||
if err != nil {
|
||||
img, err = webp.Decode(file)
|
||||
}
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, "Error decoding image")
|
||||
return
|
||||
}
|
||||
|
||||
var newImg image.Image
|
||||
if height == 0 || with == 0 {
|
||||
// 固定宽度,高度自适应
|
||||
newImg = resize.Resize(uint(with), uint(height), img, resize.Lanczos3)
|
||||
} else {
|
||||
// 生成缩略图
|
||||
newImg = resize.Thumbnail(uint(with), uint(height), img, resize.Lanczos3)
|
||||
}
|
||||
var buffer bytes.Buffer
|
||||
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 设置图片缓存有效期为一年 (365天)
|
||||
c.Header("Cache-Control", "max-age=31536000, public")
|
||||
// 直接输出图像数据流
|
||||
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
|
||||
c.Abort() // 中断请求
|
||||
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -17,88 +17,17 @@ type AppConfig struct {
|
||||
Session Session
|
||||
AdminSession Session
|
||||
ProxyURL string
|
||||
MysqlDns string // mysql 连接地址
|
||||
StaticDir string // 静态资源目录
|
||||
StaticUrl string // 静态资源 URL
|
||||
Redis RedisConfig // redis 连接信息
|
||||
ApiConfig ApiConfig // ChatPlus API authorization configs
|
||||
SMS SMSConfig // send mobile message config
|
||||
OSS OSSConfig // OSS config
|
||||
SmtpConfig SmtpConfig // 邮件发送配置
|
||||
XXLConfig XXLConfig
|
||||
AlipayConfig AlipayConfig // 支付宝支付渠道配置
|
||||
HuPiPayConfig HuPiPayConfig // 虎皮椒支付配置
|
||||
GeekPayConfig GeekPayConfig // GEEK 支付配置
|
||||
WechatPayConfig WechatPayConfig // 微信支付渠道配置
|
||||
TikaHost string // TiKa 服务器地址
|
||||
}
|
||||
|
||||
type SmtpConfig struct {
|
||||
UseTls bool // 是否使用 TLS 发送
|
||||
Host string
|
||||
Port int
|
||||
AppName string // 应用名称
|
||||
From string // 发件人邮箱地址
|
||||
Password string // 发件人邮箱密码
|
||||
}
|
||||
|
||||
type ApiConfig struct {
|
||||
ApiURL string
|
||||
AppId string
|
||||
Token string
|
||||
JimengConfig JimengConfig // 即梦AI配置
|
||||
}
|
||||
|
||||
type AlipayConfig struct {
|
||||
Enabled bool // 是否启用该支付通道
|
||||
SandBox bool // 是否沙盒环境
|
||||
AppId string // 应用 ID
|
||||
UserId string // 支付宝用户 ID
|
||||
PrivateKey string // 用户私钥文件路径
|
||||
PublicKey string // 用户公钥文件路径
|
||||
AlipayPublicKey string // 支付宝公钥文件路径
|
||||
RootCert string // Root 秘钥路径
|
||||
NotifyURL string // 异步通知地址
|
||||
ReturnURL string // 同步通知地址
|
||||
}
|
||||
|
||||
type WechatPayConfig struct {
|
||||
Enabled bool // 是否启用该支付通道
|
||||
AppId string // 公众号的APPID,如:wxd678efh567hg6787
|
||||
MchId string // 直连商户的商户号,由微信支付生成并下发
|
||||
SerialNo string // 商户证书的证书序列号
|
||||
PrivateKey string // 用户私钥文件路径
|
||||
ApiV3Key string // API V3 秘钥
|
||||
NotifyURL string // 异步通知地址
|
||||
}
|
||||
|
||||
type HuPiPayConfig struct { //虎皮椒第四方支付配置
|
||||
Enabled bool // 是否启用该支付通道
|
||||
AppId string // App ID
|
||||
AppSecret string // app 密钥
|
||||
ApiURL string // 支付网关
|
||||
NotifyURL string // 异步通知地址
|
||||
ReturnURL string // 同步通知地址
|
||||
}
|
||||
|
||||
// GeekPayConfig GEEK支付配置
|
||||
type GeekPayConfig struct {
|
||||
Enabled bool
|
||||
AppId string // 商户 ID
|
||||
PrivateKey string // 私钥
|
||||
ApiURL string // API 网关
|
||||
NotifyURL string // 异步通知地址
|
||||
ReturnURL string // 同步通知地址
|
||||
Methods []string // 支付方式
|
||||
}
|
||||
|
||||
type XXLConfig struct { // XXL 任务调度配置
|
||||
Enabled bool
|
||||
ServerAddr string
|
||||
ExecutorIp string
|
||||
ExecutorPort string
|
||||
AccessToken string
|
||||
RegistryKey string
|
||||
MysqlDns string // mysql 连接地址
|
||||
StaticDir string // 静态资源目录
|
||||
StaticUrl string // 静态资源 URL
|
||||
Redis RedisConfig // redis 连接信息
|
||||
SMS SMSConfig // send mobile message config
|
||||
OSS OSSConfig // OSS config
|
||||
SmtpConfig SmtpConfig // 邮件发送配置
|
||||
AlipayConfig AlipayConfig // 支付宝支付渠道配置
|
||||
GeekPayConfig EpayConfig // GEEK 支付配置
|
||||
WechatPayConfig WxPayConfig // 微信支付渠道配置
|
||||
TikaHost string // TiKa 服务器地址
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
@@ -128,32 +57,28 @@ func (c RedisConfig) Url() string {
|
||||
return fmt.Sprintf("%s:%d", c.Host, c.Port)
|
||||
}
|
||||
|
||||
type SystemConfig struct {
|
||||
Title string `json:"title,omitempty"` // 网站标题
|
||||
Slogan string `json:"slogan,omitempty"` // 网站 slogan
|
||||
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
|
||||
Logo string `json:"logo,omitempty"` // 圆形 Logo
|
||||
BarLogo string `json:"bar_logo,omitempty"` // 条形 Logo
|
||||
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
|
||||
DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
|
||||
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
|
||||
VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
|
||||
type BaseConfig struct {
|
||||
Title string `json:"title,omitempty"` // 网站标题
|
||||
Slogan string `json:"slogan,omitempty"` // 网站 slogan
|
||||
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
|
||||
Logo string `json:"logo,omitempty"` // 圆形 Logo
|
||||
BarLogo string `json:"bar_logo,omitempty"` // 条形 Logo
|
||||
|
||||
RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机(mobile),邮箱注册(email),账号密码注册
|
||||
EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册
|
||||
|
||||
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间
|
||||
VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明
|
||||
OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间,单位:分钟
|
||||
|
||||
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
|
||||
DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
|
||||
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
|
||||
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
||||
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
|
||||
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
||||
DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
|
||||
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
|
||||
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
|
||||
KeLingPowers map[string]int `json:"keling_powers,omitempty"` // 可灵生成视频消耗算力
|
||||
AdvanceVoicePower int `json:"advance_voice_power,omitempty"` // 高级语音对话消耗算力
|
||||
PromptPower int `json:"prompt_power,omitempty"` // 生成提示词消耗算力
|
||||
|
||||
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
|
||||
|
||||
@@ -163,15 +88,44 @@ type SystemConfig struct {
|
||||
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
|
||||
MjMode string `json:"mj_mode"` // midjourney 默认的API模式,relax, fast, turbo
|
||||
|
||||
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
|
||||
Copyright string `json:"copyright"` // 版权信息
|
||||
DefaultNickname string `json:"default_nickname"` // 默认昵称
|
||||
ICP string `json:"icp"` // ICP 备案号
|
||||
MarkMapText string `json:"mark_map_text"` // 思维导入的默认文本
|
||||
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
|
||||
Copyright string `json:"copyright"` // 版权信息
|
||||
ICP string `json:"icp"` // ICP 备案号
|
||||
GaBeian string `json:"ga_beian"` // 公安备案号
|
||||
|
||||
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
|
||||
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
|
||||
AssistantModelId int `json:"assistant_model_id"` // 用来做提示词,翻译的AI模型 id
|
||||
MaxFileSize int `json:"max_file_size"` // 最大文件大小,单位:MB
|
||||
|
||||
}
|
||||
|
||||
type SystemConfig struct {
|
||||
Base BaseConfig
|
||||
Payment PaymentConfig
|
||||
OSS OSSConfig
|
||||
SMS SMSConfig
|
||||
SMTP SmtpConfig
|
||||
Captcha CaptchaConfig
|
||||
WxLogin WxLoginConfig
|
||||
Jimeng JimengConfig
|
||||
License License
|
||||
Moderation ModerationConfig
|
||||
}
|
||||
|
||||
// 配置键名常量
|
||||
const (
|
||||
ConfigKeySystem = "system"
|
||||
ConfigKeyNotice = "notice"
|
||||
ConfigKeyAgreement = "agreement"
|
||||
ConfigKeyPrivacy = "privacy"
|
||||
ConfigKeyMarkMap = "mark_map"
|
||||
ConfigKeyCaptcha = "captcha"
|
||||
ConfigKeyWxLogin = "wx_login"
|
||||
ConfigKeyLicense = "license"
|
||||
ConfigKeySms = "sms"
|
||||
ConfigKeySmtp = "smtp"
|
||||
ConfigKeyOss = "oss"
|
||||
ConfigKeyPayment = "payment"
|
||||
ConfigKeyModeration = "moderation"
|
||||
ConfigKeyAI3D = "ai3d"
|
||||
ConfigKeyJimeng = "jimeng"
|
||||
)
|
||||
|
||||
33
api/core/types/geekai.go
Normal file
33
api/core/types/geekai.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package types
|
||||
|
||||
import "os"
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
// GeekAI 增值服务
|
||||
var GeekAPIURL = "https://sapi.geekai.me"
|
||||
|
||||
func init() {
|
||||
if os.Getenv("GEEK_API_URL") != "" {
|
||||
GeekAPIURL = os.Getenv("GEEK_API_URL")
|
||||
}
|
||||
}
|
||||
|
||||
// CaptchaConfig 行为验证码配置
|
||||
type CaptchaConfig struct {
|
||||
ApiKey string `json:"api_key"`
|
||||
Type string `json:"type"` // 验证码类型, 可选值: "dot" 或 "slide"
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// WxLoginConfig 微信登录配置
|
||||
type WxLoginConfig struct {
|
||||
ApiKey string `json:"api_key"`
|
||||
NotifyURL string `json:"notify_url"` // 登录成功回调 URL
|
||||
Enabled bool `json:"enabled"` // 是否启用微信登录
|
||||
}
|
||||
73
api/core/types/moderation.go
Normal file
73
api/core/types/moderation.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package types
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
// 文本审查
|
||||
type ModerationConfig struct {
|
||||
Enable bool `json:"enable"` // 是否启用文本审查
|
||||
Active string `json:"active"`
|
||||
EnableGuide bool `json:"enable_guide"` // 是否启用模型引导提示词
|
||||
GuidePrompt string `json:"guide_prompt"` // 模型引导提示词
|
||||
Gitee ModerationGiteeConfig `json:"gitee"`
|
||||
Baidu ModerationBaiduConfig `json:"baidu"`
|
||||
Tencent ModerationTencentConfig `json:"tencent"`
|
||||
}
|
||||
|
||||
const (
|
||||
ModerationGitee = "gitee"
|
||||
ModerationBaidu = "baidu"
|
||||
ModerationTencent = "tencent"
|
||||
)
|
||||
|
||||
// GiteeAI 文本审查配置
|
||||
type ModerationGiteeConfig struct {
|
||||
ApiKey string `json:"api_key"`
|
||||
Model string `json:"model"` // 文本审核模型
|
||||
}
|
||||
|
||||
// 百度文本审查配置
|
||||
type ModerationBaiduConfig struct {
|
||||
AccessKey string `json:"access_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
}
|
||||
|
||||
// 腾讯云文本审查配置
|
||||
type ModerationTencentConfig struct {
|
||||
AccessKey string `json:"access_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
}
|
||||
|
||||
type ModerationResult struct {
|
||||
Flagged bool `json:"flagged"`
|
||||
Categories map[string]bool `json:"categories"`
|
||||
CategoryScores map[string]float64 `json:"category_scores"`
|
||||
}
|
||||
|
||||
var ModerationCategories = map[string]string{
|
||||
"politic": "内容涉及人物、事件或敏感的政治观点",
|
||||
"porn": "明确的色情内容",
|
||||
"insult": "具有侮辱、攻击性语言、人身攻击或冒犯性表达",
|
||||
"violence": "包含暴力、血腥、攻击行为或煽动暴力的言论",
|
||||
"illegal": "涉及违法活动的内容,如诈骗、赌博等",
|
||||
"terror": "宣扬恐怖主义、极端暴力或煽动恐怖行为的内容",
|
||||
"ad": "垃圾广告或未经许可的推广内容",
|
||||
"spam": "无意义重复内容或诱导性信息",
|
||||
"abuse": "人身攻击、恶意辱骂或侮辱性言论",
|
||||
"polity": "涉及国家政治、领导人或政策的违规讨论内容",
|
||||
}
|
||||
|
||||
// 敏感词来源
|
||||
const (
|
||||
ModerationSourceChat = "chat"
|
||||
ModerationSourceMJ = "mj"
|
||||
ModerationSourceDalle = "dalle"
|
||||
ModerationSourceSD = "sd"
|
||||
ModerationSourceSuno = "suno"
|
||||
ModerationSourceVideo = "video"
|
||||
ModerationSourceJiMeng = "jimeng"
|
||||
)
|
||||
@@ -11,29 +11,25 @@ type OrderStatus int
|
||||
|
||||
const (
|
||||
OrderNotPaid = OrderStatus(0)
|
||||
OrderScanned = OrderStatus(1) // 已扫码
|
||||
OrderPaidSuccess = OrderStatus(2)
|
||||
OrderPaidSuccess = OrderStatus(2) // 已支付
|
||||
OrderPaidFailed = OrderStatus(3) // 已关闭
|
||||
)
|
||||
|
||||
type OrderRemark struct {
|
||||
Days int `json:"days"` // 有效期
|
||||
Power int `json:"power"` // 增加算力点数
|
||||
Name string `json:"name"` // 产品名称
|
||||
Price float64 `json:"price"`
|
||||
Discount float64 `json:"discount"`
|
||||
Days int `json:"days"` // 有效期
|
||||
Power int `json:"power"` // 增加算力点数
|
||||
Name string `json:"name"` // 产品名称
|
||||
Price float64 `json:"price"`
|
||||
}
|
||||
|
||||
var PayMethods = map[string]string{
|
||||
// PayChannel 支付渠道
|
||||
var PayChannel = map[string]string{
|
||||
"alipay": "支付宝商号",
|
||||
"wechat": "微信商号",
|
||||
"hupi": "虎皮椒",
|
||||
"geek": "易支付",
|
||||
"wxpay": "微信商号",
|
||||
"epay": "易支付",
|
||||
}
|
||||
var PayNames = map[string]string{
|
||||
|
||||
var PayWays = map[string]string{
|
||||
"alipay": "支付宝",
|
||||
"wxpay": "微信支付",
|
||||
"qqpay": "QQ钱包",
|
||||
"jdpay": "京东支付",
|
||||
"douyin": "抖音支付",
|
||||
"paypal": "PayPal支付",
|
||||
}
|
||||
|
||||
@@ -8,41 +8,39 @@ package types
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
type OSSConfig struct {
|
||||
Active string
|
||||
Local LocalStorageConfig
|
||||
Minio MiniOssConfig
|
||||
QiNiu QiNiuOssConfig
|
||||
AliYun AliYunOssConfig
|
||||
Active string `json:"active"`
|
||||
Local LocalStorageConfig `json:"local"`
|
||||
Minio MiniOssConfig `json:"minio"`
|
||||
QiNiu QiNiuOssConfig `json:"qiniu"`
|
||||
AliYun AliYunOssConfig `json:"aliyun"`
|
||||
}
|
||||
|
||||
type MiniOssConfig struct {
|
||||
Endpoint string
|
||||
AccessKey string
|
||||
AccessSecret string
|
||||
Bucket string
|
||||
SubDir string
|
||||
UseSSL bool
|
||||
Domain string
|
||||
Endpoint string `json:"endpoint"`
|
||||
AccessKey string `json:"access_key"`
|
||||
AccessSecret string `json:"access_secret"`
|
||||
Bucket string `json:"bucket"`
|
||||
UseSSL bool `json:"use_ssl"`
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
|
||||
type QiNiuOssConfig struct {
|
||||
Zone string
|
||||
AccessKey string
|
||||
AccessSecret string
|
||||
Bucket string
|
||||
SubDir string
|
||||
Domain string
|
||||
Zone string `json:"zone"`
|
||||
AccessKey string `json:"access_key"`
|
||||
AccessSecret string `json:"access_secret"`
|
||||
Bucket string `json:"bucket"`
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
|
||||
type AliYunOssConfig struct {
|
||||
Endpoint string
|
||||
AccessKey string
|
||||
AccessSecret string
|
||||
Bucket string
|
||||
SubDir string
|
||||
Domain string
|
||||
Endpoint string `json:"endpoint"`
|
||||
AccessKey string `json:"access_key"`
|
||||
AccessSecret string `json:"access_secret"`
|
||||
Bucket string `json:"bucket"`
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
|
||||
type LocalStorageConfig struct {
|
||||
BasePath string
|
||||
BaseURL string
|
||||
BasePath string `json:"base_path"`
|
||||
BaseURL string `json:"base_url"`
|
||||
}
|
||||
|
||||
60
api/core/types/payment.go
Normal file
60
api/core/types/payment.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package types
|
||||
|
||||
type PaymentConfig struct {
|
||||
Alipay AlipayConfig `json:"alipay"` // 支付宝支付渠道配置
|
||||
Epay EpayConfig `json:"epay"` // 易支付配置
|
||||
WxPay WxPayConfig `json:"wxpay"` // 微信支付渠道配置
|
||||
}
|
||||
|
||||
// AlipayConfig 支付宝支付配置
|
||||
type AlipayConfig struct {
|
||||
Enabled bool `json:"enabled"` // 是否启用该支付通道
|
||||
SandBox bool `json:"sandbox"` // 是否沙盒环境
|
||||
AppId string `json:"app_id"` // 应用 ID
|
||||
PrivateKey string `json:"private_key"` // 应用私钥
|
||||
AlipayPublicKey string `json:"alipay_public_key"` // 支付宝公钥
|
||||
Domain string `json:"domain"` // 支付回调域名
|
||||
}
|
||||
|
||||
func (c *AlipayConfig) Equal(other *AlipayConfig) bool {
|
||||
return c.AppId == other.AppId &&
|
||||
c.PrivateKey == other.PrivateKey &&
|
||||
c.AlipayPublicKey == other.AlipayPublicKey &&
|
||||
c.Domain == other.Domain
|
||||
}
|
||||
|
||||
// WxPayConfig 微信支付配置
|
||||
type WxPayConfig struct {
|
||||
Enabled bool `json:"enabled"` // 是否启用该支付通道
|
||||
AppId string `json:"app_id"` // 公众号的APPID,如:wxd678efh567hg6787
|
||||
MchId string `json:"mch_id"` // 直连商户的商户号,由微信支付生成并下发
|
||||
SerialNo string `json:"serial_no"` // 商户证书的证书序列号
|
||||
PrivateKey string `json:"private_key"` // 商户证书私钥
|
||||
ApiV3Key string `json:"api_v3_key"` // API V3 秘钥
|
||||
Domain string `json:"domain"` // 支付回调域名
|
||||
}
|
||||
|
||||
func (c *WxPayConfig) Equal(other *WxPayConfig) bool {
|
||||
return c.AppId == other.AppId &&
|
||||
c.MchId == other.MchId &&
|
||||
c.SerialNo == other.SerialNo &&
|
||||
c.PrivateKey == other.PrivateKey &&
|
||||
c.ApiV3Key == other.ApiV3Key &&
|
||||
c.Domain == other.Domain
|
||||
}
|
||||
|
||||
// EpayConfig 易支付配置
|
||||
type EpayConfig struct {
|
||||
Enabled bool `json:"enabled"` // 是否启用该支付通道
|
||||
AppId string `json:"app_id"` // 商户 ID
|
||||
PrivateKey string `json:"private_key"` // 私钥
|
||||
ApiURL string `json:"api_url"` // z支付 API 网关
|
||||
Domain string `json:"domain"` // 支付回调域名
|
||||
}
|
||||
|
||||
func (c *EpayConfig) Equal(other *EpayConfig) bool {
|
||||
return c.AppId == other.AppId &&
|
||||
c.PrivateKey == other.PrivateKey &&
|
||||
c.ApiURL == other.ApiURL &&
|
||||
c.Domain == other.Domain
|
||||
}
|
||||
@@ -8,6 +8,7 @@ package types
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
const LoginUserID = "LOGIN_USER_ID"
|
||||
const AdminUserID = "ADMIN_USER_ID"
|
||||
const LoginUserCache = "LOGIN_USER_CACHE"
|
||||
|
||||
const UserAuthHeader = "Authorization"
|
||||
|
||||
@@ -8,26 +8,23 @@ package types
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
type SMSConfig struct {
|
||||
Active string
|
||||
Ali SmsConfigAli
|
||||
Bao SmsConfigBao
|
||||
Active string `json:"active"`
|
||||
Ali SmsConfigAli `json:"aliyun"`
|
||||
Bao SmsConfigBao `json:"bao"`
|
||||
}
|
||||
|
||||
// SmsConfigAli 阿里云短信平台配置
|
||||
type SmsConfigAli struct {
|
||||
AccessKey string
|
||||
AccessSecret string
|
||||
Product string
|
||||
Domain string
|
||||
Sign string // 短信签名
|
||||
CodeTempId string // 验证码短信模板 ID
|
||||
AccessKey string `json:"access_key"`
|
||||
AccessSecret string `json:"access_secret"`
|
||||
Sign string `json:"sign"` // 短信签名
|
||||
CodeTempId string `json:"code_temp_id"` // 验证码短信模板 ID
|
||||
}
|
||||
|
||||
// SmsConfigBao 短信宝平台配置
|
||||
type SmsConfigBao struct {
|
||||
Username string //短信宝平台注册的用户名
|
||||
Password string //短信宝平台注册的密码
|
||||
Domain string //域名
|
||||
Sign string // 短信签名
|
||||
CodeTemplate string // 验证码短信模板 匹配
|
||||
Username string `json:"username"` //短信宝平台注册的用户名
|
||||
Password string `json:"password"` //短信宝平台注册的密码
|
||||
Sign string `json:"sign"` // 短信签名
|
||||
CodeTemplate string `json:"code_template"` // 验证码短信模板 匹配
|
||||
}
|
||||
|
||||
26
api/core/types/smtp.go
Normal file
26
api/core/types/smtp.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package types
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
type SmtpConfig struct {
|
||||
UseTls bool `json:"use_tls"` // 是否使用 TLS 发送
|
||||
Host string `json:"host"` // 邮件服务器地址
|
||||
Port int `json:"port"` // 邮件服务器端口
|
||||
AppName string `json:"app_name"` // 应用名称
|
||||
From string `json:"from"` // 发件人邮箱地址
|
||||
Password string `json:"password"` // 发件人邮箱密码
|
||||
}
|
||||
|
||||
func (s *SmtpConfig) Equal(other *SmtpConfig) bool {
|
||||
return s.UseTls == other.UseTls &&
|
||||
s.Host == other.Host &&
|
||||
s.Port == other.Port &&
|
||||
s.AppName == other.AppName &&
|
||||
s.From == other.From &&
|
||||
s.Password == other.Password
|
||||
}
|
||||
@@ -70,17 +70,18 @@ type SdTaskParams struct {
|
||||
|
||||
// DallTask DALL-E task
|
||||
type DallTask struct {
|
||||
ModelId uint `json:"model_id"`
|
||||
ModelName string `json:"model_name"`
|
||||
Id uint `json:"id"`
|
||||
UserId uint `json:"user_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
Style string `json:"style"`
|
||||
Power int `json:"power"`
|
||||
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
|
||||
ModelId uint `json:"model_id"`
|
||||
ModelName string `json:"model_name"`
|
||||
Image []string `json:"image,omitempty"`
|
||||
Id uint `json:"id"`
|
||||
UserId uint `json:"user_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
Style string `json:"style"`
|
||||
Power int `json:"power"`
|
||||
TranslateModelId int `json:"translate_model_id"` // 提示词翻译模型ID
|
||||
}
|
||||
|
||||
type SunoTask struct {
|
||||
|
||||
@@ -4,7 +4,7 @@ build_name: runner-build
|
||||
build_log: runner-build-errors.log
|
||||
valid_ext: .go, .tpl, .tmpl, .html
|
||||
no_rebuild_ext: .tpl, .tmpl, .html, .js, .vue
|
||||
ignored: assets, tmp, web, .git, .idea, test, data
|
||||
ignored: assets, tmp, web, .git, .idea, test, data, static
|
||||
build_delay: 600
|
||||
colors: 1
|
||||
log_color_main: cyan
|
||||
|
||||
10
api/go.mod
10
api/go.mod
@@ -24,11 +24,9 @@ require (
|
||||
gorm.io/driver/mysql v1.4.7
|
||||
)
|
||||
|
||||
require github.com/xxl-job/xxl-job-executor-go v1.2.0
|
||||
|
||||
require (
|
||||
github.com/go-pay/gopay v1.5.101
|
||||
github.com/go-rod/rod v0.116.2
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/go-tika v0.3.1
|
||||
github.com/microcosm-cc/bluemonday v1.0.26
|
||||
github.com/sashabaranov/go-openai v1.38.1
|
||||
@@ -50,11 +48,6 @@ require (
|
||||
github.com/gorilla/css v1.0.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.13 // indirect
|
||||
github.com/tklauser/numcpus v0.7.0 // indirect
|
||||
github.com/ysmood/fetchup v0.3.0 // indirect
|
||||
github.com/ysmood/goob v0.4.0 // indirect
|
||||
github.com/ysmood/got v0.40.0 // indirect
|
||||
github.com/ysmood/gson v0.7.3 // indirect
|
||||
github.com/ysmood/leakless v0.9.0 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
)
|
||||
@@ -69,7 +62,6 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gaukas/godicttls v0.0.3 // indirect
|
||||
github.com/go-basic/ipv4 v1.0.0 // indirect
|
||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
|
||||
22
api/go.sum
22
api/go.sum
@@ -46,8 +46,6 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||
github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs=
|
||||
github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
@@ -80,8 +78,6 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg
|
||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/go-rod/rod v0.116.2 h1:A5t2Ky2A+5eD/ZJQr1EfsQSe5rms5Xof/qj296e+ZqA=
|
||||
github.com/go-rod/rod v0.116.2/go.mod h1:H+CMO9SCNc2TJ2WfrG+pKhITz57uGNYU43qYHh438Mg=
|
||||
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
@@ -89,6 +85,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
|
||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
@@ -261,22 +259,6 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d
|
||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8=
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
|
||||
github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ=
|
||||
github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo=
|
||||
github.com/ysmood/fetchup v0.3.0 h1:UhYz9xnLEVn2ukSuK3KCgcznWpHMdrmbsPpllcylyu8=
|
||||
github.com/ysmood/fetchup v0.3.0/go.mod h1:hbysoq65PXL0NQeNzUczNYIKpwpkwFL4LXMDEvIQq9A=
|
||||
github.com/ysmood/goob v0.4.0 h1:HsxXhyLBeGzWXnqVKtmT9qM7EuVs/XOgkX7T6r1o1AQ=
|
||||
github.com/ysmood/goob v0.4.0/go.mod h1:u6yx7ZhS4Exf2MwciFr6nIM8knHQIE22lFpWHnfql18=
|
||||
github.com/ysmood/gop v0.2.0 h1:+tFrG0TWPxT6p9ZaZs+VY+opCvHU8/3Fk6BaNv6kqKg=
|
||||
github.com/ysmood/gop v0.2.0/go.mod h1:rr5z2z27oGEbyB787hpEcx4ab8cCiPnKxn0SUHt6xzk=
|
||||
github.com/ysmood/got v0.40.0 h1:ZQk1B55zIvS7zflRrkGfPDrPG3d7+JOza1ZkNxcc74Q=
|
||||
github.com/ysmood/got v0.40.0/go.mod h1:W7DdpuX6skL3NszLmAsC5hT7JAhuLZhByVzHTq874Qg=
|
||||
github.com/ysmood/gotrace v0.6.0 h1:SyI1d4jclswLhg7SWTL6os3L1WOKeNn/ZtzVQF8QmdY=
|
||||
github.com/ysmood/gotrace v0.6.0/go.mod h1:TzhIG7nHDry5//eYZDYcTzuJLYQIkykJzCRIo4/dzQM=
|
||||
github.com/ysmood/gson v0.7.3 h1:QFkWbTH8MxyUTKPkVWAENJhxqdBa4lYTQWqZCiLG6kE=
|
||||
github.com/ysmood/gson v0.7.3/go.mod h1:3Kzs5zDl21g5F/BlLTNcuAGAYLKt2lV5G8D1zF3RNmg=
|
||||
github.com/ysmood/leakless v0.9.0 h1:qxCG5VirSBvmi3uynXFkcnLMzkphdh3xx5FtrORwDCU=
|
||||
github.com/ysmood/leakless v0.9.0/go.mod h1:R8iAXPRaG97QJwqxs74RdwzcRHT1SWCGTNqY8q0JvMQ=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
logger2 "geekai/logger"
|
||||
@@ -19,9 +20,10 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -45,6 +47,26 @@ func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client, cap
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ManagerHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/")
|
||||
|
||||
// 公开接口,不需要授权
|
||||
group.POST("login", h.Login)
|
||||
group.GET("logout", h.Logout)
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("session", h.Session)
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.POST("enable", h.Enable)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("resetPass", h.ResetPass)
|
||||
}
|
||||
}
|
||||
|
||||
// Login 登录
|
||||
func (h *ManagerHandler) Login(c *gin.Context) {
|
||||
var data struct {
|
||||
@@ -59,19 +81,6 @@ func (h *ManagerHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.EnabledVerify {
|
||||
var check bool
|
||||
if data.X != 0 {
|
||||
check = h.captcha.SlideCheck(data)
|
||||
} else {
|
||||
check = h.captcha.Check(data)
|
||||
}
|
||||
if !check {
|
||||
resp.ERROR(c, "请先完人机验证")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var manager model.AdminUser
|
||||
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
|
||||
if res.Error != nil {
|
||||
@@ -135,16 +144,15 @@ func (h *ManagerHandler) Logout(c *gin.Context) {
|
||||
|
||||
// Session 会话检测
|
||||
func (h *ManagerHandler) Session(c *gin.Context) {
|
||||
id := h.GetLoginUserId(c)
|
||||
key := fmt.Sprintf("admin/%d", id)
|
||||
if _, err := h.redis.Get(context.Background(), key).Result(); err != nil {
|
||||
resp.NotAuth(c)
|
||||
id := h.GetAdminId(c)
|
||||
if id == 0 {
|
||||
resp.NotAuth(c, "当前用户已退出登录")
|
||||
return
|
||||
}
|
||||
var manager model.AdminUser
|
||||
res := h.DB.Where("id", id).First(&manager)
|
||||
if res.Error != nil {
|
||||
resp.NotAuth(c)
|
||||
err := h.DB.Where("id", id).First(&manager).Error
|
||||
if err != nil {
|
||||
resp.NotAuth(c, "当前用户已退出登录")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ package admin
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
@@ -30,6 +31,20 @@ func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
|
||||
return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ApiKeyHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/apikey/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("remove", h.Remove)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) Save(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
|
||||
@@ -10,6 +10,7 @@ package admin
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
@@ -30,14 +31,29 @@ func NewChatAppHandler(app *core.AppServer, db *gorm.DB) *ChatAppHandler {
|
||||
return &ChatAppHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatAppHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/role/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.POST("sort", h.Sort)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("remove", h.Remove)
|
||||
}
|
||||
}
|
||||
|
||||
// Save 创建或者更新某个角色
|
||||
func (h *ChatAppHandler) Save(c *gin.Context) {
|
||||
var data vo.ChatRole
|
||||
var data vo.ChatApp
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
var role model.ChatRole
|
||||
var role model.ChatApp
|
||||
err := utils.CopyObject(data, &role)
|
||||
if err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
@@ -65,8 +81,8 @@ func (h *ChatAppHandler) Save(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *ChatAppHandler) List(c *gin.Context) {
|
||||
var items []model.ChatRole
|
||||
var roles = make([]vo.ChatRole, 0)
|
||||
var items []model.ChatApp
|
||||
var roles = make([]vo.ChatApp, 0)
|
||||
res := h.DB.Order("sort_num ASC").Find(&items)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "No data found")
|
||||
@@ -107,7 +123,7 @@ func (h *ChatAppHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
for _, v := range items {
|
||||
var role vo.ChatRole
|
||||
var role vo.ChatApp
|
||||
err := utils.CopyObject(v, &role)
|
||||
if err == nil {
|
||||
role.Id = v.Id
|
||||
@@ -135,7 +151,7 @@ func (h *ChatAppHandler) Sort(c *gin.Context) {
|
||||
}
|
||||
|
||||
for index, id := range data.Ids {
|
||||
err := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
|
||||
err := h.DB.Model(&model.ChatApp{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -157,7 +173,7 @@ func (h *ChatAppHandler) Set(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
|
||||
err := h.DB.Model(&model.ChatApp{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -172,9 +188,8 @@ func (h *ChatAppHandler) Remove(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
res := h.DB.Where("id", id).Delete(&model.ChatRole{})
|
||||
res := h.DB.Where("id", id).Delete(&model.ChatApp{})
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database:", res.Error)
|
||||
resp.ERROR(c, "删除失败!")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,12 +2,14 @@ package admin
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -20,6 +22,21 @@ func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler
|
||||
return &ChatAppTypeHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatAppTypeHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/app/type/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("enable", h.Enable)
|
||||
group.POST("sort", h.Sort)
|
||||
}
|
||||
}
|
||||
|
||||
// Save 创建或更新App类型
|
||||
func (h *ChatAppTypeHandler) Save(c *gin.Context) {
|
||||
var data struct {
|
||||
|
||||
@@ -9,6 +9,7 @@ package admin
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
@@ -28,16 +29,31 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
|
||||
return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/chat/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("list", h.List)
|
||||
group.POST("message", h.Messages)
|
||||
group.GET("history", h.History)
|
||||
group.GET("remove", h.RemoveChat)
|
||||
group.GET("message/remove", h.RemoveMessage)
|
||||
}
|
||||
}
|
||||
|
||||
type chatItemVo struct {
|
||||
Username string `json:"username"`
|
||||
UserId uint `json:"user_id"`
|
||||
ChatId string `json:"chat_id"`
|
||||
Title string `json:"title"`
|
||||
Role vo.ChatRole `json:"role"`
|
||||
Model string `json:"model"`
|
||||
Token int `json:"token"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
MsgNum int `json:"msg_num"` // 消息数量
|
||||
Username string `json:"username"`
|
||||
UserId uint `json:"user_id"`
|
||||
ChatId string `json:"chat_id"`
|
||||
Title string `json:"title"`
|
||||
Role vo.ChatApp `json:"role"`
|
||||
Model string `json:"model"`
|
||||
Token int `json:"token"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
MsgNum int `json:"msg_num"` // 消息数量
|
||||
}
|
||||
|
||||
func (h *ChatHandler) List(c *gin.Context) {
|
||||
@@ -87,7 +103,7 @@ func (h *ChatHandler) List(c *gin.Context) {
|
||||
}
|
||||
var messages []model.ChatMessage
|
||||
var users []model.User
|
||||
var roles []model.ChatRole
|
||||
var roles []model.ChatApp
|
||||
h.DB.Where("chat_id IN ?", chatIds).Find(&messages)
|
||||
h.DB.Where("id IN ?", userIds).Find(&users)
|
||||
h.DB.Where("id IN ?", roleIds).Find(&roles)
|
||||
@@ -95,7 +111,7 @@ func (h *ChatHandler) List(c *gin.Context) {
|
||||
tokenMap := make(map[string]int)
|
||||
userMap := make(map[uint]string)
|
||||
msgMap := make(map[string]int)
|
||||
roleMap := make(map[uint]vo.ChatRole)
|
||||
roleMap := make(map[uint]vo.ChatApp)
|
||||
for _, msg := range messages {
|
||||
tokenMap[msg.ChatId] += msg.Tokens
|
||||
msgMap[msg.ChatId] += 1
|
||||
@@ -104,7 +120,7 @@ func (h *ChatHandler) List(c *gin.Context) {
|
||||
userMap[user.Id] = user.Username
|
||||
}
|
||||
for _, r := range roles {
|
||||
var roleVo vo.ChatRole
|
||||
var roleVo vo.ChatApp
|
||||
err := utils.CopyObject(r, &roleVo)
|
||||
if err != nil {
|
||||
continue
|
||||
|
||||
@@ -8,7 +8,9 @@ package admin
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
@@ -28,6 +30,22 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
||||
return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatModelHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/model/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.POST("set", h.Set)
|
||||
group.POST("sort", h.Sort)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("batch-remove", h.BatchRemove)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ChatModelHandler) Save(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
@@ -201,3 +219,33 @@ func (h *ChatModelHandler) Remove(c *gin.Context) {
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// BatchRemove 批量删除模型
|
||||
func (h *ChatModelHandler) BatchRemove(c *gin.Context) {
|
||||
var data struct {
|
||||
Ids []uint `json:"ids"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if len(data.Ids) == 0 {
|
||||
resp.ERROR(c, "请选择要删除的模型")
|
||||
return
|
||||
}
|
||||
|
||||
// 执行批量删除
|
||||
err := h.DB.Where("id IN ?", data.Ids).Delete(&model.ChatModel{}).Error
|
||||
if err != nil {
|
||||
logger.Error("批量删除模型失败:", err)
|
||||
resp.ERROR(c, "批量删除失败:"+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"message": fmt.Sprintf("成功删除 %d 个模型", len(data.Ids)),
|
||||
"deleted_count": len(data.Ids),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,106 +9,399 @@ package admin
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
"geekai/store"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/payment"
|
||||
"geekai/service/sms"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shirou/gopsutil/host"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ConfigHandler struct {
|
||||
handler.BaseHandler
|
||||
levelDB *store.LevelDB
|
||||
licenseService *service.LicenseService
|
||||
licenseService *service.LicenseService
|
||||
sysConfig *types.SystemConfig
|
||||
alipayService *payment.AlipayService
|
||||
wxpayService *payment.WxPayService
|
||||
epayService *payment.EPayService
|
||||
smsManager *sms.SmsManager
|
||||
uploaderManager *oss.UploaderManager
|
||||
smtpService *service.SmtpService
|
||||
captchaService *service.CaptchaService
|
||||
wxLoginService *service.WxLoginService
|
||||
}
|
||||
|
||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler {
|
||||
func NewConfigHandler(
|
||||
app *core.AppServer,
|
||||
db *gorm.DB,
|
||||
licenseService *service.LicenseService,
|
||||
sysConfig *types.SystemConfig,
|
||||
alipayService *payment.AlipayService,
|
||||
wxpayService *payment.WxPayService,
|
||||
epayService *payment.EPayService,
|
||||
smsManager *sms.SmsManager,
|
||||
uploaderManager *oss.UploaderManager,
|
||||
smtpService *service.SmtpService,
|
||||
captchaService *service.CaptchaService,
|
||||
wxLoginService *service.WxLoginService,
|
||||
) *ConfigHandler {
|
||||
return &ConfigHandler{
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
levelDB: levelDB,
|
||||
licenseService: licenseService,
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
licenseService: licenseService,
|
||||
sysConfig: sysConfig,
|
||||
alipayService: alipayService,
|
||||
wxpayService: wxpayService,
|
||||
epayService: epayService,
|
||||
smsManager: smsManager,
|
||||
uploaderManager: uploaderManager,
|
||||
smtpService: smtpService,
|
||||
captchaService: captchaService,
|
||||
wxLoginService: wxLoginService,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ConfigHandler) Update(c *gin.Context) {
|
||||
var data struct {
|
||||
Key string `json:"key"`
|
||||
Config struct {
|
||||
types.SystemConfig
|
||||
Content string `json:"content,omitempty"`
|
||||
Updated bool `json:"updated,omitempty"`
|
||||
} `json:"config"`
|
||||
ConfigBak types.SystemConfig `json:"config_bak,omitempty"`
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ConfigHandler) RegisterRoutes() {
|
||||
rg := h.App.Engine.Group("/api/admin/config")
|
||||
|
||||
// 需要管理员登录的接口
|
||||
rg.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
rg.POST("update/base", h.UpdateBase)
|
||||
rg.POST("update/power", h.UpdatePower)
|
||||
rg.POST("update/notice", h.UpdateNotice)
|
||||
rg.POST("update/agreement", h.UpdateAgreement)
|
||||
rg.POST("update/privacy", h.UpdatePrivacy)
|
||||
rg.POST("update/mark_map", h.UpdateMarkMap)
|
||||
rg.POST("update/captcha", h.UpdateCaptcha)
|
||||
rg.POST("update/wx_login", h.UpdateWxLogin)
|
||||
rg.POST("update/payment", h.UpdatePayment)
|
||||
rg.POST("update/sms", h.UpdateSms)
|
||||
rg.POST("update/oss", h.UpdateOss)
|
||||
rg.POST("update/smtp", h.UpdateStmp)
|
||||
rg.GET("get", h.Get)
|
||||
rg.POST("license/active", h.Active)
|
||||
rg.GET("license/get", h.GetLicense)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateBase 更新基础配置
|
||||
func (h *ConfigHandler) UpdateBase(c *gin.Context) {
|
||||
var data types.BaseConfig
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
logger.Errorf("Update config failed: %v", err)
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
// ONLY authorized user can change the copyright
|
||||
if (data.Key == "system" && data.Config.Copyright != data.ConfigBak.Copyright) && !h.licenseService.GetLicense().Configs.DeCopy {
|
||||
resp.ERROR(c, "您无权修改版权信息,请先联系作者获取授权")
|
||||
// 未授权的话不允许修改版权
|
||||
license := h.licenseService.GetLicense()
|
||||
if !license.IsActive && data.Copyright != h.sysConfig.Base.Copyright {
|
||||
resp.ERROR(c, "未授权系统不允许修改版权信息")
|
||||
return
|
||||
}
|
||||
|
||||
// 如果要启用图形验证码功能,则检查是否配置了 API 服务
|
||||
if data.Config.EnabledVerify && h.App.Config.ApiConfig.AppId == "" {
|
||||
resp.ERROR(c, "启用验证码服务需要先配置 GeekAI 官方 API 服务 AppId 和 Token")
|
||||
// 未授权的话不允许修改 Logo
|
||||
if !license.IsActive && data.Logo != h.sysConfig.Base.Logo {
|
||||
resp.ERROR(c, "未授权系统不允许修改 Logo")
|
||||
return
|
||||
}
|
||||
|
||||
value := utils.JsonEncode(&data.Config)
|
||||
config := model.Config{Name: data.Key, Value: value}
|
||||
res := h.DB.FirstOrCreate(&config, model.Config{Name: data.Key})
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
err := h.Update(types.ConfigKeySystem, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if config.Id > 0 {
|
||||
config.Value = value
|
||||
res := h.DB.Updates(&config)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
h.sysConfig.Base = data
|
||||
|
||||
// update config cache for AppServer
|
||||
var cfg model.Config
|
||||
h.DB.Where("name", data.Key).First(&cfg)
|
||||
var err error
|
||||
if data.Key == "system" {
|
||||
err = utils.JsonDecode(cfg.Value, &h.App.SysConfig)
|
||||
}
|
||||
if err != nil {
|
||||
resp.ERROR(c, "Failed to update config cache: "+err.Error())
|
||||
return
|
||||
}
|
||||
logger.Infof("Update AppServer's config successfully: %v", config.Value)
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, config)
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// Get 获取指定的系统配置
|
||||
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
key := c.Query("key")
|
||||
// UpdatePower 更新系统配置
|
||||
func (h *ConfigHandler) UpdatePower(c *gin.Context) {
|
||||
var data struct {
|
||||
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
|
||||
DailyPower int `json:"daily_power,omitempty"` // 每日签到赠送算力
|
||||
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
|
||||
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
|
||||
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
|
||||
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
|
||||
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
|
||||
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
|
||||
KeLingPowers map[string]int `json:"keling_powers,omitempty"` // 可灵生成视频消耗算力
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
h.sysConfig.Base.InitPower = data.InitPower
|
||||
h.sysConfig.Base.DailyPower = data.DailyPower
|
||||
h.sysConfig.Base.InvitePower = data.InvitePower
|
||||
h.sysConfig.Base.MjPower = data.MjPower
|
||||
h.sysConfig.Base.MjActionPower = data.MjActionPower
|
||||
h.sysConfig.Base.SdPower = data.SdPower
|
||||
h.sysConfig.Base.SunoPower = data.SunoPower
|
||||
h.sysConfig.Base.LumaPower = data.LumaPower
|
||||
h.sysConfig.Base.KeLingPowers = data.KeLingPowers
|
||||
|
||||
err := h.Update(types.ConfigKeySystem, h.sysConfig.Base)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, h.sysConfig.Base)
|
||||
}
|
||||
|
||||
// UpdateNotice 更新公告配置
|
||||
func (h *ConfigHandler) UpdateNotice(c *gin.Context) {
|
||||
var data struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeyNotice, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdateAgreement 更新用户协议配置
|
||||
func (h *ConfigHandler) UpdateAgreement(c *gin.Context) {
|
||||
var data struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeyAgreement, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdatePrivacy 更新隐私政策配置
|
||||
func (h *ConfigHandler) UpdatePrivacy(c *gin.Context) {
|
||||
var data struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeyPrivacy, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdateMarkMap 更新思维导图配置
|
||||
func (h *ConfigHandler) UpdateMarkMap(c *gin.Context) {
|
||||
var data struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeyMarkMap, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdateCaptcha 更新行为验证码配置
|
||||
func (h *ConfigHandler) UpdateCaptcha(c *gin.Context) {
|
||||
var data types.CaptchaConfig
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeyCaptcha, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
h.captchaService.UpdateConfig(data)
|
||||
resp.SUCCESS(c, data)
|
||||
|
||||
}
|
||||
|
||||
// UpdatePayment 更新支付配置
|
||||
func (h *ConfigHandler) UpdatePayment(c *gin.Context) {
|
||||
var data types.PaymentConfig
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeyPayment, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 如果启用状态发生改变,则需要更新支付服务配置
|
||||
if data.WxPay.Enabled {
|
||||
err = h.wxpayService.UpdateConfig(&data.WxPay)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
if data.Epay.Enabled {
|
||||
h.epayService.UpdateConfig(&data.Epay)
|
||||
}
|
||||
if data.Alipay.Enabled {
|
||||
err = h.alipayService.UpdateConfig(&data.Alipay)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.sysConfig.Payment = data
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdateSms 更新短信配置
|
||||
func (h *ConfigHandler) UpdateSms(c *gin.Context) {
|
||||
var data types.SMSConfig
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeySms, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 更新服务配置
|
||||
h.smsManager.UpdateConfig(data)
|
||||
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdateOss 更新 Oss 配置
|
||||
func (h *ConfigHandler) UpdateOss(c *gin.Context) {
|
||||
var data types.OSSConfig
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeyOss, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 更新服务配置
|
||||
h.uploaderManager.UpdateConfig(data)
|
||||
h.sysConfig.OSS = data
|
||||
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdateStmp 更新 Stmp 配置
|
||||
func (h *ConfigHandler) UpdateStmp(c *gin.Context) {
|
||||
var data types.SmtpConfig
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.Update(types.ConfigKeySmtp, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 更新服务配置
|
||||
h.smtpService.UpdateConfig(&data)
|
||||
h.sysConfig.SMTP = data
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// UpdateWxLogin 更新微信登录配置
|
||||
func (h *ConfigHandler) UpdateWxLogin(c *gin.Context) {
|
||||
var data types.WxLoginConfig
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
err := h.Update(types.ConfigKeyWxLogin, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if data.Enabled {
|
||||
h.wxLoginService.UpdateConfig(data)
|
||||
}
|
||||
|
||||
h.sysConfig.WxLogin = data
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// Update 更新系统配置
|
||||
func (h *ConfigHandler) Update(name string, value any) error {
|
||||
var config model.Config
|
||||
res := h.DB.Where("name", key).First(&config)
|
||||
err := h.DB.Where("name", name).First(&config).Error
|
||||
if err != nil { // 不存在则创建
|
||||
config.Name = name
|
||||
config.Value = utils.JsonEncode(value)
|
||||
return h.DB.Create(&config).Error
|
||||
} else { // 存在则更新
|
||||
config.Value = utils.JsonEncode(value)
|
||||
return h.DB.Updates(&config).Error
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Get 获取指定名称的系统配置
|
||||
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
name := c.Query("key")
|
||||
var config model.Config
|
||||
res := h.DB.Where("name", name).First(&config)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var value map[string]interface{}
|
||||
var value map[string]any
|
||||
err := utils.JsonDecode(config.Value, &value)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
@@ -127,19 +420,21 @@ func (h *ConfigHandler) Active(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
info, err := host.Info()
|
||||
|
||||
err := h.licenseService.ActiveLicense(data.License)
|
||||
license := h.licenseService.GetLicense()
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = h.licenseService.ActiveLicense(data.License, info.HostID)
|
||||
if err != nil {
|
||||
if err := h.Update(types.ConfigKeyLicense, license); err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 更新系统配置
|
||||
h.sysConfig.License = *license
|
||||
|
||||
resp.SUCCESS(c)
|
||||
resp.SUCCESS(c, license.MachineId)
|
||||
|
||||
}
|
||||
|
||||
@@ -148,69 +443,3 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
|
||||
license := h.licenseService.GetLicense()
|
||||
resp.SUCCESS(c, license)
|
||||
}
|
||||
|
||||
// FixData 修复数据
|
||||
func (h *ConfigHandler) FixData(c *gin.Context) {
|
||||
resp.ERROR(c, "当前升级版本没有数据需要修正!")
|
||||
//var fixed bool
|
||||
//version := "data_fix_4.1.4"
|
||||
//err := h.levelDB.Get(version, &fixed)
|
||||
//if err == nil || fixed {
|
||||
// resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作")
|
||||
// return
|
||||
//}
|
||||
//tx := h.DB.Begin()
|
||||
//var users []model.User
|
||||
//err = tx.Find(&users).Error
|
||||
//if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// return
|
||||
//}
|
||||
//for _, user := range users {
|
||||
// if user.Email != "" || user.Mobile != "" {
|
||||
// continue
|
||||
// }
|
||||
// if utils.IsValidEmail(user.Username) {
|
||||
// user.Email = user.Username
|
||||
// } else if utils.IsValidMobile(user.Username) {
|
||||
// user.Mobile = user.Username
|
||||
// }
|
||||
// err = tx.Save(&user).Error
|
||||
// if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// tx.Rollback()
|
||||
// return
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//var orders []model.Order
|
||||
//err = h.DB.Find(&orders).Error
|
||||
//if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// return
|
||||
//}
|
||||
//for _, order := range orders {
|
||||
// if order.PayWay == "支付宝" {
|
||||
// order.PayWay = "alipay"
|
||||
// order.PayType = "alipay"
|
||||
// } else if order.PayWay == "微信支付" {
|
||||
// order.PayWay = "wechat"
|
||||
// order.PayType = "wxpay"
|
||||
// } else if order.PayWay == "hupi" {
|
||||
// order.PayType = "wxpay"
|
||||
// }
|
||||
// err = tx.Save(&order).Error
|
||||
// if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// tx.Rollback()
|
||||
// return
|
||||
// }
|
||||
//}
|
||||
//tx.Commit()
|
||||
//err = h.levelDB.Put(version, true)
|
||||
//if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// return
|
||||
//}
|
||||
//resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -13,10 +13,11 @@ import (
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shopspring/decimal"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DashboardHandler struct {
|
||||
@@ -27,46 +28,161 @@ func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler {
|
||||
return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *DashboardHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/dashboard/")
|
||||
group.GET("stats", h.Stats)
|
||||
}
|
||||
|
||||
// statsVo 增加 recentOrders、recentUsers 字段
|
||||
// 最近订单
|
||||
type OrderBrief struct {
|
||||
OrderNo string `json:"order_no"`
|
||||
Amount float64 `json:"amount"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// 最近用户
|
||||
type UserBrief struct {
|
||||
Nickname string `json:"nickname"`
|
||||
Avatar string `json:"avatar"`
|
||||
LastActive time.Time `json:"last_active"`
|
||||
}
|
||||
|
||||
type statsVo struct {
|
||||
Users int64 `json:"users"`
|
||||
Chats int64 `json:"chats"`
|
||||
Tokens int `json:"tokens"`
|
||||
Income float64 `json:"income"`
|
||||
Chart map[string]map[string]float64 `json:"chart"`
|
||||
Users int64 `json:"users"`
|
||||
Chats int64 `json:"chats"`
|
||||
Tokens int `json:"tokens"`
|
||||
Income float64 `json:"income"`
|
||||
Chart map[string]map[string]float64 `json:"chart"`
|
||||
TodayUsers int64 `json:"todayUsers"`
|
||||
TodayChats int64 `json:"todayChats"`
|
||||
TodayTokens int `json:"todayTokens"`
|
||||
TodayIncome float64 `json:"todayIncome"`
|
||||
TodayOrders int64 `json:"todayOrders"`
|
||||
TodayImageJobs int64 `json:"todayImageJobs"`
|
||||
TodayVideoJobs int64 `json:"todayVideoJobs"`
|
||||
TodayMusicJobs int64 `json:"todayMusicJobs"`
|
||||
Orders int64 `json:"orders"`
|
||||
ImageJobs int64 `json:"imageJobs"`
|
||||
VideoJobs int64 `json:"videoJobs"`
|
||||
MusicJobs int64 `json:"musicJobs"`
|
||||
RecentOrders []OrderBrief `json:"recentOrders"`
|
||||
RecentUsers []UserBrief `json:"recentUsers"`
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) Stats(c *gin.Context) {
|
||||
stats := statsVo{}
|
||||
// new users statistic
|
||||
var userCount int64
|
||||
now := time.Now()
|
||||
zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
res := h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
|
||||
if res.Error == nil {
|
||||
stats.Users = userCount
|
||||
|
||||
// 总用户数
|
||||
h.DB.Model(&model.User{}).Count(&stats.Users)
|
||||
|
||||
// 今日新增用户
|
||||
h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&stats.TodayUsers)
|
||||
|
||||
// 总对话数
|
||||
h.DB.Model(&model.ChatItem{}).Count(&stats.Chats)
|
||||
|
||||
// 今日新增对话
|
||||
h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&stats.TodayChats)
|
||||
|
||||
// 总算力消耗
|
||||
var powerLogs []model.PowerLog
|
||||
h.DB.Where("mark = ?", types.PowerSub).Find(&powerLogs)
|
||||
for _, item := range powerLogs {
|
||||
stats.Tokens += item.Amount
|
||||
}
|
||||
|
||||
// new chats statistic
|
||||
var chatCount int64
|
||||
res = h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
|
||||
if res.Error == nil {
|
||||
stats.Chats = chatCount
|
||||
// 今日算力消耗
|
||||
var todayPowerLogs []model.PowerLog
|
||||
h.DB.Where("mark = ?", types.PowerSub).Where("created_at > ?", zeroTime).Find(&todayPowerLogs)
|
||||
for _, item := range todayPowerLogs {
|
||||
stats.TodayTokens += item.Amount
|
||||
}
|
||||
|
||||
// tokens took stats
|
||||
var historyMessages []model.ChatMessage
|
||||
res = h.DB.Where("created_at > ?", zeroTime).Find(&historyMessages)
|
||||
for _, item := range historyMessages {
|
||||
stats.Tokens += item.Tokens
|
||||
}
|
||||
|
||||
// 订单收入
|
||||
var orders []model.Order
|
||||
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
|
||||
for _, item := range orders {
|
||||
// 总收入
|
||||
var allOrders []model.Order
|
||||
h.DB.Where("status = ?", types.OrderPaidSuccess).Find(&allOrders)
|
||||
for _, item := range allOrders {
|
||||
stats.Income += item.Amount
|
||||
}
|
||||
|
||||
// 今日收入
|
||||
var todayOrders []model.Order
|
||||
h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&todayOrders)
|
||||
for _, item := range todayOrders {
|
||||
stats.TodayIncome += item.Amount
|
||||
}
|
||||
|
||||
// 订单总数
|
||||
h.DB.Model(&model.Order{}).Where("status = ?", types.OrderPaidSuccess).Count(&stats.Orders)
|
||||
|
||||
// 今日订单数
|
||||
h.DB.Model(&model.Order{}).Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Count(&stats.TodayOrders)
|
||||
|
||||
// 图片生成任务统计
|
||||
var mjJobs, sdJobs, dallJobs, jimengImageJobs int64
|
||||
h.DB.Model(&model.MidJourneyJob{}).Count(&mjJobs)
|
||||
h.DB.Model(&model.SdJob{}).Count(&sdJobs)
|
||||
h.DB.Model(&model.DallJob{}).Count(&dallJobs)
|
||||
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_image", "image_to_image", "image_edit", "image_effects"}).Count(&jimengImageJobs)
|
||||
stats.ImageJobs = mjJobs + sdJobs + dallJobs + jimengImageJobs
|
||||
|
||||
logger.Info("stats.ImageJobs", stats.ImageJobs)
|
||||
|
||||
// 今日图片生成任务统计
|
||||
var todayMjJobs, todaySdJobs, todayDallJobs, todayJimengImageJobs int64
|
||||
h.DB.Model(&model.MidJourneyJob{}).Where("created_at > ?", zeroTime).Count(&todayMjJobs)
|
||||
h.DB.Model(&model.SdJob{}).Where("created_at > ?", zeroTime).Count(&todaySdJobs)
|
||||
h.DB.Model(&model.DallJob{}).Where("created_at > ?", zeroTime).Count(&todayDallJobs)
|
||||
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_image", "image_to_image", "image_edit", "image_effects"}).Where("created_at > ?", zeroTime).Count(&todayJimengImageJobs)
|
||||
stats.TodayImageJobs = todayMjJobs + todaySdJobs + todayDallJobs + todayJimengImageJobs
|
||||
|
||||
// 视频生成任务统计
|
||||
var videoJobs, jimengVideoJobs int64
|
||||
h.DB.Model(&model.VideoJob{}).Count(&videoJobs)
|
||||
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_video", "image_to_video"}).Count(&jimengVideoJobs)
|
||||
stats.VideoJobs = videoJobs + jimengVideoJobs
|
||||
|
||||
// 今日视频生成任务统计
|
||||
var todayVideoJobs, todayJimengVideoJobs int64
|
||||
h.DB.Model(&model.VideoJob{}).Where("created_at > ?", zeroTime).Count(&todayVideoJobs)
|
||||
h.DB.Model(&model.JimengJob{}).Where("type IN ?", []string{"text_to_video", "image_to_video"}).Where("created_at > ?", zeroTime).Count(&todayJimengVideoJobs)
|
||||
stats.TodayVideoJobs = todayVideoJobs + todayJimengVideoJobs
|
||||
|
||||
// 音乐生成任务统计
|
||||
h.DB.Model(&model.SunoJob{}).Count(&stats.MusicJobs)
|
||||
|
||||
// 今日音乐生成任务统计
|
||||
h.DB.Model(&model.SunoJob{}).Where("created_at > ?", zeroTime).Count(&stats.TodayMusicJobs)
|
||||
|
||||
// recentOrders: 最近10条已支付订单
|
||||
var orderList []model.Order
|
||||
h.DB.Model(&model.Order{}).Where("status = ?", types.OrderPaidSuccess).Order("created_at desc").Limit(10).Find(&orderList)
|
||||
for _, o := range orderList {
|
||||
stats.RecentOrders = append(stats.RecentOrders, OrderBrief{
|
||||
OrderNo: o.OrderNo,
|
||||
Amount: o.Amount,
|
||||
CreatedAt: o.CreatedAt,
|
||||
})
|
||||
}
|
||||
// recentUsers: 最近10个注册用户
|
||||
var userList []model.User
|
||||
h.DB.Model(&model.User{}).Order("created_at desc").Limit(10).Find(&userList)
|
||||
for _, u := range userList {
|
||||
lastActive := u.UpdatedAt
|
||||
if lastActive.IsZero() {
|
||||
lastActive = u.CreatedAt
|
||||
}
|
||||
stats.RecentUsers = append(stats.RecentUsers, UserBrief{
|
||||
Nickname: u.Nickname,
|
||||
Avatar: u.Avatar,
|
||||
LastActive: lastActive,
|
||||
})
|
||||
}
|
||||
|
||||
// 统计7天的订单的图表
|
||||
startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02")
|
||||
var statsChart = make(map[string]map[string]float64)
|
||||
@@ -81,23 +197,29 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
|
||||
|
||||
// 统计用户7天增加的曲线
|
||||
var users []model.User
|
||||
res = h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users)
|
||||
if res.Error == nil {
|
||||
err := h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users).Error
|
||||
if err == nil {
|
||||
for _, item := range users {
|
||||
userStatistic[item.CreatedAt.Format("2006-01-02")] += 1
|
||||
}
|
||||
}
|
||||
|
||||
// 统计7天Token 消耗
|
||||
res = h.DB.Where("created_at > ?", startDate).Find(&historyMessages)
|
||||
for _, item := range historyMessages {
|
||||
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
|
||||
// 统计7天算力消耗
|
||||
var chartPowerLogs []model.PowerLog
|
||||
err = h.DB.Where("mark = ?", types.PowerSub).Where("created_at > ?", startDate).Find(&chartPowerLogs).Error
|
||||
if err == nil {
|
||||
for _, item := range chartPowerLogs {
|
||||
historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Amount)
|
||||
}
|
||||
}
|
||||
|
||||
// 统计最近7天的订单
|
||||
res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
|
||||
for _, item := range orders {
|
||||
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
|
||||
var orders []model.Order
|
||||
err = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders).Error
|
||||
if err == nil {
|
||||
for _, item := range orders {
|
||||
incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
|
||||
}
|
||||
}
|
||||
|
||||
statsChart["users"] = userStatistic
|
||||
|
||||
@@ -9,6 +9,7 @@ package admin
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/store/model"
|
||||
@@ -30,6 +31,21 @@ func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler {
|
||||
return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *FunctionHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/function/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("token", h.GenToken)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *FunctionHandler) Save(c *gin.Context) {
|
||||
var data vo.Function
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
@@ -119,7 +135,6 @@ func (h *FunctionHandler) GenToken(c *gin.Context) {
|
||||
})
|
||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
||||
if err != nil {
|
||||
logger.Error("error with generate token", err)
|
||||
resp.ERROR(c)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ package admin
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
@@ -33,6 +34,20 @@ func NewImageHandler(app *core.AppServer, db *gorm.DB, userService *service.User
|
||||
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ImageHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/image/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("list/mj", h.MjList)
|
||||
group.POST("list/sd", h.SdList)
|
||||
group.POST("list/dall", h.DallList)
|
||||
group.GET("remove", h.Remove)
|
||||
}
|
||||
}
|
||||
|
||||
type imageQuery struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Username string `json:"username"`
|
||||
|
||||
@@ -21,18 +21,18 @@ import (
|
||||
// AdminJimengHandler 管理后台即梦AI处理器
|
||||
type AdminJimengHandler struct {
|
||||
handler.BaseHandler
|
||||
jimengService *jimeng.Service
|
||||
userService *service.UserService
|
||||
uploader *oss.UploaderManager
|
||||
jimengClient *jimeng.Client
|
||||
userService *service.UserService
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
// NewAdminJimengHandler 创建管理后台即梦AI处理器
|
||||
func NewAdminJimengHandler(app *core.AppServer, db *gorm.DB, jimengService *jimeng.Service, userService *service.UserService, uploader *oss.UploaderManager) *AdminJimengHandler {
|
||||
func NewAdminJimengHandler(app *core.AppServer, db *gorm.DB, jimengClient *jimeng.Client, userService *service.UserService, uploader *oss.UploaderManager) *AdminJimengHandler {
|
||||
return &AdminJimengHandler{
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
jimengService: jimengService,
|
||||
userService: userService,
|
||||
uploader: uploader,
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
jimengClient: jimengClient,
|
||||
userService: userService,
|
||||
uploader: uploader,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,7 +43,6 @@ func (h *AdminJimengHandler) RegisterRoutes() {
|
||||
rg.GET("/jobs/:id", h.JobDetail)
|
||||
rg.POST("/jobs/remove", h.BatchRemove)
|
||||
rg.GET("/stats", h.Stats)
|
||||
rg.GET("/config", h.GetConfig)
|
||||
rg.POST("/config/update", h.UpdateConfig)
|
||||
}
|
||||
|
||||
@@ -213,12 +212,6 @@ func (h *AdminJimengHandler) Stats(c *gin.Context) {
|
||||
resp.SUCCESS(c, result)
|
||||
}
|
||||
|
||||
// GetConfig 获取即梦AI配置
|
||||
func (h *AdminJimengHandler) GetConfig(c *gin.Context) {
|
||||
jimengConfig := h.jimengService.GetConfig()
|
||||
resp.SUCCESS(c, jimengConfig)
|
||||
}
|
||||
|
||||
// UpdateConfig 更新即梦AI配置
|
||||
func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
|
||||
var req types.JimengConfig
|
||||
@@ -266,31 +259,35 @@ func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
|
||||
// 保存配置
|
||||
tx := h.DB.Begin()
|
||||
value := utils.JsonEncode(&req)
|
||||
config := model.Config{Name: "jimeng", Value: value}
|
||||
var exist model.Config
|
||||
tx.Where("name", types.ConfigKeyJimeng).First(&exist)
|
||||
|
||||
err := tx.FirstOrCreate(&config, model.Config{Name: "jimeng"}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "保存配置失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if config.Id > 0 {
|
||||
config.Value = value
|
||||
err = tx.Updates(&config).Error
|
||||
if exist.Id > 0 {
|
||||
exist.Value = value
|
||||
err := tx.Updates(&exist).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "更新配置失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
exist.Name = types.ConfigKeyJimeng
|
||||
exist.Value = value
|
||||
err := tx.Create(&exist).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "创建配置失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 更新服务中的客户端配置
|
||||
updateErr := h.jimengService.UpdateClientConfig(req.AccessKey, req.SecretKey)
|
||||
if updateErr != nil {
|
||||
resp.ERROR(c, updateErr.Error())
|
||||
err := h.jimengClient.UpdateConfig(req)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
h.App.SysConfig.Jimeng = req
|
||||
|
||||
resp.SUCCESS(c, gin.H{"message": "配置更新成功"})
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ package admin
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
@@ -33,6 +34,19 @@ func NewMediaHandler(app *core.AppServer, db *gorm.DB, userService *service.User
|
||||
return &MediaHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *MediaHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/media/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("suno", h.SunoList)
|
||||
group.POST("videos", h.Videos)
|
||||
group.GET("remove", h.Remove)
|
||||
}
|
||||
}
|
||||
|
||||
type mediaQuery struct {
|
||||
Type string `json:"type"` // 任务类型 luma, keling
|
||||
Prompt string `json:"prompt"`
|
||||
|
||||
@@ -27,6 +27,16 @@ func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
|
||||
return &MenuHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *MenuHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/menu/")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.POST("enable", h.Enable)
|
||||
group.POST("sort", h.Sort)
|
||||
group.GET("remove", h.Remove)
|
||||
}
|
||||
|
||||
func (h *MenuHandler) Save(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
|
||||
333
api/handler/admin/moderation_handler.go
Normal file
333
api/handler/admin/moderation_handler.go
Normal file
@@ -0,0 +1,333 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service/moderation"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ModerationHandler struct {
|
||||
handler.BaseHandler
|
||||
sysConfig *types.SystemConfig
|
||||
moderationManager *moderation.ServiceManager
|
||||
}
|
||||
|
||||
func NewModerationHandler(app *core.AppServer, db *gorm.DB, sysConfig *types.SystemConfig, moderationManager *moderation.ServiceManager) *ModerationHandler {
|
||||
return &ModerationHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, sysConfig: sysConfig, moderationManager: moderationManager}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ModerationHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/moderation/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("batch-remove", h.BatchRemove)
|
||||
group.GET("source-list", h.GetSourceList)
|
||||
group.POST("config", h.UpdateModeration)
|
||||
group.POST("test", h.TestModeration)
|
||||
}
|
||||
}
|
||||
|
||||
// List 获取文本审核记录列表
|
||||
func (h *ModerationHandler) List(c *gin.Context) {
|
||||
var data struct {
|
||||
Username string `json:"username"`
|
||||
Source string `json:"source"`
|
||||
StartDate string `json:"start_date"`
|
||||
EndDate string `json:"end_date"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
|
||||
// 构建查询条件
|
||||
if data.Username != "" {
|
||||
// 通过用户名查找用户ID
|
||||
var user model.User
|
||||
if err := h.DB.Where("username LIKE ?", "%"+data.Username+"%").First(&user).Error; err == nil {
|
||||
session = session.Where("user_id", user.Id)
|
||||
}
|
||||
}
|
||||
|
||||
if data.Source != "" {
|
||||
session = session.Where("source", data.Source)
|
||||
}
|
||||
|
||||
if data.StartDate != "" && data.EndDate != "" {
|
||||
startTime := data.StartDate + " 00:00:00"
|
||||
endTime := data.EndDate + " 23:59:59"
|
||||
session = session.Where("created_at >= ? AND created_at <= ?", startTime, endTime)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
var total int64
|
||||
session.Model(&model.Moderation{}).Count(&total)
|
||||
|
||||
// 分页
|
||||
page := data.Page
|
||||
pageSize := data.PageSize
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
session = session.Offset(offset).Limit(pageSize)
|
||||
|
||||
// 查询数据
|
||||
var items []model.Moderation
|
||||
err := session.Order("id DESC").Find(&items).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
userIds := make([]uint, 0)
|
||||
for _, item := range items {
|
||||
userIds = append(userIds, item.UserId)
|
||||
}
|
||||
|
||||
var users []model.User
|
||||
if len(userIds) > 0 {
|
||||
h.DB.Where("id IN ?", userIds).Find(&users)
|
||||
}
|
||||
|
||||
userMap := make(map[uint]string)
|
||||
for _, user := range users {
|
||||
userMap[user.Id] = user.Username
|
||||
}
|
||||
|
||||
// 转换为响应数据
|
||||
list := make([]map[string]any, 0)
|
||||
for _, item := range items {
|
||||
var moderation types.ModerationResult
|
||||
err := utils.JsonDecode(item.Result, &moderation)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var result []string
|
||||
for value, label := range types.ModerationCategories {
|
||||
if moderation.Categories[value] {
|
||||
result = append(result, label)
|
||||
}
|
||||
}
|
||||
list = append(list, map[string]any{
|
||||
"id": item.Id,
|
||||
"user_id": item.UserId,
|
||||
"username": userMap[item.UserId],
|
||||
"source": item.Source,
|
||||
"input": item.Input,
|
||||
"output": item.Output,
|
||||
"result": result,
|
||||
"created_at": item.CreatedAt.Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, map[string]any{
|
||||
"items": list,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ModerationHandler) Remove(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
if id <= 0 {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.DB.Where("id", id).Delete(&model.Moderation{}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// BatchRemove 批量删除文本审核记录
|
||||
func (h *ModerationHandler) BatchRemove(c *gin.Context) {
|
||||
var data struct {
|
||||
Ids []uint `json:"ids"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if len(data.Ids) == 0 {
|
||||
resp.ERROR(c, "请选择要删除的记录")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.DB.Where("id IN ?", data.Ids).Delete(&model.Moderation{}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// 获取 source 列表
|
||||
func (h *ModerationHandler) GetSourceList(c *gin.Context) {
|
||||
sources := []gin.H{
|
||||
{
|
||||
"id": types.ModerationSourceChat,
|
||||
"name": "AI对话",
|
||||
},
|
||||
{
|
||||
"id": types.ModerationSourceMJ,
|
||||
"name": "Midjourney 绘图",
|
||||
},
|
||||
{
|
||||
"id": types.ModerationSourceDalle,
|
||||
"name": "Dalle 绘图",
|
||||
},
|
||||
{
|
||||
"id": types.ModerationSourceSD,
|
||||
"name": "StableDiffusion 绘图",
|
||||
},
|
||||
{
|
||||
"id": types.ModerationSourceSuno,
|
||||
"name": "Suno 音乐",
|
||||
},
|
||||
{
|
||||
"id": types.ModerationSourceVideo,
|
||||
"name": "视频生成",
|
||||
},
|
||||
{
|
||||
"id": types.ModerationSourceJiMeng,
|
||||
"name": "即梦AI",
|
||||
},
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, sources)
|
||||
}
|
||||
|
||||
// UpdateModeration 更新文本审查配置
|
||||
func (h *ModerationHandler) UpdateModeration(c *gin.Context) {
|
||||
var data types.ModerationConfig
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
var config model.Config
|
||||
err := h.DB.Where("name", types.ConfigKeyModeration).First(&config).Error
|
||||
if err != nil {
|
||||
config.Name = types.ConfigKeyModeration
|
||||
config.Value = utils.JsonEncode(data)
|
||||
err = h.DB.Create(&config).Error
|
||||
} else {
|
||||
config.Value = utils.JsonEncode(data)
|
||||
err = h.DB.Updates(&config).Error
|
||||
}
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.moderationManager.UpdateConfig(data)
|
||||
h.sysConfig.Moderation = data
|
||||
|
||||
resp.SUCCESS(c, data)
|
||||
}
|
||||
|
||||
// 测试结果类型,用于前端显示
|
||||
type ModerationTestResult struct {
|
||||
IsAbnormal bool `json:"isAbnormal"`
|
||||
Details []ModerationTestDetail `json:"details"`
|
||||
}
|
||||
|
||||
type ModerationTestDetail struct {
|
||||
Category string `json:"category"`
|
||||
Description string `json:"description"`
|
||||
Confidence string `json:"confidence"`
|
||||
IsCategory bool `json:"isCategory"`
|
||||
}
|
||||
|
||||
// TestModeration 测试文本审查服务
|
||||
func (h *ModerationHandler) TestModeration(c *gin.Context) {
|
||||
var data struct {
|
||||
Text string `json:"text"`
|
||||
Service string `json:"service"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if data.Text == "" {
|
||||
resp.ERROR(c, "测试文本不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否启用了文本审查
|
||||
if !h.sysConfig.Moderation.Enable {
|
||||
resp.ERROR(c, "文本审查服务未启用")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前激活的审核服务
|
||||
service := h.moderationManager.GetService()
|
||||
// 执行文本审核
|
||||
result, err := service.Moderate(data.Text)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "审核服务调用失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为前端需要的格式
|
||||
testResult := ModerationTestResult{
|
||||
IsAbnormal: result.Flagged,
|
||||
Details: make([]ModerationTestDetail, 0),
|
||||
}
|
||||
|
||||
// 构建详细信息
|
||||
for category, description := range types.ModerationCategories {
|
||||
score := result.CategoryScores[category]
|
||||
isCategory := result.Categories[category]
|
||||
|
||||
testResult.Details = append(testResult.Details, ModerationTestDetail{
|
||||
Category: category,
|
||||
Description: description,
|
||||
Confidence: fmt.Sprintf("%.2f", score),
|
||||
IsCategory: isCategory,
|
||||
})
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, testResult)
|
||||
}
|
||||
@@ -29,6 +29,14 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
||||
return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *OrderHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/order/")
|
||||
group.POST("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("clear", h.Clear)
|
||||
}
|
||||
|
||||
func (h *OrderHandler) List(c *gin.Context) {
|
||||
var data struct {
|
||||
OrderNo string `json:"order_no"`
|
||||
@@ -68,16 +76,16 @@ func (h *OrderHandler) List(c *gin.Context) {
|
||||
order.Id = item.Id
|
||||
order.CreatedAt = item.CreatedAt.Unix()
|
||||
order.UpdatedAt = item.UpdatedAt.Unix()
|
||||
payMethod, ok := types.PayMethods[item.PayWay]
|
||||
payChannel, ok := types.PayChannel[item.Channel]
|
||||
if !ok {
|
||||
payMethod = item.PayWay
|
||||
payChannel = item.Channel
|
||||
}
|
||||
payName, ok := types.PayNames[item.PayType]
|
||||
payWays, ok := types.PayWays[item.PayWay]
|
||||
if !ok {
|
||||
payName = item.PayWay
|
||||
payWays = item.PayWay
|
||||
}
|
||||
order.PayMethod = payMethod
|
||||
order.PayName = payName
|
||||
order.ChannelName = payChannel
|
||||
order.PayName = payWays
|
||||
list = append(list, order)
|
||||
} else {
|
||||
logger.Error(err)
|
||||
@@ -121,8 +129,8 @@ func (h *OrderHandler) Clear(c *gin.Context) {
|
||||
}
|
||||
deleteIds := make([]uint, 0)
|
||||
for _, order := range orders {
|
||||
// 只删除 15 分钟内的未支付订单
|
||||
if time.Now().After(order.CreatedAt.Add(time.Minute * 15)) {
|
||||
// 只删除超时的未支付订单
|
||||
if time.Now().After(order.CreatedAt.Add(time.Minute * time.Duration(h.App.SysConfig.Base.OrderPayTimeout))) {
|
||||
deleteIds = append(deleteIds, order.Id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,12 @@ func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
|
||||
return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *PowerLogHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/powerLog/")
|
||||
group.POST("list", h.List)
|
||||
}
|
||||
|
||||
func (h *PowerLogHandler) List(c *gin.Context) {
|
||||
var data struct {
|
||||
Username string `json:"username"`
|
||||
|
||||
@@ -15,9 +15,10 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ProductHandler struct {
|
||||
@@ -28,14 +29,22 @@ func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
|
||||
return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ProductHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/product/")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.POST("enable", h.Enable)
|
||||
group.POST("sort", h.Sort)
|
||||
group.GET("remove", h.Remove)
|
||||
}
|
||||
|
||||
func (h *ProductHandler) Save(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Price float64 `json:"price"`
|
||||
Discount float64 `json:"discount"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Days int `json:"days"`
|
||||
Power int `json:"power"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
@@ -45,12 +54,10 @@ func (h *ProductHandler) Save(c *gin.Context) {
|
||||
}
|
||||
|
||||
item := model.Product{
|
||||
Name: data.Name,
|
||||
Price: data.Price,
|
||||
Discount: data.Discount,
|
||||
Days: data.Days,
|
||||
Power: data.Power,
|
||||
Enabled: data.Enabled}
|
||||
Name: data.Name,
|
||||
Price: data.Price,
|
||||
Power: data.Power,
|
||||
Enabled: data.Enabled}
|
||||
item.Id = data.Id
|
||||
if item.Id > 0 {
|
||||
item.CreatedAt = time.Unix(data.CreatedAt, 0)
|
||||
|
||||
@@ -29,6 +29,16 @@ func NewRedeemHandler(app *core.AppServer, db *gorm.DB) *RedeemHandler {
|
||||
return &RedeemHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *RedeemHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/redeem/")
|
||||
group.GET("list", h.List)
|
||||
group.POST("create", h.Create)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("export", h.Export)
|
||||
}
|
||||
|
||||
func (h *RedeemHandler) List(c *gin.Context) {
|
||||
page := h.GetInt(c, "page", 1)
|
||||
pageSize := h.GetInt(c, "page_size", 20)
|
||||
|
||||
@@ -9,6 +9,7 @@ package admin
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/handler"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
@@ -28,6 +29,17 @@ func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderMan
|
||||
return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *UploadHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/upload")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("", h.Upload)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *UploadHandler) Upload(c *gin.Context) {
|
||||
// 判断文件大小
|
||||
f, err := c.FormFile("file")
|
||||
@@ -36,7 +48,7 @@ func (h *UploadHandler) Upload(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.MaxFileSize > 0 && f.Size > int64(h.App.SysConfig.MaxFileSize)*1024*1024 {
|
||||
if h.App.SysConfig.Base.MaxFileSize > 0 && f.Size > int64(h.App.SysConfig.Base.MaxFileSize)*1024*1024 {
|
||||
resp.ERROR(c, "文件大小超过限制")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ package admin
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
@@ -19,10 +20,9 @@ import (
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -36,6 +36,22 @@ func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.Li
|
||||
return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService, redis: redisCli}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *UserHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/admin/user/")
|
||||
|
||||
// 需要管理员授权的接口
|
||||
group.Use(middleware.AdminAuthMiddleware(h.App.Config.AdminSession.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("loginLog", h.LoginLog)
|
||||
group.GET("genLoginLink", h.GenLoginLink)
|
||||
group.POST("resetPass", h.ResetPass)
|
||||
}
|
||||
}
|
||||
|
||||
// List 用户列表
|
||||
func (h *UserHandler) List(c *gin.Context) {
|
||||
page := h.GetInt(c, "page", 1)
|
||||
|
||||
@@ -15,9 +15,10 @@ import (
|
||||
logger2 "geekai/logger"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"gorm.io/gorm"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -69,6 +70,14 @@ func (h *BaseHandler) GetLoginUserId(c *gin.Context) uint {
|
||||
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
|
||||
}
|
||||
|
||||
func (h *BaseHandler) GetAdminId(c *gin.Context) uint {
|
||||
userId, ok := c.Get(types.AdminUserID)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
|
||||
}
|
||||
|
||||
func (h *BaseHandler) IsLogin(c *gin.Context) bool {
|
||||
return h.GetLoginUserId(c) > 0
|
||||
}
|
||||
|
||||
@@ -8,23 +8,45 @@ package handler
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 今日头条函数实现
|
||||
|
||||
type CaptchaHandler struct {
|
||||
App *core.AppServer
|
||||
service *service.CaptchaService
|
||||
}
|
||||
|
||||
func NewCaptchaHandler(s *service.CaptchaService) *CaptchaHandler {
|
||||
return &CaptchaHandler{service: s}
|
||||
func NewCaptchaHandler(app *core.AppServer, s *service.CaptchaService, sysConfig *types.SystemConfig) *CaptchaHandler {
|
||||
return &CaptchaHandler{App: app, service: s}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *CaptchaHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/captcha/")
|
||||
|
||||
// 无需授权的接口
|
||||
group.GET("get", h.Get)
|
||||
group.POST("check", h.Check)
|
||||
group.GET("slide/get", h.SlideGet)
|
||||
group.POST("slide/check", h.SlideCheck)
|
||||
group.GET("config", h.GetConfig)
|
||||
}
|
||||
|
||||
func (h *CaptchaHandler) GetConfig(c *gin.Context) {
|
||||
resp.SUCCESS(c, gin.H{"enabled": h.service.GetConfig().Enabled, "type": h.service.GetConfig().Type})
|
||||
}
|
||||
|
||||
func (h *CaptchaHandler) Get(c *gin.Context) {
|
||||
if !h.service.GetConfig().Enabled {
|
||||
resp.ERROR(c, "验证码服务未启用")
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.service.Get()
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
@@ -36,6 +58,11 @@ func (h *CaptchaHandler) Get(c *gin.Context) {
|
||||
|
||||
// Check verify the captcha data
|
||||
func (h *CaptchaHandler) Check(c *gin.Context) {
|
||||
if !h.service.GetConfig().Enabled {
|
||||
resp.ERROR(c, "验证码服务未启用")
|
||||
return
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Key string `json:"key"`
|
||||
Dots string `json:"dots"`
|
||||
@@ -55,6 +82,11 @@ func (h *CaptchaHandler) Check(c *gin.Context) {
|
||||
|
||||
// SlideGet 获取滑动验证图片
|
||||
func (h *CaptchaHandler) SlideGet(c *gin.Context) {
|
||||
if !h.service.GetConfig().Enabled {
|
||||
resp.ERROR(c, "验证码服务未启用")
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.service.SlideGet()
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
@@ -66,6 +98,11 @@ func (h *CaptchaHandler) SlideGet(c *gin.Context) {
|
||||
|
||||
// SlideCheck 滑动验证结果校验
|
||||
func (h *CaptchaHandler) SlideCheck(c *gin.Context) {
|
||||
if !h.service.GetConfig().Enabled {
|
||||
resp.ERROR(c, "验证码服务未启用")
|
||||
return
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Key string `json:"key"`
|
||||
X int `json:"x"`
|
||||
|
||||
@@ -9,6 +9,7 @@ package handler
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
@@ -19,18 +20,31 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ChatRoleHandler struct {
|
||||
type ChatAppHandler struct {
|
||||
BaseHandler
|
||||
}
|
||||
|
||||
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
|
||||
return &ChatRoleHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
func NewChatAppHandler(app *core.AppServer, db *gorm.DB) *ChatAppHandler {
|
||||
return &ChatAppHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatAppHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/app/")
|
||||
group.GET("list", h.List)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list/user", h.ListByUser)
|
||||
group.POST("update", h.UpdateApp)
|
||||
}
|
||||
}
|
||||
|
||||
// List 获取用户聊天应用列表
|
||||
func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
func (h *ChatAppHandler) List(c *gin.Context) {
|
||||
tid := h.GetInt(c, "tid", 0)
|
||||
var roles []model.ChatRole
|
||||
var roles []model.ChatApp
|
||||
session := h.DB.Where("enable", true)
|
||||
if tid > 0 {
|
||||
session = session.Where("tid", tid)
|
||||
@@ -41,9 +55,9 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var roleVos = make([]vo.ChatRole, 0)
|
||||
var roleVos = make([]vo.ChatApp, 0)
|
||||
for _, r := range roles {
|
||||
var v vo.ChatRole
|
||||
var v vo.ChatApp
|
||||
err := utils.CopyObject(r, &v)
|
||||
if err == nil {
|
||||
v.Id = r.Id
|
||||
@@ -54,10 +68,10 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
// ListByUser 获取用户添加的角色列表
|
||||
func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
|
||||
func (h *ChatAppHandler) ListByUser(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetLoginUserId(c)
|
||||
var roles []model.ChatRole
|
||||
var roles []model.ChatApp
|
||||
session := h.DB.Where("enable", true)
|
||||
// 如果用户没登录,则获取所有角色
|
||||
if userId > 0 {
|
||||
@@ -86,9 +100,9 @@ func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var roleVos = make([]vo.ChatRole, 0)
|
||||
var roleVos = make([]vo.ChatApp, 0)
|
||||
for _, r := range roles {
|
||||
var v vo.ChatRole
|
||||
var v vo.ChatApp
|
||||
err := utils.CopyObject(r, &v)
|
||||
if err == nil {
|
||||
v.Id = r.Id
|
||||
@@ -98,8 +112,8 @@ func (h *ChatRoleHandler) ListByUser(c *gin.Context) {
|
||||
resp.SUCCESS(c, roleVos)
|
||||
}
|
||||
|
||||
// UpdateRole 更新用户聊天角色
|
||||
func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
|
||||
// UpdateApp 更新用户聊天应用
|
||||
func (h *ChatAppHandler) UpdateApp(c *gin.Context) {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
@@ -19,6 +19,12 @@ func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler
|
||||
return &ChatAppTypeHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatAppTypeHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/app/type/")
|
||||
group.GET("list", h.List)
|
||||
}
|
||||
|
||||
// List 获取App类型列表
|
||||
func (h *ChatAppTypeHandler) List(c *gin.Context) {
|
||||
var items []model.AppType
|
||||
|
||||
@@ -14,8 +14,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/moderation"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
@@ -39,6 +41,7 @@ import (
|
||||
const (
|
||||
ChatEventStart = "start"
|
||||
ChatEventEnd = "end"
|
||||
ChatEventComplete = "complete"
|
||||
ChatEventError = "error"
|
||||
ChatEventMessageDelta = "message_delta"
|
||||
ChatEventTitle = "title"
|
||||
@@ -54,44 +57,69 @@ type ChatInput struct {
|
||||
Stream bool `json:"stream"`
|
||||
Files []vo.File `json:"files"`
|
||||
ChatModel model.ChatModel `json:"chat_model,omitempty"`
|
||||
ChatRole model.ChatRole `json:"chat_role,omitempty"`
|
||||
ChatRole model.ChatApp `json:"chat_role,omitempty"`
|
||||
LastMsgId uint `json:"last_msg_id,omitempty"` // 最后的消息ID,用于重新生成答案的时候过滤上下文
|
||||
}
|
||||
|
||||
type ChatHandler struct {
|
||||
BaseHandler
|
||||
redis *redis.Client
|
||||
uploadManager *oss.UploaderManager
|
||||
licenseService *service.LicenseService
|
||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||
userService *service.UserService
|
||||
redis *redis.Client
|
||||
uploadManager *oss.UploaderManager
|
||||
licenseService *service.LicenseService
|
||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||
userService *service.UserService
|
||||
moderationManager *moderation.ServiceManager
|
||||
}
|
||||
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler {
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService, moderationManager *moderation.ServiceManager) *ChatHandler {
|
||||
return &ChatHandler{
|
||||
BaseHandler: BaseHandler{App: app, DB: db},
|
||||
redis: redis,
|
||||
uploadManager: manager,
|
||||
licenseService: licenseService,
|
||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||
userService: userService,
|
||||
BaseHandler: BaseHandler{App: app, DB: db},
|
||||
redis: redis,
|
||||
uploadManager: manager,
|
||||
licenseService: licenseService,
|
||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||
userService: userService,
|
||||
moderationManager: moderationManager,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/chat/")
|
||||
|
||||
// 聊天接口不需要授权(已在authConfig中配置)
|
||||
group.Any("message", h.Chat)
|
||||
|
||||
// 其他接口需要用户授权
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.GET("detail", h.Detail)
|
||||
group.POST("update", h.Update)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("history", h.History)
|
||||
group.GET("clear", h.Clear)
|
||||
group.POST("tokens", h.Tokens)
|
||||
group.GET("stop", h.StopGenerate)
|
||||
group.POST("tts", h.TextToSpeech)
|
||||
}
|
||||
}
|
||||
|
||||
// Chat 处理聊天请求
|
||||
func (h *ChatHandler) Chat(c *gin.Context) {
|
||||
var input ChatInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置SSE响应头
|
||||
c.Header("Prompt-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
var input ChatInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
pushMessage(c, ChatEventError, types.InvalidArgs)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
@@ -113,7 +141,7 @@ func (h *ChatHandler) Chat(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 验证聊天角色
|
||||
var chatRole model.ChatRole
|
||||
var chatRole model.ChatApp
|
||||
err := h.DB.First(&chatRole, input.RoleId).Error
|
||||
if err != nil || !chatRole.Enable {
|
||||
pushMessage(c, ChatEventError, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!")
|
||||
@@ -166,7 +194,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
||||
}
|
||||
|
||||
if userVo.Power < input.ChatModel.Power {
|
||||
return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d,[立即购买](/member)。", userVo.Power, input.ChatModel.Power)
|
||||
return fmt.Errorf("您的算力不足,请购买算力。")
|
||||
}
|
||||
|
||||
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
|
||||
@@ -229,17 +257,24 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
||||
// 加载聊天上下文
|
||||
chatCtx := make([]any, 0)
|
||||
messages := make([]any, 0)
|
||||
if h.App.SysConfig.EnableContext {
|
||||
if h.App.SysConfig.Base.EnableContext {
|
||||
_ = utils.JsonDecode(input.ChatRole.Context, &messages)
|
||||
if h.App.SysConfig.ContextDeep > 0 {
|
||||
if h.App.SysConfig.Base.ContextDeep > 0 {
|
||||
var historyMessages []model.ChatMessage
|
||||
dbSession := h.DB.Session(&gorm.Session{}).Where("chat_id", input.ChatId)
|
||||
if input.LastMsgId > 0 { // 重新生成逻辑
|
||||
var lastMessage model.ChatMessage
|
||||
err = dbSession.Where("id <= ?", input.LastMsgId).Where("type", types.PromptMsg).First(&lastMessage).Error
|
||||
if err != nil {
|
||||
input.LastMsgId = 0
|
||||
} else {
|
||||
input.LastMsgId = lastMessage.Id
|
||||
}
|
||||
dbSession = dbSession.Where("id < ?", input.LastMsgId)
|
||||
// 删除对应的聊天记录
|
||||
h.DB.Debug().Where("chat_id", input.ChatId).Where("id >= ?", input.LastMsgId).Delete(&model.ChatMessage{})
|
||||
}
|
||||
err = dbSession.Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages).Error
|
||||
err = dbSession.Limit(h.App.SysConfig.Base.ContextDeep).Order("id DESC").Find(&historyMessages).Error
|
||||
if err == nil {
|
||||
for i := len(historyMessages) - 1; i >= 0; i-- {
|
||||
msg := historyMessages[i]
|
||||
@@ -267,7 +302,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
||||
}
|
||||
|
||||
// 上下文的深度超出了模型的最大上下文深度
|
||||
if len(chatCtx) >= h.App.SysConfig.ContextDeep {
|
||||
if len(chatCtx) >= h.App.SysConfig.Base.ContextDeep {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -277,6 +312,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
||||
}
|
||||
reqMgs := make([]any, 0)
|
||||
|
||||
// 添加引导提示词,防止模型生成违规内容
|
||||
if h.App.SysConfig.Moderation.EnableGuide {
|
||||
reqMgs = append(reqMgs, map[string]any{
|
||||
"role": "system",
|
||||
"content": h.App.SysConfig.Moderation.GuidePrompt,
|
||||
})
|
||||
}
|
||||
|
||||
for i := len(chatCtx) - 1; i >= 0; i-- {
|
||||
reqMgs = append(reqMgs, chatCtx[i])
|
||||
}
|
||||
@@ -295,16 +338,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
||||
},
|
||||
})
|
||||
} else {
|
||||
// 如果不是逆向模型,则提取文件内容
|
||||
modelValue := input.ChatModel.Value
|
||||
if !(strings.Contains(modelValue, "-all") || strings.HasPrefix(modelValue, "gpt-4-gizmo") || strings.HasPrefix(modelValue, "claude")) {
|
||||
content, err := utils.ReadFileContent(file.URL, h.App.Config.TikaHost)
|
||||
if err != nil {
|
||||
logger.Error("error with read file: ", err)
|
||||
continue
|
||||
} else {
|
||||
fileContents = append(fileContents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
|
||||
}
|
||||
// 处理文件,提取文件内容
|
||||
content, err := utils.ReadFileContent(file.URL, h.App.Config.TikaHost)
|
||||
if err != nil {
|
||||
logger.Error("error with read file: ", err)
|
||||
continue
|
||||
} else {
|
||||
fileContents = append(fileContents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
|
||||
logger.Debugf("fileContents: %s", fileContents)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -320,16 +361,16 @@ func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.C
|
||||
}
|
||||
|
||||
if len(imgList) > 0 {
|
||||
imgList = append(imgList, map[string]interface{}{
|
||||
imgList = append(imgList, map[string]any{
|
||||
"type": "text",
|
||||
"text": input.Prompt,
|
||||
})
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
req.Messages = append(reqMgs, map[string]any{
|
||||
"role": "user",
|
||||
"content": imgList,
|
||||
})
|
||||
} else {
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
req.Messages = append(reqMgs, map[string]any{
|
||||
"role": "user",
|
||||
"content": finalPrompt,
|
||||
})
|
||||
@@ -445,7 +486,7 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, input ChatInput, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
// if the chat model bind a KEY, use it directly
|
||||
if input.ChatModel.KeyId > 0 {
|
||||
h.DB.Where("id", input.ChatModel.KeyId).Find(apiKey)
|
||||
h.DB.Where("id", input.ChatModel.KeyId).Where("enabled", true).Find(apiKey)
|
||||
} else { // use the last unused key
|
||||
h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
|
||||
}
|
||||
@@ -516,6 +557,7 @@ func (h *ChatHandler) subUserPower(userVo vo.User, input ChatInput, promptTokens
|
||||
}
|
||||
|
||||
func (h *ChatHandler) saveChatHistory(
|
||||
c *gin.Context,
|
||||
req types.ApiRequest,
|
||||
usage Usage,
|
||||
message types.Message,
|
||||
@@ -524,6 +566,34 @@ func (h *ChatHandler) saveChatHistory(
|
||||
promptCreatedAt time.Time,
|
||||
replyCreatedAt time.Time) {
|
||||
|
||||
// 文本审核
|
||||
if h.App.SysConfig.Moderation.Enable {
|
||||
moderationResult, err := h.moderationManager.GetService().Moderate(usage.Content)
|
||||
if err != nil {
|
||||
logger.Error("failed to moderate content: ", err)
|
||||
}
|
||||
logger.Debugf("moderationResult: %+v", moderationResult)
|
||||
if moderationResult.Flagged {
|
||||
// 记录违规内容
|
||||
moderation := model.Moderation{
|
||||
UserId: userVo.Id,
|
||||
Source: types.ModerationSourceChat,
|
||||
Input: usage.Prompt,
|
||||
Output: usage.Content,
|
||||
Result: utils.JsonEncode(moderationResult),
|
||||
}
|
||||
err = h.DB.Create(&moderation).Error
|
||||
if err != nil {
|
||||
logger.Error("failed to save moderation: ", err)
|
||||
}
|
||||
pushMessage(c, ChatEventError, "很抱歉,内容触发敏感词预警,AI 无法回答!!!")
|
||||
// 更新用户算力
|
||||
if input.ChatModel.Power > 0 {
|
||||
h.subUserPower(userVo, input, 0, 0)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
// 追加聊天记录
|
||||
// for prompt
|
||||
var promptTokens, replyTokens, totalTokens int
|
||||
@@ -586,6 +656,22 @@ func (h *ChatHandler) saveChatHistory(
|
||||
logger.Error("failed to save reply history message: ", err)
|
||||
}
|
||||
|
||||
// 发送完整聊天记录给前端
|
||||
var messageVo vo.ChatMessage
|
||||
err = utils.CopyObject(historyReplyMsg, &messageVo)
|
||||
if err == nil {
|
||||
// 解析内容
|
||||
var content vo.MsgContent
|
||||
err = utils.JsonDecode(historyReplyMsg.Content, &content)
|
||||
if err != nil {
|
||||
content.Text = historyReplyMsg.Content
|
||||
}
|
||||
messageVo.Content = content
|
||||
messageVo.CreatedAt = historyReplyMsg.CreatedAt.Unix()
|
||||
messageVo.UpdatedAt = historyReplyMsg.UpdatedAt.Unix()
|
||||
pushMessage(c, ChatEventComplete, messageVo)
|
||||
}
|
||||
|
||||
// 更新用户算力
|
||||
if input.ChatModel.Power > 0 {
|
||||
h.subUserPower(userVo, input, promptTokens, replyTokens)
|
||||
@@ -710,221 +796,3 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
|
||||
logger.Error("写入音频数据到响应失败:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// // OPenAI 消息发送实现
|
||||
// func (h *ChatHandler) sendOpenAiMessage(
|
||||
// req types.ApiRequest,
|
||||
// userVo vo.User,
|
||||
// ctx context.Context,
|
||||
// session *types.ChatSession,
|
||||
// role model.ChatRole,
|
||||
// prompt string,
|
||||
// c *gin.Context) error {
|
||||
// promptCreatedAt := time.Now() // 记录提问时间
|
||||
// start := time.Now()
|
||||
// var apiKey = model.ApiKey{}
|
||||
// response, err := h.doRequest(ctx, req, session, &apiKey)
|
||||
// logger.Info("HTTP请求完成,耗时:", time.Since(start))
|
||||
// if err != nil {
|
||||
// if strings.Contains(err.Error(), "context canceled") {
|
||||
// return fmt.Errorf("用户取消了请求:%s", prompt)
|
||||
// } else if strings.Contains(err.Error(), "no available key") {
|
||||
// return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||
// }
|
||||
// return err
|
||||
// } else {
|
||||
// defer response.Body.Close()
|
||||
// }
|
||||
|
||||
// if response.StatusCode != 200 {
|
||||
// body, _ := io.ReadAll(response.Body)
|
||||
// return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
|
||||
// }
|
||||
|
||||
// contentType := response.Header.Get("Prompt-Type")
|
||||
// if strings.Contains(contentType, "text/event-stream") {
|
||||
// replyCreatedAt := time.Now() // 记录回复时间
|
||||
// // 循环读取 Chunk 消息
|
||||
// var message = types.Message{Role: "assistant"}
|
||||
// var contents = make([]string, 0)
|
||||
// var function model.Function
|
||||
// var toolCall = false
|
||||
// var arguments = make([]string, 0)
|
||||
// var reasoning = false
|
||||
|
||||
// pushMessage(c, ChatEventStart, "开始响应")
|
||||
// scanner := bufio.NewScanner(response.Body)
|
||||
// for scanner.Scan() {
|
||||
// line := scanner.Text()
|
||||
// if !strings.Contains(line, "data:") || len(line) < 30 {
|
||||
// continue
|
||||
// }
|
||||
// var responseBody = types.ApiResponse{}
|
||||
// err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
// if err != nil { // 数据解析出错
|
||||
// return errors.New(line)
|
||||
// }
|
||||
// if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
|
||||
// continue
|
||||
// }
|
||||
// if responseBody.Choices[0].Delta.Prompt == nil &&
|
||||
// responseBody.Choices[0].Delta.ToolCalls == nil &&
|
||||
// responseBody.Choices[0].Delta.ReasoningContent == "" {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
|
||||
// pushMessage(c, ChatEventError, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
|
||||
// break
|
||||
// }
|
||||
|
||||
// var tool types.ToolCall
|
||||
// if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
|
||||
// tool = responseBody.Choices[0].Delta.ToolCalls[0]
|
||||
// if toolCall && tool.Function.Name == "" {
|
||||
// arguments = append(arguments, tool.Function.Arguments)
|
||||
// continue
|
||||
// }
|
||||
// }
|
||||
|
||||
// // 兼容 Function Call
|
||||
// fun := responseBody.Choices[0].Delta.FunctionCall
|
||||
// if fun.Name != "" {
|
||||
// tool = *new(types.ToolCall)
|
||||
// tool.Function.Name = fun.Name
|
||||
// } else if toolCall {
|
||||
// arguments = append(arguments, fun.Arguments)
|
||||
// continue
|
||||
// }
|
||||
|
||||
// if !utils.IsEmptyValue(tool) {
|
||||
// res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
|
||||
// if res.Error == nil {
|
||||
// toolCall = true
|
||||
// callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
|
||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
||||
// "type": "text",
|
||||
// "content": callMsg,
|
||||
// })
|
||||
// contents = append(contents, callMsg)
|
||||
// }
|
||||
// continue
|
||||
// }
|
||||
|
||||
// if responseBody.Choices[0].FinishReason == "tool_calls" ||
|
||||
// responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
|
||||
// break
|
||||
// }
|
||||
|
||||
// // output stopped
|
||||
// if responseBody.Choices[0].FinishReason != "" {
|
||||
// break // 输出完成或者输出中断了
|
||||
// } else { // 正常输出结果
|
||||
// // 兼容思考过程
|
||||
// if responseBody.Choices[0].Delta.ReasoningContent != "" {
|
||||
// reasoningContent := responseBody.Choices[0].Delta.ReasoningContent
|
||||
// if !reasoning {
|
||||
// reasoningContent = fmt.Sprintf("<think>%s", reasoningContent)
|
||||
// reasoning = true
|
||||
// }
|
||||
|
||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
||||
// "type": "text",
|
||||
// "content": reasoningContent,
|
||||
// })
|
||||
// contents = append(contents, reasoningContent)
|
||||
// } else if responseBody.Choices[0].Delta.Prompt != "" {
|
||||
// finalContent := responseBody.Choices[0].Delta.Prompt
|
||||
// if reasoning {
|
||||
// finalContent = fmt.Sprintf("</think>%s", responseBody.Choices[0].Delta.Prompt)
|
||||
// reasoning = false
|
||||
// }
|
||||
// contents = append(contents, utils.InterfaceToString(finalContent))
|
||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
||||
// "type": "text",
|
||||
// "content": finalContent,
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
// } // end for
|
||||
|
||||
// if err := scanner.Err(); err != nil {
|
||||
// if strings.Contains(err.Error(), "context canceled") {
|
||||
// logger.Info("用户取消了请求:", prompt)
|
||||
// } else {
|
||||
// logger.Error("信息读取出错:", err)
|
||||
// }
|
||||
// }
|
||||
|
||||
// if toolCall { // 调用函数完成任务
|
||||
// params := make(map[string]any)
|
||||
// _ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
|
||||
// logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
|
||||
// params["user_id"] = userVo.Id
|
||||
// var apiRes types.BizVo
|
||||
// r, err := req2.C().R().SetHeader("Body-Type", "application/json").
|
||||
// SetHeader("Authorization", function.Token).
|
||||
// SetBody(params).Post(function.Action)
|
||||
// errMsg := ""
|
||||
// if err != nil {
|
||||
// errMsg = err.Error()
|
||||
// } else {
|
||||
// all, _ := io.ReadAll(r.Body)
|
||||
// err = json.Unmarshal(all, &apiRes)
|
||||
// if err != nil {
|
||||
// errMsg = err.Error()
|
||||
// } else if apiRes.Code != types.Success {
|
||||
// errMsg = apiRes.Message
|
||||
// }
|
||||
// }
|
||||
|
||||
// if errMsg != "" {
|
||||
// errMsg = "调用函数工具出错:" + errMsg
|
||||
// contents = append(contents, errMsg)
|
||||
// } else {
|
||||
// errMsg = utils.InterfaceToString(apiRes.Data)
|
||||
// contents = append(contents, errMsg)
|
||||
// }
|
||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
||||
// "type": "text",
|
||||
// "content": errMsg,
|
||||
// })
|
||||
// }
|
||||
|
||||
// // 消息发送成功
|
||||
// if len(contents) > 0 {
|
||||
// usage := Usage{
|
||||
// Prompt: prompt,
|
||||
// Prompt: strings.Join(contents, ""),
|
||||
// PromptTokens: 0,
|
||||
// CompletionTokens: 0,
|
||||
// TotalTokens: 0,
|
||||
// }
|
||||
// message.Prompt = usage.Prompt
|
||||
// h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||
// }
|
||||
// } else {
|
||||
// var respVo OpenAIResVo
|
||||
// body, err := io.ReadAll(response.Body)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("读取响应失败:%v", body)
|
||||
// }
|
||||
// err = json.Unmarshal(body, &respVo)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("解析响应失败:%v", body)
|
||||
// }
|
||||
// content := respVo.Choices[0].Message.Prompt
|
||||
// if strings.HasPrefix(req.Model, "o1-") {
|
||||
// content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Prompt)
|
||||
// }
|
||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
||||
// "type": "text",
|
||||
// "content": content,
|
||||
// })
|
||||
// respVo.Usage.Prompt = prompt
|
||||
// respVo.Usage.Prompt = content
|
||||
// h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
@@ -42,9 +42,9 @@ func (h *ChatHandler) List(c *gin.Context) {
|
||||
modelValues = append(modelValues, chat.Model)
|
||||
}
|
||||
|
||||
var roles []model.ChatRole
|
||||
var roles []model.ChatApp
|
||||
var models []model.ChatModel
|
||||
roleMap := make(map[uint]model.ChatRole)
|
||||
roleMap := make(map[uint]model.ChatApp)
|
||||
modelMap := make(map[string]model.ChatModel)
|
||||
h.DB.Where("id IN ?", roleIds).Find(&roles)
|
||||
h.DB.Where("value IN ?", modelValues).Find(&models)
|
||||
@@ -205,7 +205,7 @@ func (h *ChatHandler) Detail(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 填充角色名称
|
||||
var role model.ChatRole
|
||||
var role model.ChatApp
|
||||
res = h.DB.Where("id", chatItem.RoleId).First(&role)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "Role not found")
|
||||
|
||||
@@ -26,6 +26,12 @@ func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
|
||||
return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ChatModelHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/model/")
|
||||
group.GET("list", h.List)
|
||||
}
|
||||
|
||||
// List 模型列表
|
||||
func (h *ChatModelHandler) List(c *gin.Context) {
|
||||
var items []model.ChatModel
|
||||
|
||||
@@ -226,7 +226,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
TotalTokens: 0,
|
||||
}
|
||||
message.Content = usage.Content
|
||||
h.saveChatHistory(req, usage, message, input, userVo, promptCreatedAt, replyCreatedAt)
|
||||
h.saveChatHistory(c, req, usage, message, input, userVo, promptCreatedAt, replyCreatedAt)
|
||||
}
|
||||
} else { // 非流式输出
|
||||
var respVo OpenAIResVo
|
||||
@@ -242,7 +242,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
pushMessage(c, "text", content)
|
||||
respVo.Usage.Prompt = input.Prompt
|
||||
respVo.Usage.Content = content
|
||||
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, input, userVo, promptCreatedAt, time.Now())
|
||||
h.saveChatHistory(c, req, respVo.Usage, respVo.Choices[0].Message, input, userVo, promptCreatedAt, time.Now())
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -27,6 +27,15 @@ func NewConfigHandler(app *core.AppServer, db *gorm.DB, licenseService *service.
|
||||
return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}, licenseService: licenseService}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ConfigHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/config/")
|
||||
|
||||
// 无需授权的接口
|
||||
group.GET("get", h.Get)
|
||||
group.GET("license", h.License)
|
||||
}
|
||||
|
||||
// Get 获取指定的系统配置
|
||||
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
key := c.Query("key")
|
||||
|
||||
@@ -10,9 +10,11 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/dalle"
|
||||
"geekai/service/moderation"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
@@ -25,16 +27,18 @@ import (
|
||||
|
||||
type DallJobHandler struct {
|
||||
BaseHandler
|
||||
dallService *dalle.Service
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
dallService *dalle.Service
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
moderationManager *moderation.ServiceManager
|
||||
}
|
||||
|
||||
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService) *DallJobHandler {
|
||||
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *DallJobHandler {
|
||||
return &DallJobHandler{
|
||||
dallService: service,
|
||||
uploader: manager,
|
||||
userService: userService,
|
||||
dallService: service,
|
||||
uploader: manager,
|
||||
userService: userService,
|
||||
moderationManager: moderationManager,
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
DB: db,
|
||||
@@ -42,6 +46,24 @@ func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *DallJobHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/dall/")
|
||||
|
||||
// 公开接口,不需要授权
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
group.GET("models", h.GetModels)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("image", h.Image)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}
|
||||
}
|
||||
|
||||
// Image 创建一个绘画任务
|
||||
func (h *DallJobHandler) Image(c *gin.Context) {
|
||||
var data types.DallTask
|
||||
@@ -50,6 +72,29 @@ func (h *DallJobHandler) Image(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 文本审核
|
||||
if h.App.SysConfig.Moderation.Enable {
|
||||
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
|
||||
if err != nil {
|
||||
logger.Error("failed to moderate content: ", err)
|
||||
}
|
||||
if moderationResult.Flagged {
|
||||
// 记录违规内容
|
||||
moderation := model.Moderation{
|
||||
UserId: h.GetLoginUserId(c),
|
||||
Source: types.ModerationSourceDalle,
|
||||
Input: data.Prompt,
|
||||
Result: utils.JsonEncode(moderationResult),
|
||||
}
|
||||
err = h.DB.Create(&moderation).Error
|
||||
if err != nil {
|
||||
logger.Error("failed to save moderation: ", err)
|
||||
}
|
||||
resp.ERROR(c, "当前创作内容包含敏感词,提示词未通过文本审核,请重新输入!")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var chatModel model.ChatModel
|
||||
if res := h.DB.Where("id = ?", data.ModelId).First(&chatModel); res.Error != nil {
|
||||
resp.ERROR(c, "模型不存在")
|
||||
@@ -73,11 +118,12 @@ func (h *DallJobHandler) Image(c *gin.Context) {
|
||||
UserId: uint(userId),
|
||||
ModelId: chatModel.Id,
|
||||
ModelName: chatModel.Value,
|
||||
Image: data.Image,
|
||||
Prompt: data.Prompt,
|
||||
Quality: data.Quality,
|
||||
Size: data.Size,
|
||||
Style: data.Style,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||
Power: chatModel.Power,
|
||||
}
|
||||
job := model.DallJob{
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/crawler"
|
||||
"geekai/service/dalle"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
@@ -31,7 +30,6 @@ import (
|
||||
|
||||
type FunctionHandler struct {
|
||||
BaseHandler
|
||||
config types.ApiConfig
|
||||
uploadManager *oss.UploaderManager
|
||||
dallService *dalle.Service
|
||||
userService *service.UserService
|
||||
@@ -49,13 +47,23 @@ func NewFunctionHandler(
|
||||
App: server,
|
||||
DB: db,
|
||||
},
|
||||
config: config.ApiConfig,
|
||||
uploadManager: manager,
|
||||
dallService: dallService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *FunctionHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/function/")
|
||||
group.GET("list", h.List)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.POST("weibo", h.WeiBo)
|
||||
group.POST("zaobao", h.ZaoBao)
|
||||
group.POST("dalle3", h.Dall3)
|
||||
}
|
||||
|
||||
type resVo struct {
|
||||
Code types.BizCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
@@ -107,16 +115,10 @@ func (h *FunctionHandler) WeiBo(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.config.Token == "" {
|
||||
resp.ERROR(c, "无效的 API Token")
|
||||
return
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/weibo/fetch", h.config.ApiURL)
|
||||
url := fmt.Sprintf("%s/api/weibo/fetch", types.GeekAPIURL)
|
||||
var res resVo
|
||||
r, err := req.C().R().
|
||||
SetHeader("AppId", h.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
|
||||
SetHeader("Authorization", "Bearer geekai-plus").
|
||||
SetSuccessResult(&res).Get(url)
|
||||
if err != nil {
|
||||
resp.ERROR(c, fmt.Sprintf("%v", err))
|
||||
@@ -146,16 +148,10 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.config.Token == "" {
|
||||
resp.ERROR(c, "无效的 API Token")
|
||||
return
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/zaobao/fetch", h.config.ApiURL)
|
||||
url := fmt.Sprintf("%s/api/zaobao/fetch", types.GeekAPIURL)
|
||||
var res resVo
|
||||
r, err := req.C().R().
|
||||
SetHeader("AppId", h.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
|
||||
SetHeader("Authorization", "Bearer geekai-plus").
|
||||
SetSuccessResult(&res).Get(url)
|
||||
if err != nil {
|
||||
resp.ERROR(c, fmt.Sprintf("%v", err))
|
||||
@@ -193,16 +189,23 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var chatModel model.ChatModel
|
||||
res := h.DB.Where("type = ?", "img").Where("enabled", true).First(&chatModel)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "没有找到可用的AI绘图模型!")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debugf("绘画参数:%+v", params)
|
||||
var user model.User
|
||||
res := h.DB.Where("id = ?", params["user_id"]).First(&user)
|
||||
res = h.DB.Where("id = ?", params["user_id"]).First(&user)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "当前用户不存在!")
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < h.App.SysConfig.DallPower {
|
||||
resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足")
|
||||
if user.Power < chatModel.Power {
|
||||
resp.ERROR(c, "创建绘图任务失败,算力不足")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -211,24 +214,24 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
||||
task := types.DallTask{
|
||||
UserId: user.Id,
|
||||
Prompt: prompt,
|
||||
ModelId: 0,
|
||||
ModelName: "dall-e-3",
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
ModelId: chatModel.Id,
|
||||
ModelName: chatModel.Value,
|
||||
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||
N: 1,
|
||||
Quality: "standard",
|
||||
Size: "1024x1024",
|
||||
Style: "vivid",
|
||||
Power: h.App.SysConfig.DallPower,
|
||||
Power: chatModel.Power,
|
||||
}
|
||||
job := model.DallJob{
|
||||
UserId: user.Id,
|
||||
Prompt: prompt,
|
||||
Power: h.App.SysConfig.DallPower,
|
||||
Power: chatModel.Power,
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
}
|
||||
err := h.DB.Create(&job).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+err.Error())
|
||||
resp.ERROR(c, "创建绘图任务失败:"+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -253,76 +256,6 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
||||
resp.SUCCESS(c, content)
|
||||
}
|
||||
|
||||
// 实现一个联网搜索的函数工具,采用爬虫实现
|
||||
func (h *FunctionHandler) WebSearch(c *gin.Context) {
|
||||
if err := h.checkAuth(c); err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var params map[string]interface{}
|
||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
// 从参数中获取搜索关键词
|
||||
keyword, ok := params["keyword"].(string)
|
||||
if !ok || keyword == "" {
|
||||
resp.ERROR(c, "搜索关键词不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 从参数中获取最大页数,默认为1页
|
||||
maxPages := 1
|
||||
if pages, ok := params["max_pages"].(float64); ok {
|
||||
maxPages = int(pages)
|
||||
}
|
||||
|
||||
// 获取用户ID
|
||||
userID, ok := params["user_id"].(float64)
|
||||
if !ok {
|
||||
resp.ERROR(c, "用户ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 查询用户信息
|
||||
var user model.User
|
||||
res := h.DB.Where("id = ?", int(userID)).First(&user)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户算力是否足够
|
||||
searchPower := 1 // 每次搜索消耗1点算力
|
||||
if user.Power < searchPower {
|
||||
resp.ERROR(c, "算力不足,无法执行网络搜索")
|
||||
return
|
||||
}
|
||||
|
||||
// 执行网络搜索
|
||||
searchResults, err := crawler.SearchWeb(keyword, maxPages)
|
||||
if err != nil {
|
||||
resp.ERROR(c, fmt.Sprintf("搜索失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 扣减用户算力
|
||||
err = h.userService.DecreasePower(user.Id, searchPower, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "web_search",
|
||||
Remark: fmt.Sprintf("网络搜索:%s", utils.CutWords(keyword, 10)),
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, "扣减算力失败:"+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 返回搜索结果
|
||||
resp.SUCCESS(c, searchResults)
|
||||
}
|
||||
|
||||
// List 获取所有的工具函数列表
|
||||
func (h *FunctionHandler) List(c *gin.Context) {
|
||||
var items []model.Function
|
||||
|
||||
@@ -8,14 +8,18 @@ package handler
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// InviteHandler 用户邀请
|
||||
@@ -27,6 +31,23 @@ func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler {
|
||||
return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *InviteHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/invite/")
|
||||
|
||||
// 公开接口,不需要授权
|
||||
group.GET("hits", h.Hits)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("code", h.Code)
|
||||
group.GET("list", h.List)
|
||||
group.GET("stats", h.Stats)
|
||||
group.GET("rules", h.Rules)
|
||||
}
|
||||
}
|
||||
|
||||
// Code 获取当前用户邀请码
|
||||
func (h *InviteHandler) Code(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
@@ -65,21 +86,34 @@ func (h *InviteHandler) List(c *gin.Context) {
|
||||
var total int64
|
||||
session.Model(&model.InviteLog{}).Count(&total)
|
||||
var items []model.InviteLog
|
||||
var list = make([]vo.InviteLog, 0)
|
||||
offset := (page - 1) * pageSize
|
||||
res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items)
|
||||
if res.Error == nil {
|
||||
for _, item := range items {
|
||||
var v vo.InviteLog
|
||||
err := utils.CopyObject(item, &v)
|
||||
if err == nil {
|
||||
v.Id = item.Id
|
||||
v.CreatedAt = item.CreatedAt.Unix()
|
||||
list = append(list, v)
|
||||
} else {
|
||||
logger.Error(err)
|
||||
}
|
||||
err := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
userIds := make([]uint, 0)
|
||||
for _, item := range items {
|
||||
userIds = append(userIds, item.UserId)
|
||||
}
|
||||
userMap := make(map[uint]model.User)
|
||||
var users []model.User
|
||||
h.DB.Model(&model.User{}).Where("id IN (?)", userIds).Find(&users)
|
||||
for _, user := range users {
|
||||
userMap[user.Id] = user
|
||||
}
|
||||
|
||||
var list = make([]vo.InviteLog, 0)
|
||||
for _, item := range items {
|
||||
var v vo.InviteLog
|
||||
err := utils.CopyObject(item, &v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
v.CreatedAt = item.CreatedAt.Unix()
|
||||
v.Avatar = userMap[item.UserId].Avatar
|
||||
list = append(list, v)
|
||||
}
|
||||
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
|
||||
}
|
||||
@@ -90,3 +124,89 @@ func (h *InviteHandler) Hits(c *gin.Context) {
|
||||
h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// Stats 获取邀请统计
|
||||
func (h *InviteHandler) Stats(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
|
||||
// 获取邀请码
|
||||
var inviteCode model.InviteCode
|
||||
res := h.DB.Where("user_id = ?", userId).First(&inviteCode)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "邀请码不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 统计累计邀请数
|
||||
var totalInvite int64
|
||||
h.DB.Model(&model.InviteLog{}).Where("inviter_id = ?", userId).Count(&totalInvite)
|
||||
|
||||
// 统计今日邀请数
|
||||
today := time.Now().Format("2006-01-02")
|
||||
var todayInvite int64
|
||||
h.DB.Model(&model.InviteLog{}).Where("inviter_id = ? AND DATE(created_at) = ?", userId, today).Count(&todayInvite)
|
||||
|
||||
// 获取系统配置中的邀请奖励
|
||||
var config model.Config
|
||||
var invitePower int = 200 // 默认值
|
||||
if h.DB.Where("name = ?", "system").First(&config).Error == nil {
|
||||
var configMap map[string]any
|
||||
if utils.JsonDecode(config.Value, &configMap) == nil {
|
||||
if power, ok := configMap["invite_power"].(float64); ok {
|
||||
invitePower = int(power)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算获得奖励总数
|
||||
rewardTotal := int(totalInvite) * invitePower
|
||||
|
||||
// 构建邀请链接
|
||||
inviteLink := fmt.Sprintf("%s/register?invite=%s", h.App.Config.StaticUrl, inviteCode.Code)
|
||||
|
||||
stats := vo.InviteStats{
|
||||
InviteCount: int(totalInvite),
|
||||
RewardTotal: rewardTotal,
|
||||
TodayInvite: int(todayInvite),
|
||||
InviteCode: inviteCode.Code,
|
||||
InviteLink: inviteLink,
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, stats)
|
||||
}
|
||||
|
||||
// Rules 获取奖励规则
|
||||
func (h *InviteHandler) Rules(c *gin.Context) {
|
||||
// 获取系统配置中的邀请奖励
|
||||
var config model.Config
|
||||
var invitePower int = 200 // 默认值
|
||||
if h.DB.Where("name = ?", "system").First(&config).Error == nil {
|
||||
var configMap map[string]interface{}
|
||||
if utils.JsonDecode(config.Value, &configMap) == nil {
|
||||
if power, ok := configMap["invite_power"].(float64); ok {
|
||||
invitePower = int(power)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rules := []vo.RewardRule{
|
||||
{
|
||||
Id: 1,
|
||||
Title: "好友注册",
|
||||
Desc: "好友通过邀请链接成功注册",
|
||||
Icon: "icon-user-fill",
|
||||
Color: "#1989fa",
|
||||
Reward: invitePower,
|
||||
},
|
||||
{
|
||||
Id: 2,
|
||||
Title: "好友首次充值",
|
||||
Desc: "好友首次充值任意金额",
|
||||
Icon: "icon-money",
|
||||
Color: "#07c160",
|
||||
Reward: invitePower * 2, // 假设首次充值奖励是注册奖励的2倍
|
||||
},
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, rules)
|
||||
}
|
||||
|
||||
@@ -2,11 +2,12 @@ package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/jimeng"
|
||||
"geekai/service/moderation"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
@@ -19,27 +20,34 @@ import (
|
||||
// JimengHandler 即梦AI处理器
|
||||
type JimengHandler struct {
|
||||
BaseHandler
|
||||
jimengService *jimeng.Service
|
||||
userService *service.UserService
|
||||
jimengService *jimeng.Service
|
||||
userService *service.UserService
|
||||
moderationManager *moderation.ServiceManager
|
||||
}
|
||||
|
||||
// NewJimengHandler 创建即梦AI处理器
|
||||
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB, userService *service.UserService) *JimengHandler {
|
||||
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB, userService *service.UserService, moderationManager *moderation.ServiceManager) *JimengHandler {
|
||||
return &JimengHandler{
|
||||
BaseHandler: BaseHandler{App: app, DB: db},
|
||||
jimengService: jimengService,
|
||||
userService: userService,
|
||||
BaseHandler: BaseHandler{App: app, DB: db},
|
||||
jimengService: jimengService,
|
||||
userService: userService,
|
||||
moderationManager: moderationManager,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由,新增统一任务接口
|
||||
func (h *JimengHandler) RegisterRoutes() {
|
||||
rg := h.App.Engine.Group("/api/jimeng")
|
||||
rg.POST("task", h.CreateTask) // 只保留统一任务接口
|
||||
rg.GET("power-config", h.GetPowerConfig) // 新增算力配置接口
|
||||
rg.POST("jobs", h.Jobs)
|
||||
rg.GET("remove", h.Remove)
|
||||
rg.GET("retry", h.Retry)
|
||||
group := h.App.Engine.Group("/api/jimeng/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("task", h.CreateTask)
|
||||
group.GET("power-config", h.GetPowerConfig)
|
||||
group.POST("jobs", h.Jobs)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("retry", h.Retry)
|
||||
}
|
||||
}
|
||||
|
||||
// JimengTaskRequest 统一任务请求结构体
|
||||
@@ -70,6 +78,31 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
// 文本审核
|
||||
if h.App.SysConfig.Moderation.Enable {
|
||||
moderationResult, err := h.moderationManager.GetService().Moderate(req.Prompt)
|
||||
if err != nil {
|
||||
logger.Error("failed to moderate content: ", err)
|
||||
}
|
||||
if moderationResult.Flagged {
|
||||
// 记录违规内容
|
||||
moderation := model.Moderation{
|
||||
UserId: h.GetLoginUserId(c),
|
||||
Source: types.ModerationSourceJiMeng,
|
||||
Input: req.Prompt,
|
||||
Result: utils.JsonEncode(moderationResult),
|
||||
}
|
||||
err = h.DB.Create(&moderation).Error
|
||||
if err != nil {
|
||||
logger.Error("failed to save moderation: ", err)
|
||||
}
|
||||
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 新增:除图像特效外,其他任务类型必须有提示词
|
||||
if req.TaskType != "image_effects" && req.Prompt == "" {
|
||||
resp.ERROR(c, "提示词不能为空")
|
||||
@@ -153,12 +186,7 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
||||
"seed": req.Seed,
|
||||
"scale": req.Scale,
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
params["image_urls"] = req.ImageUrls
|
||||
}
|
||||
if len(req.BinaryDataBase64) > 0 {
|
||||
params["binary_data_base64"] = req.BinaryDataBase64
|
||||
}
|
||||
params["image_urls"] = []string{req.ImageInput}
|
||||
case "image_effects":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEffects)
|
||||
taskType = model.JMTaskTypeImageEffects
|
||||
@@ -181,9 +209,6 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
||||
taskType = model.JMTaskTypeTextToVideo
|
||||
reqKey = jimeng.ReqKeyTextToVideo
|
||||
modelName = "即梦文生视频"
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
if req.AspectRatio == "" {
|
||||
req.AspectRatio = jimeng.AspectRatio16_9
|
||||
}
|
||||
@@ -196,9 +221,6 @@ func (h *JimengHandler) CreateTask(c *gin.Context) {
|
||||
taskType = model.JMTaskTypeImageToVideo
|
||||
reqKey = jimeng.ReqKeyImageToVideo
|
||||
modelName = "即梦图生视频"
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"aspect_ratio": req.AspectRatio,
|
||||
@@ -333,8 +355,10 @@ func (h *JimengHandler) Remove(c *gin.Context) {
|
||||
resp.ERROR(c, "无权限操作")
|
||||
return
|
||||
}
|
||||
if job.Status != model.JMTaskStatusFailed {
|
||||
resp.ERROR(c, "只有失败的任务才能删除")
|
||||
|
||||
// 正在运行中的任务不能删除
|
||||
if job.Status == model.JMTaskStatusGenerating || job.Status == model.JMTaskStatusInQueue {
|
||||
resp.ERROR(c, "正在运行中的任务不能删除,否则无法退回算力")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -345,17 +369,20 @@ func (h *JimengHandler) Remove(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 退回算力
|
||||
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "jimeng",
|
||||
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, "退回算力失败")
|
||||
tx.Rollback()
|
||||
return
|
||||
// 失败任务删除后退回算力
|
||||
if job.Status != model.JMTaskStatusFailed {
|
||||
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "jimeng",
|
||||
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, "退回算力失败")
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
resp.SUCCESS(c, gin.H{})
|
||||
@@ -408,7 +435,7 @@ func (h *JimengHandler) Retry(c *gin.Context) {
|
||||
|
||||
// getPowerFromConfig 从配置中获取指定类型的算力消耗
|
||||
func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
|
||||
config := h.jimengService.GetConfig()
|
||||
config := h.App.SysConfig.Jimeng
|
||||
|
||||
switch taskType {
|
||||
case model.JMTaskTypeTextToImage:
|
||||
@@ -430,7 +457,7 @@ func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
|
||||
|
||||
// GetPowerConfig 获取即梦各任务类型算力消耗配置
|
||||
func (h *JimengHandler) GetPowerConfig(c *gin.Context) {
|
||||
config := h.jimengService.GetConfig()
|
||||
config := h.App.SysConfig.Jimeng
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"text_to_image": config.Power.TextToImage,
|
||||
"image_to_image": config.Power.ImageToImage,
|
||||
|
||||
@@ -10,6 +10,7 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
@@ -35,6 +36,17 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *MarkMapHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/markMap/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("gen", h.Generate)
|
||||
}
|
||||
}
|
||||
|
||||
// Generate 生成思维导图
|
||||
func (h *MarkMapHandler) Generate(c *gin.Context) {
|
||||
var data struct {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -25,6 +26,12 @@ func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
|
||||
return &MenuHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *MenuHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/menu/")
|
||||
group.GET("list", h.List)
|
||||
}
|
||||
|
||||
// List 数据列表
|
||||
func (h *MenuHandler) List(c *gin.Context) {
|
||||
index := h.GetBool(c, "index")
|
||||
@@ -33,7 +40,7 @@ func (h *MenuHandler) List(c *gin.Context) {
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
session = session.Where("enabled", true)
|
||||
if index {
|
||||
session = session.Where("id IN ?", h.App.SysConfig.IndexNavs)
|
||||
session = session.Where("id IN ?", h.App.SysConfig.Base.IndexNavs)
|
||||
}
|
||||
res := session.Order("sort_num ASC").Find(&items)
|
||||
if res.Error == nil {
|
||||
|
||||
@@ -10,9 +10,11 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/mj"
|
||||
"geekai/service/moderation"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
@@ -27,18 +29,20 @@ import (
|
||||
|
||||
type MidJourneyHandler struct {
|
||||
BaseHandler
|
||||
mjService *mj.Service
|
||||
snowflake *service.Snowflake
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
mjService *mj.Service
|
||||
snowflake *service.Snowflake
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
moderationManager *moderation.ServiceManager
|
||||
}
|
||||
|
||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService) *MidJourneyHandler {
|
||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *MidJourneyHandler {
|
||||
return &MidJourneyHandler{
|
||||
snowflake: snowflake,
|
||||
mjService: service,
|
||||
uploader: manager,
|
||||
userService: userService,
|
||||
snowflake: snowflake,
|
||||
mjService: service,
|
||||
uploader: manager,
|
||||
userService: userService,
|
||||
moderationManager: moderationManager,
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
DB: db,
|
||||
@@ -46,6 +50,25 @@ func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.S
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *MidJourneyHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/mj/")
|
||||
|
||||
// 公开接口,不需要授权
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("image", h.Image)
|
||||
group.POST("upscale", h.Upscale)
|
||||
group.POST("variation", h.Variation)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
@@ -53,7 +76,7 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if user.Power < h.App.SysConfig.MjPower {
|
||||
if user.Power < h.App.SysConfig.Base.MjPower {
|
||||
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||
return false
|
||||
}
|
||||
@@ -90,6 +113,29 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 文本审核
|
||||
if h.App.SysConfig.Moderation.Enable {
|
||||
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
|
||||
if err != nil {
|
||||
logger.Error("failed to moderate content: ", err)
|
||||
}
|
||||
if moderationResult.Flagged {
|
||||
// 记录违规内容
|
||||
moderation := model.Moderation{
|
||||
UserId: h.GetLoginUserId(c),
|
||||
Source: types.ModerationSourceMJ,
|
||||
Input: data.Prompt,
|
||||
Result: utils.JsonEncode(moderationResult),
|
||||
}
|
||||
err = h.DB.Create(&moderation).Error
|
||||
if err != nil {
|
||||
logger.Error("failed to save moderation: ", err)
|
||||
}
|
||||
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var params = ""
|
||||
if data.Rate != "" && !strings.Contains(params, "--ar") {
|
||||
params += " --ar " + data.Rate
|
||||
@@ -159,8 +205,8 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
Params: params,
|
||||
UserId: userId,
|
||||
ImgArr: data.ImgArr,
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
Mode: h.App.SysConfig.Base.MjMode,
|
||||
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||
}
|
||||
job := model.MidJourneyJob{
|
||||
Type: data.TaskType,
|
||||
@@ -169,7 +215,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
Progress: 0,
|
||||
Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
|
||||
Power: h.App.SysConfig.MjPower,
|
||||
Power: h.App.SysConfig.Base.MjPower,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
opt := "绘图"
|
||||
@@ -232,7 +278,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
Index: data.Index,
|
||||
MessageId: data.MessageId,
|
||||
MessageHash: data.MessageHash,
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
Mode: h.App.SysConfig.Base.MjMode,
|
||||
}
|
||||
job := model.MidJourneyJob{
|
||||
Type: types.TaskUpscale.String(),
|
||||
@@ -240,7 +286,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
TaskId: taskId,
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
Progress: 0,
|
||||
Power: h.App.SysConfig.MjActionPower,
|
||||
Power: h.App.SysConfig.Base.MjActionPower,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
||||
@@ -287,7 +333,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
ChannelId: data.ChannelId,
|
||||
MessageId: data.MessageId,
|
||||
MessageHash: data.MessageHash,
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
Mode: h.App.SysConfig.Base.MjMode,
|
||||
}
|
||||
job := model.MidJourneyJob{
|
||||
Type: types.TaskVariation.String(),
|
||||
@@ -296,7 +342,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
TaskId: taskId,
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
Progress: 0,
|
||||
Power: h.App.SysConfig.MjActionPower,
|
||||
Power: h.App.SysConfig.Base.MjActionPower,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
|
||||
|
||||
@@ -9,6 +9,7 @@ package handler
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
@@ -32,6 +33,22 @@ func NewNetHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManage
|
||||
return &NetHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *NetHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/upload")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("", h.Upload)
|
||||
group.POST("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
}
|
||||
|
||||
// 公开接口,不需要授权
|
||||
h.App.Engine.GET("/api/download", h.Download)
|
||||
}
|
||||
|
||||
func (h *NetHandler) Upload(c *gin.Context) {
|
||||
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
|
||||
if err != nil {
|
||||
|
||||
@@ -9,12 +9,12 @@ package handler
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -28,6 +28,18 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
|
||||
return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *OrderHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/order/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("list", h.List)
|
||||
group.GET("query", h.Query)
|
||||
}
|
||||
}
|
||||
|
||||
// List 订单列表
|
||||
func (h *OrderHandler) List(c *gin.Context) {
|
||||
page := h.GetInt(c, "page", 1)
|
||||
@@ -48,20 +60,21 @@ func (h *OrderHandler) List(c *gin.Context) {
|
||||
order.Id = item.Id
|
||||
order.CreatedAt = item.CreatedAt.Unix()
|
||||
order.UpdatedAt = item.UpdatedAt.Unix()
|
||||
payMethod, ok := types.PayMethods[item.PayWay]
|
||||
payChannel, ok := types.PayChannel[item.Channel]
|
||||
if !ok {
|
||||
payMethod = item.PayWay
|
||||
payChannel = item.PayWay
|
||||
}
|
||||
payName, ok := types.PayNames[item.PayType]
|
||||
payWays, ok := types.PayWays[item.PayWay]
|
||||
if !ok {
|
||||
payName = item.PayWay
|
||||
payWays = item.PayWay
|
||||
}
|
||||
order.PayMethod = payMethod
|
||||
order.PayName = payName
|
||||
order.ChannelName = payChannel
|
||||
order.PayName = payWays
|
||||
list = append(list, order)
|
||||
} else {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
|
||||
@@ -82,17 +95,8 @@ func (h *OrderHandler) Query(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
counter := 0
|
||||
for {
|
||||
time.Sleep(time.Second)
|
||||
var item model.Order
|
||||
h.DB.Where("order_no = ?", orderNo).First(&item)
|
||||
if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
|
||||
order.Status = item.Status
|
||||
break
|
||||
}
|
||||
counter++
|
||||
}
|
||||
var item model.Order
|
||||
h.DB.Where("order_no = ?", orderNo).First(&item)
|
||||
|
||||
resp.SUCCESS(c, gin.H{"status": order.Status})
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/payment"
|
||||
@@ -33,52 +34,148 @@ type PayWay struct {
|
||||
// PaymentHandler 支付服务回调 handler
|
||||
type PaymentHandler struct {
|
||||
BaseHandler
|
||||
alipayService *payment.AlipayService
|
||||
huPiPayService *payment.HuPiPayService
|
||||
geekPayService *payment.GeekPayService
|
||||
wechatPayService *payment.WechatPayService
|
||||
snowflake *service.Snowflake
|
||||
userService *service.UserService
|
||||
fs embed.FS
|
||||
lock sync.Mutex
|
||||
signKey string // 用来签名的随机秘钥
|
||||
alipayService *payment.AlipayService
|
||||
epayService *payment.EPayService
|
||||
wxpayService *payment.WxPayService
|
||||
snowflake *service.Snowflake
|
||||
userService *service.UserService
|
||||
fs embed.FS
|
||||
lock sync.Mutex
|
||||
config *types.PaymentConfig
|
||||
}
|
||||
|
||||
func NewPaymentHandler(
|
||||
server *core.AppServer,
|
||||
alipayService *payment.AlipayService,
|
||||
huPiPayService *payment.HuPiPayService,
|
||||
geekPayService *payment.GeekPayService,
|
||||
wechatPayService *payment.WechatPayService,
|
||||
geekPayService *payment.EPayService,
|
||||
wxpayService *payment.WxPayService,
|
||||
db *gorm.DB,
|
||||
userService *service.UserService,
|
||||
snowflake *service.Snowflake,
|
||||
fs embed.FS) *PaymentHandler {
|
||||
fs embed.FS,
|
||||
sysConfig *types.SystemConfig) *PaymentHandler {
|
||||
return &PaymentHandler{
|
||||
alipayService: alipayService,
|
||||
huPiPayService: huPiPayService,
|
||||
geekPayService: geekPayService,
|
||||
wechatPayService: wechatPayService,
|
||||
snowflake: snowflake,
|
||||
userService: userService,
|
||||
fs: fs,
|
||||
lock: sync.Mutex{},
|
||||
alipayService: alipayService,
|
||||
epayService: geekPayService,
|
||||
wxpayService: wxpayService,
|
||||
snowflake: snowflake,
|
||||
userService: userService,
|
||||
fs: fs,
|
||||
lock: sync.Mutex{},
|
||||
BaseHandler: BaseHandler{
|
||||
App: server,
|
||||
DB: db,
|
||||
},
|
||||
signKey: utils.RandString(32),
|
||||
config: &sysConfig.Payment,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *PaymentHandler) Pay(c *gin.Context) {
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *PaymentHandler) RegisterRoutes() {
|
||||
rg := h.App.Engine.Group("/api/payment/")
|
||||
|
||||
// 支付回调接口(公开)
|
||||
rg.POST("notify/alipay", h.AlipayNotify)
|
||||
rg.GET("notify/epay", h.EPayNotify)
|
||||
rg.POST("notify/wxpay", h.WxpayNotify)
|
||||
|
||||
// 需要用户登录的接口
|
||||
rg.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
rg.POST("create", h.CreateOrder)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *PaymentHandler) StartSyncOrders() {
|
||||
go func() {
|
||||
for {
|
||||
err := h.SyncOrders()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
time.Sleep(time.Second * 5)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// SyncOrders 同步订单状态
|
||||
func (h *PaymentHandler) SyncOrders() error {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Errorf("同步订单状态发生异常: %v", err)
|
||||
}
|
||||
}()
|
||||
var orders []model.Order
|
||||
err := h.DB.Where("status", types.OrderNotPaid).Where("checked", false).Find(&orders).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, order := range orders {
|
||||
time.Sleep(time.Second * 1)
|
||||
//超时15分钟的订单,直接标记为已关闭
|
||||
if time.Now().After(order.CreatedAt.Add(time.Minute * 5)) {
|
||||
h.DB.Model(&model.Order{}).Where("id", order.Id).Update("checked", true)
|
||||
logger.Errorf("订单超时:%v", order)
|
||||
continue
|
||||
}
|
||||
// 查询订单状态
|
||||
var res payment.OrderInfo
|
||||
switch order.Channel {
|
||||
case payment.PayChannelEpay:
|
||||
res, err = h.epayService.Query(order.OrderNo)
|
||||
if err != nil {
|
||||
logger.Errorf("error with query order info: %v", err)
|
||||
continue
|
||||
}
|
||||
// 微信支付
|
||||
case payment.PayChannelWX:
|
||||
res, err = h.wxpayService.Query(order.OrderNo)
|
||||
logger.Debugf("微信支付订单状态:%+v", res)
|
||||
if err != nil {
|
||||
logger.Errorf("error with query order info: %v", err)
|
||||
continue
|
||||
}
|
||||
case payment.PayChannelAL:
|
||||
res, err = h.alipayService.Query(order.OrderNo)
|
||||
if err != nil {
|
||||
logger.Errorf("error with query order info: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 订单已关闭
|
||||
if res.Closed() {
|
||||
h.DB.Model(&model.Order{}).Where("id", order.Id).Updates(map[string]any{
|
||||
"checked": true,
|
||||
"status": types.OrderPaidFailed,
|
||||
})
|
||||
logger.Errorf("订单已关闭:%v", order)
|
||||
continue
|
||||
}
|
||||
|
||||
// 订单未支付,不处理,继续轮询
|
||||
if !res.Success() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 订单支付成功
|
||||
err = h.paySuccess(res)
|
||||
if err != nil {
|
||||
logger.Errorf("error with deal order: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *PaymentHandler) CreateOrder(c *gin.Context) {
|
||||
var data struct {
|
||||
PayWay string `json:"pay_way"`
|
||||
PayType string `json:"pay_type"`
|
||||
ProductId int `json:"product_id"`
|
||||
UserId int `json:"user_id"`
|
||||
Device string `json:"device"`
|
||||
Host string `json:"host"`
|
||||
PayWay string `json:"pay_way,omitempty"` // 支付方式:支付宝,微信
|
||||
Pid int `json:"pid,omitempty"`
|
||||
Device string `json:"device,omitempty"`
|
||||
Domain string `json:"domain,omitempty"` // 支付回调域名
|
||||
Channel string `json:"channel,omitempty"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
@@ -86,7 +183,7 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
|
||||
}
|
||||
|
||||
var product model.Product
|
||||
err := h.DB.Where("id", data.ProductId).First(&product).Error
|
||||
err := h.DB.Where("id", data.Pid).First(&product).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "Product not found")
|
||||
return
|
||||
@@ -97,136 +194,118 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
|
||||
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
||||
return
|
||||
}
|
||||
userId := h.GetLoginUserId(c)
|
||||
var user model.User
|
||||
err = h.DB.Where("id", data.UserId).First(&user).Error
|
||||
err = h.DB.Where("id", userId).First(&user).Error
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
amount := product.Discount
|
||||
var payURL, returnURL, notifyURL string
|
||||
amount := product.Price
|
||||
var payURL, notifyURL string
|
||||
switch data.PayWay {
|
||||
case "alipay":
|
||||
if h.App.Config.AlipayConfig.NotifyURL != "" { // 用于本地调试支付
|
||||
notifyURL = h.App.Config.AlipayConfig.NotifyURL
|
||||
} else {
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Host)
|
||||
}
|
||||
if h.App.Config.AlipayConfig.ReturnURL != "" { // 用于本地调试支付
|
||||
returnURL = h.App.Config.AlipayConfig.ReturnURL
|
||||
} else {
|
||||
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
|
||||
}
|
||||
money := fmt.Sprintf("%.2f", amount)
|
||||
if data.Device == "wechat" {
|
||||
payURL, err = h.alipayService.PayMobile(payment.AlipayParams{
|
||||
case "wxpay":
|
||||
logger.Debugf("微信支付,%+v", data)
|
||||
data.Channel = payment.PayChannelWX
|
||||
// 优先使用微信官方支付
|
||||
if h.config.WxPay.Enabled {
|
||||
data.Channel = "wxpay"
|
||||
if h.config.WxPay.Domain != "" {
|
||||
data.Domain = h.config.WxPay.Domain
|
||||
}
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/wxpay", data.Domain)
|
||||
payURL, err = h.wxpayService.Pay(payment.PayRequest{
|
||||
OutTradeNo: orderNo,
|
||||
Subject: product.Name,
|
||||
TotalFee: money,
|
||||
ReturnURL: returnURL,
|
||||
NotifyURL: notifyURL,
|
||||
})
|
||||
} else {
|
||||
payURL, err = h.alipayService.PayPC(payment.AlipayParams{
|
||||
OutTradeNo: orderNo,
|
||||
Subject: product.Name,
|
||||
TotalFee: money,
|
||||
ReturnURL: returnURL,
|
||||
NotifyURL: notifyURL,
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with generate pay url: "+err.Error())
|
||||
return
|
||||
}
|
||||
break
|
||||
case "wechat":
|
||||
if h.App.Config.WechatPayConfig.NotifyURL != "" {
|
||||
notifyURL = h.App.Config.WechatPayConfig.NotifyURL
|
||||
} else {
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/wechat", data.Host)
|
||||
}
|
||||
if data.Device == "wechat" {
|
||||
payURL, err = h.wechatPayService.PayUrlH5(payment.WechatPayParams{
|
||||
OutTradeNo: orderNo,
|
||||
TotalFee: int(amount * 100),
|
||||
TotalFee: fmt.Sprintf("%d", int(amount*100)),
|
||||
Subject: product.Name,
|
||||
NotifyURL: notifyURL,
|
||||
ClientIP: c.ClientIP(),
|
||||
Device: data.Device,
|
||||
PayWay: payment.PayWayWX,
|
||||
})
|
||||
} else {
|
||||
payURL, err = h.wechatPayService.PayUrlNative(payment.WechatPayParams{
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
} else if h.config.Epay.Enabled { // 聚合支付
|
||||
logger.Debugf("聚合支付%+v", data)
|
||||
data.Channel = payment.PayChannelEpay
|
||||
if h.config.Epay.Domain != "" {
|
||||
data.Domain = h.config.Epay.Domain
|
||||
}
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/epay", data.Domain)
|
||||
params := payment.PayRequest{
|
||||
OutTradeNo: orderNo,
|
||||
TotalFee: int(amount * 100),
|
||||
Subject: product.Name,
|
||||
TotalFee: fmt.Sprintf("%f", amount),
|
||||
ClientIP: c.ClientIP(),
|
||||
Device: data.Device,
|
||||
PayWay: payment.PayWayWX,
|
||||
NotifyURL: notifyURL,
|
||||
}
|
||||
|
||||
r, err := h.epayService.Pay(params)
|
||||
logger.Debugf("请求支付结果,%+v", r)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
} else {
|
||||
payURL = r
|
||||
}
|
||||
} else {
|
||||
resp.ERROR(c, "系统没有配置可用的支付渠道!")
|
||||
return
|
||||
}
|
||||
case "alipay":
|
||||
if h.config.Alipay.Enabled {
|
||||
logger.Debugf("支付宝,%+v", data)
|
||||
data.Channel = payment.PayChannelAL
|
||||
if h.config.Alipay.Domain != "" { // 用于本地调试支付
|
||||
data.Domain = h.config.Alipay.Domain
|
||||
}
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/alipay", data.Domain)
|
||||
money := fmt.Sprintf("%.2f", amount)
|
||||
payURL, err = h.alipayService.Pay(payment.PayRequest{
|
||||
Device: data.Device,
|
||||
OutTradeNo: orderNo,
|
||||
Subject: product.Name,
|
||||
TotalFee: money,
|
||||
NotifyURL: notifyURL,
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
break
|
||||
case "hupi":
|
||||
if h.App.Config.HuPiPayConfig.NotifyURL != "" {
|
||||
notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
|
||||
} else {
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/hupi", data.Host)
|
||||
}
|
||||
if h.App.Config.HuPiPayConfig.ReturnURL != "" {
|
||||
returnURL = h.App.Config.HuPiPayConfig.ReturnURL
|
||||
} else {
|
||||
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
|
||||
}
|
||||
r, err := h.huPiPayService.Pay(payment.HuPiPayParams{
|
||||
Version: "1.1",
|
||||
TradeOrderId: orderNo,
|
||||
TotalFee: fmt.Sprintf("%f", amount),
|
||||
Title: product.Name,
|
||||
NotifyURL: notifyURL,
|
||||
ReturnURL: returnURL,
|
||||
WapName: "GeekAI助手",
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
payURL = r.URL
|
||||
break
|
||||
case "geek":
|
||||
if h.App.Config.GeekPayConfig.NotifyURL != "" {
|
||||
notifyURL = h.App.Config.GeekPayConfig.NotifyURL
|
||||
} else {
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/geek", data.Host)
|
||||
}
|
||||
if h.App.Config.GeekPayConfig.ReturnURL != "" {
|
||||
data.Host = utils.GetBaseURL(h.App.Config.GeekPayConfig.ReturnURL)
|
||||
}
|
||||
if data.Device == "wechat" { // 微信客户端打开,调回手机端用户中心页面
|
||||
returnURL = fmt.Sprintf("%s/mobile/profile", data.Host)
|
||||
} else {
|
||||
returnURL = fmt.Sprintf("%s/payReturn", data.Host)
|
||||
}
|
||||
params := payment.GeekPayParams{
|
||||
OutTradeNo: orderNo,
|
||||
Method: "web",
|
||||
Name: product.Name,
|
||||
Money: fmt.Sprintf("%f", amount),
|
||||
ClientIP: c.ClientIP(),
|
||||
Device: data.Device,
|
||||
Type: data.PayType,
|
||||
ReturnURL: returnURL,
|
||||
NotifyURL: notifyURL,
|
||||
}
|
||||
|
||||
res, err := h.geekPayService.Pay(params)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with generate pay url: "+err.Error())
|
||||
return
|
||||
}
|
||||
} else if h.config.Epay.Enabled { // 聚合支付
|
||||
logger.Debugf("聚合支付,%+v", data)
|
||||
data.Channel = payment.PayChannelEpay
|
||||
if h.config.Epay.Domain != "" {
|
||||
data.Domain = h.config.Epay.Domain
|
||||
}
|
||||
notifyURL = fmt.Sprintf("%s/api/payment/notify/epay", data.Domain)
|
||||
params := payment.PayRequest{
|
||||
OutTradeNo: orderNo,
|
||||
Subject: product.Name,
|
||||
TotalFee: fmt.Sprintf("%f", amount),
|
||||
ClientIP: c.ClientIP(),
|
||||
Device: data.Device,
|
||||
PayWay: data.PayWay,
|
||||
NotifyURL: notifyURL,
|
||||
}
|
||||
|
||||
r, err := h.epayService.Pay(params)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
} else {
|
||||
payURL = r
|
||||
}
|
||||
} else {
|
||||
resp.ERROR(c, "系统没有配置可用的支付渠道!")
|
||||
return
|
||||
}
|
||||
payURL = res.PayURL
|
||||
default:
|
||||
resp.ERROR(c, "不支持的支付渠道")
|
||||
return
|
||||
@@ -234,43 +313,40 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
|
||||
|
||||
// 创建订单
|
||||
remark := types.OrderRemark{
|
||||
Days: product.Days,
|
||||
Power: product.Power,
|
||||
Name: product.Name,
|
||||
Price: product.Price,
|
||||
Discount: product.Discount,
|
||||
Power: product.Power,
|
||||
Name: product.Name,
|
||||
Price: product.Price,
|
||||
}
|
||||
order := model.Order{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
ProductId: product.Id,
|
||||
OrderNo: orderNo,
|
||||
Subject: product.Name,
|
||||
Amount: amount,
|
||||
Status: types.OrderNotPaid,
|
||||
PayWay: data.PayWay,
|
||||
PayType: data.PayType,
|
||||
Remark: utils.JsonEncode(remark),
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
OrderNo: orderNo,
|
||||
Subject: product.Name,
|
||||
Amount: amount,
|
||||
Status: types.OrderNotPaid,
|
||||
PayWay: data.PayWay,
|
||||
Channel: data.Channel,
|
||||
Remark: utils.JsonEncode(remark),
|
||||
}
|
||||
err = h.DB.Create(&order).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with create order: "+err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c, payURL)
|
||||
resp.SUCCESS(c, gin.H{"pay_url": payURL, "order_no": orderNo})
|
||||
}
|
||||
|
||||
// 异步通知回调公共逻辑
|
||||
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
||||
// 支付成功处理
|
||||
func (h *PaymentHandler) paySuccess(info payment.OrderInfo) error {
|
||||
h.lock.Lock()
|
||||
defer h.lock.Unlock()
|
||||
|
||||
var order model.Order
|
||||
err := h.DB.Where("order_no = ?", orderNo).First(&order).Error
|
||||
err := h.DB.Where("order_no", info.OutTradeNo).First(&order).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with fetch order: %v", err)
|
||||
}
|
||||
|
||||
h.lock.Lock()
|
||||
defer h.lock.Unlock()
|
||||
|
||||
// 已支付订单,直接返回
|
||||
if order.Status == types.OrderPaidSuccess {
|
||||
return nil
|
||||
@@ -290,19 +366,21 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
||||
|
||||
// 增加用户算力
|
||||
err = h.userService.IncreasePower(order.UserId, remark.Power, model.PowerLog{
|
||||
Type: types.PowerRecharge,
|
||||
Model: order.PayWay,
|
||||
Remark: fmt.Sprintf("充值算力,金额:%f,订单号:%s", order.Amount, order.OrderNo),
|
||||
Type: types.PowerRecharge,
|
||||
Model: order.Subject,
|
||||
Remark: fmt.Sprintf("充值算力,金额:%f,订单号:%s", order.Amount, order.OrderNo),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新订单状态
|
||||
order.PayTime = time.Now().Unix()
|
||||
order.PayTime = utils.Str2stamp(info.PayTime)
|
||||
order.Status = types.OrderPaidSuccess
|
||||
order.TradeNo = tradeNo
|
||||
err = h.DB.Updates(&order).Error
|
||||
order.TradeNo = info.TradeId
|
||||
order.Checked = true
|
||||
err = h.DB.Debug().Updates(&order).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with update order info: %v", err)
|
||||
}
|
||||
@@ -317,54 +395,6 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPayWays 获取支付方式
|
||||
func (h *PaymentHandler) GetPayWays(c *gin.Context) {
|
||||
payWays := make([]gin.H, 0)
|
||||
if h.App.Config.AlipayConfig.Enabled {
|
||||
payWays = append(payWays, gin.H{"pay_way": "alipay", "pay_type": "alipay"})
|
||||
}
|
||||
if h.App.Config.HuPiPayConfig.Enabled {
|
||||
payWays = append(payWays, gin.H{"pay_way": "hupi", "pay_type": "wxpay"})
|
||||
}
|
||||
if h.App.Config.GeekPayConfig.Enabled {
|
||||
for _, v := range h.App.Config.GeekPayConfig.Methods {
|
||||
payWays = append(payWays, gin.H{"pay_way": "geek", "pay_type": v})
|
||||
}
|
||||
}
|
||||
if h.App.Config.WechatPayConfig.Enabled {
|
||||
payWays = append(payWays, gin.H{"pay_way": "wechat", "pay_type": "wxpay"})
|
||||
}
|
||||
resp.SUCCESS(c, payWays)
|
||||
}
|
||||
|
||||
// HuPiPayNotify 虎皮椒支付异步回调
|
||||
func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
|
||||
err := c.Request.ParseForm()
|
||||
if err != nil {
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
orderNo := c.Request.Form.Get("trade_order_id")
|
||||
tradeNo := c.Request.Form.Get("open_order_id")
|
||||
logger.Infof("收到虎皮椒订单支付回调,%+v", c.Request.Form)
|
||||
|
||||
if err = h.huPiPayService.Check(orderNo); err != nil {
|
||||
logger.Error("订单校验失败:", err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.notify(orderNo, tradeNo)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
c.String(http.StatusOK, "success")
|
||||
}
|
||||
|
||||
// AlipayNotify 支付宝支付回调
|
||||
func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
||||
err := c.Request.ParseForm()
|
||||
@@ -373,16 +403,15 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
result := h.alipayService.TradeVerify(c.Request)
|
||||
logger.Infof("收到支付宝商号订单支付回调:%+v", result)
|
||||
if !result.Success() {
|
||||
logger.Error("订单校验失败:", result.Message)
|
||||
orderInfo, err := h.alipayService.Query(c.Request.Form.Get("out_trade_no"))
|
||||
logger.Infof("收到支付宝商号订单支付回调:%+v", orderInfo)
|
||||
if !orderInfo.Success() {
|
||||
logger.Errorf("订单校验失败:%v", err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
tradeNo := c.Request.Form.Get("trade_no")
|
||||
err = h.notify(result.OutTradeNo, tradeNo)
|
||||
err = h.paySuccess(orderInfo)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
@@ -392,28 +421,35 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
|
||||
c.String(http.StatusOK, "success")
|
||||
}
|
||||
|
||||
// GeekPayNotify 支付异步回调
|
||||
func (h *PaymentHandler) GeekPayNotify(c *gin.Context) {
|
||||
// EPayNotify 易支付支付异步回调
|
||||
func (h *PaymentHandler) EPayNotify(c *gin.Context) {
|
||||
var params = make(map[string]string)
|
||||
for k := range c.Request.URL.Query() {
|
||||
params[k] = c.Query(k)
|
||||
}
|
||||
|
||||
logger.Infof("收到GeekPay订单支付回调:%+v", params)
|
||||
// 检查支付状态
|
||||
logger.Infof("收到易支付订单支付回调:%+v", params)
|
||||
// 检查支付状态, 如果未支付,则返回成功
|
||||
if params["trade_status"] != "TRADE_SUCCESS" {
|
||||
c.String(http.StatusOK, "success")
|
||||
return
|
||||
}
|
||||
|
||||
sign := h.geekPayService.Sign(params)
|
||||
sign := h.epayService.Sign(params)
|
||||
if sign != c.Query("sign") {
|
||||
logger.Errorf("签名验证失败, %s, %s", sign, c.Query("sign"))
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
// 查询订单状态
|
||||
order, err := h.epayService.Query(params["out_trade_no"])
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.notify(params["out_trade_no"], params["trade_no"])
|
||||
err = h.paySuccess(order)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
@@ -423,26 +459,23 @@ func (h *PaymentHandler) GeekPayNotify(c *gin.Context) {
|
||||
c.String(http.StatusOK, "success")
|
||||
}
|
||||
|
||||
// WechatPayNotify 微信商户支付异步回调
|
||||
func (h *PaymentHandler) WechatPayNotify(c *gin.Context) {
|
||||
// WxpayNotify 微信商户支付异步回调
|
||||
func (h *PaymentHandler) WxpayNotify(c *gin.Context) {
|
||||
err := c.Request.ParseForm()
|
||||
if err != nil {
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
result := h.wechatPayService.TradeVerify(c.Request)
|
||||
logger.Infof("收到微信商号订单支付回调:%+v", result)
|
||||
if !result.Success() {
|
||||
logger.Error("订单校验失败:", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": "FAIL",
|
||||
"message": err.Error(),
|
||||
})
|
||||
orderInfo, err := h.wxpayService.TradeVerify(c.Request)
|
||||
logger.Infof("收到微信商号订单支付回调:%+v", orderInfo)
|
||||
if err != nil {
|
||||
logger.Errorf("订单校验失败:%v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": "FAIL"})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.notify(result.OutTradeNo, result.TradeId)
|
||||
err = h.paySuccess(orderInfo)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.String(http.StatusOK, "fail")
|
||||
|
||||
@@ -9,11 +9,13 @@ package handler
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -27,6 +29,18 @@ func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
|
||||
return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *PowerLogHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/powerLog/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("list", h.List)
|
||||
group.GET("stats", h.Stats)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *PowerLogHandler) List(c *gin.Context) {
|
||||
var data struct {
|
||||
Model string `json:"model"`
|
||||
@@ -72,3 +86,45 @@ func (h *PowerLogHandler) List(c *gin.Context) {
|
||||
}
|
||||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
|
||||
}
|
||||
|
||||
// Stats 获取用户算力统计
|
||||
func (h *PowerLogHandler) Stats(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
if userId == 0 {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户信息(包含余额)
|
||||
var user model.User
|
||||
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
|
||||
resp.ERROR(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 计算总消费(所有支出记录)
|
||||
var totalConsume int64
|
||||
h.DB.Model(&model.PowerLog{}).
|
||||
Where("user_id", userId).
|
||||
Where("mark", types.PowerSub).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Scan(&totalConsume)
|
||||
|
||||
// 计算今日消费
|
||||
today := time.Now().Format("2006-01-02")
|
||||
var todayConsume int64
|
||||
h.DB.Model(&model.PowerLog{}).
|
||||
Where("user_id", userId).
|
||||
Where("mark", types.PowerSub).
|
||||
Where("DATE(created_at) = ?", today).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Scan(&todayConsume)
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total": totalConsume,
|
||||
"today": todayConsume,
|
||||
"balance": user.Power,
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, stats)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -25,6 +26,12 @@ func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
|
||||
return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *ProductHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/product/")
|
||||
group.GET("list", h.List)
|
||||
}
|
||||
|
||||
// List 模型列表
|
||||
func (h *ProductHandler) List(c *gin.Context) {
|
||||
var items []model.Product
|
||||
|
||||
@@ -10,12 +10,14 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@@ -39,6 +41,20 @@ func NewPromptHandler(app *core.AppServer, db *gorm.DB, userService *service.Use
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *PromptHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/prompt/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis)).Use(middleware.RateLimitEvery(h.App.Redis, 30*time.Second))
|
||||
{
|
||||
group.POST("lyric", h.Lyric)
|
||||
group.POST("image", h.Image)
|
||||
group.POST("video", h.Video)
|
||||
group.POST("meta", h.MetaPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
// Lyric 生成歌词
|
||||
func (h *PromptHandler) Lyric(c *gin.Context) {
|
||||
var data struct {
|
||||
@@ -48,25 +64,12 @@ func (h *PromptHandler) Lyric(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.PromptPower > 0 {
|
||||
userId := h.GetLoginUserId(c)
|
||||
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: h.getPromptModel(),
|
||||
Remark: "生成歌词",
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, content)
|
||||
}
|
||||
|
||||
@@ -79,23 +82,12 @@ func (h *PromptHandler) Image(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
if h.App.SysConfig.PromptPower > 0 {
|
||||
userId := h.GetLoginUserId(c)
|
||||
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: h.getPromptModel(),
|
||||
Remark: "生成绘画提示词",
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, strings.Trim(content, `"`))
|
||||
}
|
||||
|
||||
@@ -108,25 +100,12 @@ func (h *PromptHandler) Video(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.Base.AssistantModelId)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.PromptPower > 0 {
|
||||
userId := h.GetLoginUserId(c)
|
||||
err = h.userService.DecreasePower(userId, h.App.SysConfig.PromptPower, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: h.getPromptModel(),
|
||||
Remark: "生成视频脚本",
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, strings.Trim(content, `"`))
|
||||
}
|
||||
|
||||
@@ -158,9 +137,9 @@ func (h *PromptHandler) MetaPrompt(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *PromptHandler) getPromptModel() string {
|
||||
if h.App.SysConfig.AssistantModelId > 0 {
|
||||
if h.App.SysConfig.Base.AssistantModelId > 0 {
|
||||
var chatModel model.ChatModel
|
||||
h.DB.Where("id", h.App.SysConfig.AssistantModelId).First(&chatModel)
|
||||
h.DB.Where("id", h.App.SysConfig.Base.AssistantModelId).First(&chatModel)
|
||||
return chatModel.Value
|
||||
}
|
||||
return "gpt-4o"
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
@@ -39,6 +40,18 @@ func NewRealtimeHandler(server *core.AppServer, db *gorm.DB, userService *servic
|
||||
return &RealtimeHandler{BaseHandler: BaseHandler{App: server, DB: db}, userService: userService}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *RealtimeHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/realtime/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.Any("", h.Connection)
|
||||
group.POST("voice", h.VoiceChat)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RealtimeHandler) Connection(c *gin.Context) {
|
||||
// 获取客户端请求中指定的子协议
|
||||
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||
@@ -154,7 +167,7 @@ func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < h.App.SysConfig.AdvanceVoicePower {
|
||||
if user.Power < h.App.SysConfig.Base.AdvanceVoicePower {
|
||||
resp.ERROR(c, "当前用户算力不足,无法使用该功能")
|
||||
return
|
||||
}
|
||||
@@ -198,7 +211,7 @@ func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
|
||||
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
|
||||
// 扣减算力
|
||||
err = h.userService.DecreasePower(userId, h.App.SysConfig.AdvanceVoicePower, model.PowerLog{
|
||||
err = h.userService.DecreasePower(userId, h.App.SysConfig.Base.AdvanceVoicePower, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "advanced-voice",
|
||||
Remark: "实时语音通话",
|
||||
|
||||
@@ -10,14 +10,16 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type RedeemHandler struct {
|
||||
@@ -30,6 +32,17 @@ func NewRedeemHandler(app *core.AppServer, db *gorm.DB, userService *service.Use
|
||||
return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}, userService: userService}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *RedeemHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/redeem/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("verify", h.Verify)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RedeemHandler) Verify(c *gin.Context) {
|
||||
var data struct {
|
||||
Code string `json:"code"`
|
||||
|
||||
@@ -10,8 +10,10 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/moderation"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/sd"
|
||||
"geekai/store"
|
||||
@@ -28,12 +30,13 @@ import (
|
||||
|
||||
type SdJobHandler struct {
|
||||
BaseHandler
|
||||
redis *redis.Client
|
||||
sdService *sd.Service
|
||||
uploader *oss.UploaderManager
|
||||
snowflake *service.Snowflake
|
||||
leveldb *store.LevelDB
|
||||
userService *service.UserService
|
||||
redis *redis.Client
|
||||
sdService *sd.Service
|
||||
uploader *oss.UploaderManager
|
||||
snowflake *service.Snowflake
|
||||
leveldb *store.LevelDB
|
||||
userService *service.UserService
|
||||
moderationManager *moderation.ServiceManager
|
||||
}
|
||||
|
||||
func NewSdJobHandler(app *core.AppServer,
|
||||
@@ -42,13 +45,15 @@ func NewSdJobHandler(app *core.AppServer,
|
||||
manager *oss.UploaderManager,
|
||||
snowflake *service.Snowflake,
|
||||
userService *service.UserService,
|
||||
levelDB *store.LevelDB) *SdJobHandler {
|
||||
levelDB *store.LevelDB,
|
||||
moderationManager *moderation.ServiceManager) *SdJobHandler {
|
||||
return &SdJobHandler{
|
||||
sdService: service,
|
||||
uploader: manager,
|
||||
snowflake: snowflake,
|
||||
leveldb: levelDB,
|
||||
userService: userService,
|
||||
sdService: service,
|
||||
uploader: manager,
|
||||
snowflake: snowflake,
|
||||
leveldb: levelDB,
|
||||
userService: userService,
|
||||
moderationManager: moderationManager,
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
DB: db,
|
||||
@@ -56,6 +61,23 @@ func NewSdJobHandler(app *core.AppServer,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *SdJobHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/sd/")
|
||||
|
||||
// 公开接口,不需要授权
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("image", h.Image)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
@@ -63,7 +85,7 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if user.Power < h.App.SysConfig.SdPower {
|
||||
if user.Power < h.App.SysConfig.Base.SdPower {
|
||||
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
|
||||
return false
|
||||
}
|
||||
@@ -84,6 +106,29 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.Moderation.Enable {
|
||||
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
|
||||
if err != nil {
|
||||
logger.Error("failed to moderate content: ", err)
|
||||
}
|
||||
if moderationResult.Flagged {
|
||||
// 记录违规内容
|
||||
moderation := model.Moderation{
|
||||
UserId: h.GetLoginUserId(c),
|
||||
Source: types.ModerationSourceSD,
|
||||
Input: data.Prompt,
|
||||
Result: utils.JsonEncode(moderationResult),
|
||||
}
|
||||
err = h.DB.Create(&moderation).Error
|
||||
if err != nil {
|
||||
logger.Error("failed to save moderation: ", err)
|
||||
}
|
||||
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if data.Width <= 0 {
|
||||
data.Width = 512
|
||||
}
|
||||
@@ -131,7 +176,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
HdSteps: data.HdSteps,
|
||||
},
|
||||
UserId: userId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||
}
|
||||
|
||||
job := model.SdJob{
|
||||
@@ -142,7 +187,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
Prompt: data.Prompt,
|
||||
Progress: 0,
|
||||
Power: h.App.SysConfig.SdPower,
|
||||
Power: h.App.SysConfig.Base.SdPower,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
res := h.DB.Create(&job)
|
||||
|
||||
@@ -24,24 +24,31 @@ const CodeStorePrefix = "/verify/codes/"
|
||||
|
||||
type SmsHandler struct {
|
||||
BaseHandler
|
||||
redis *redis.Client
|
||||
sms *sms.ServiceManager
|
||||
smtp *service.SmtpService
|
||||
captcha *service.CaptchaService
|
||||
redis *redis.Client
|
||||
sms *sms.SmsManager
|
||||
smtp *service.SmtpService
|
||||
captchaService *service.CaptchaService
|
||||
}
|
||||
|
||||
func NewSmsHandler(
|
||||
app *core.AppServer,
|
||||
client *redis.Client,
|
||||
sms *sms.ServiceManager,
|
||||
sms *sms.SmsManager,
|
||||
smtp *service.SmtpService,
|
||||
captcha *service.CaptchaService) *SmsHandler {
|
||||
return &SmsHandler{
|
||||
redis: client,
|
||||
sms: sms,
|
||||
captcha: captcha,
|
||||
smtp: smtp,
|
||||
BaseHandler: BaseHandler{App: app}}
|
||||
redis: client,
|
||||
sms: sms,
|
||||
captchaService: captcha,
|
||||
smtp: smtp,
|
||||
BaseHandler: BaseHandler{App: app}}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *SmsHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/sms/")
|
||||
// 无需授权的接口
|
||||
group.POST("code", h.SendCode)
|
||||
}
|
||||
|
||||
// SendCode 发送验证码
|
||||
@@ -56,12 +63,12 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
if h.App.SysConfig.EnabledVerify {
|
||||
if h.captchaService.GetConfig().Enabled {
|
||||
var check bool
|
||||
if data.X != 0 {
|
||||
check = h.captcha.SlideCheck(data)
|
||||
check = h.captchaService.SlideCheck(data)
|
||||
} else {
|
||||
check = h.captcha.Check(data)
|
||||
check = h.captchaService.Check(data)
|
||||
}
|
||||
if !check {
|
||||
resp.ERROR(c, "请先完人机验证")
|
||||
@@ -72,14 +79,14 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
||||
code := utils.RandomNumber(6)
|
||||
var err error
|
||||
if strings.Contains(data.Receiver, "@") { // email
|
||||
if !utils.Contains(h.App.SysConfig.RegisterWays, "email") {
|
||||
if !utils.Contains(h.App.SysConfig.Base.RegisterWays, "email") {
|
||||
resp.ERROR(c, "系统已禁用邮箱注册!")
|
||||
return
|
||||
}
|
||||
// 检查邮箱后缀是否在白名单
|
||||
if len(h.App.SysConfig.EmailWhiteList) > 0 {
|
||||
if len(h.App.SysConfig.Base.EmailWhiteList) > 0 {
|
||||
inWhiteList := false
|
||||
for _, suffix := range h.App.SysConfig.EmailWhiteList {
|
||||
for _, suffix := range h.App.SysConfig.Base.EmailWhiteList {
|
||||
if strings.HasSuffix(data.Receiver, suffix) {
|
||||
inWhiteList = true
|
||||
break
|
||||
@@ -92,7 +99,7 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
||||
}
|
||||
err = h.smtp.SendVerifyCode(data.Receiver, code)
|
||||
} else {
|
||||
if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") {
|
||||
if !utils.Contains(h.App.SysConfig.Base.RegisterWays, "mobile") {
|
||||
resp.ERROR(c, "系统已禁用手机号注册!")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,8 +10,10 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/moderation"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/suno"
|
||||
"geekai/store/model"
|
||||
@@ -26,20 +28,41 @@ import (
|
||||
|
||||
type SunoHandler struct {
|
||||
BaseHandler
|
||||
sunoService *suno.Service
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
sunoService *suno.Service
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
moderationManager *moderation.ServiceManager
|
||||
}
|
||||
|
||||
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService) *SunoHandler {
|
||||
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *SunoHandler {
|
||||
return &SunoHandler{
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
DB: db,
|
||||
},
|
||||
sunoService: service,
|
||||
uploader: uploader,
|
||||
userService: userService,
|
||||
sunoService: service,
|
||||
uploader: uploader,
|
||||
userService: userService,
|
||||
moderationManager: moderationManager,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *SunoHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/suno/")
|
||||
|
||||
// 公开接口,不需要授权
|
||||
group.GET("play", h.Play)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("create", h.Create)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
group.POST("update", h.Update)
|
||||
group.GET("detail", h.Detail)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,13 +87,36 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.Moderation.Enable {
|
||||
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
|
||||
if err != nil {
|
||||
logger.Error("failed to moderate content: ", err)
|
||||
}
|
||||
if moderationResult.Flagged {
|
||||
// 记录违规内容
|
||||
moderation := model.Moderation{
|
||||
UserId: h.GetLoginUserId(c),
|
||||
Source: types.ModerationSourceSuno,
|
||||
Input: data.Prompt,
|
||||
Result: utils.JsonEncode(moderationResult),
|
||||
}
|
||||
err = h.DB.Create(&moderation).Error
|
||||
if err != nil {
|
||||
logger.Error("failed to save moderation: ", err)
|
||||
}
|
||||
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < h.App.SysConfig.SunoPower {
|
||||
if user.Power < h.App.SysConfig.Base.SunoPower {
|
||||
resp.ERROR(c, "您的算力不足,请充值后再试!")
|
||||
return
|
||||
}
|
||||
@@ -118,7 +164,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
RefSongId: data.RefSongId,
|
||||
RefTaskId: data.RefTaskId,
|
||||
ExtendSecs: data.ExtendSecs,
|
||||
Power: h.App.SysConfig.SunoPower,
|
||||
Power: h.App.SysConfig.Base.SunoPower,
|
||||
SongId: utils.RandString(32),
|
||||
}
|
||||
if data.Lyrics != "" {
|
||||
|
||||
@@ -1,21 +1,36 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/service"
|
||||
"geekai/service/payment"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type TestHandler struct {
|
||||
App *core.AppServer
|
||||
db *gorm.DB
|
||||
snowflake *service.Snowflake
|
||||
js *payment.GeekPayService
|
||||
js *payment.EPayService
|
||||
}
|
||||
|
||||
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.GeekPayService) *TestHandler {
|
||||
return &TestHandler{db: db, snowflake: snowflake, js: js}
|
||||
func NewTestHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, js *payment.EPayService) *TestHandler {
|
||||
return &TestHandler{App: app, db: db, snowflake: snowflake, js: js}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *TestHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/test/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.Any("sse", h.PostTest, h.SseTest)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *TestHandler) SseTest(c *gin.Context) {
|
||||
|
||||
@@ -8,8 +8,10 @@ package handler
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store"
|
||||
@@ -20,8 +22,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
@@ -36,8 +36,10 @@ type UserHandler struct {
|
||||
redis *redis.Client
|
||||
levelDB *store.LevelDB
|
||||
licenseService *service.LicenseService
|
||||
captcha *service.CaptchaService
|
||||
captchaService *service.CaptchaService
|
||||
userService *service.UserService
|
||||
wxLoginService *service.WxLoginService
|
||||
ipSearcher *xdb.Searcher
|
||||
}
|
||||
|
||||
func NewUserHandler(
|
||||
@@ -48,15 +50,45 @@ func NewUserHandler(
|
||||
levelDB *store.LevelDB,
|
||||
captcha *service.CaptchaService,
|
||||
userService *service.UserService,
|
||||
wxLoginService *service.WxLoginService,
|
||||
ipSearcher *xdb.Searcher,
|
||||
licenseService *service.LicenseService) *UserHandler {
|
||||
return &UserHandler{
|
||||
BaseHandler: BaseHandler{DB: db, App: app},
|
||||
searcher: searcher,
|
||||
redis: client,
|
||||
levelDB: levelDB,
|
||||
captcha: captcha,
|
||||
captchaService: captcha,
|
||||
licenseService: licenseService,
|
||||
userService: userService,
|
||||
wxLoginService: wxLoginService,
|
||||
ipSearcher: ipSearcher,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *UserHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/user/")
|
||||
|
||||
// 公开接口,不需要授权
|
||||
group.POST("register", h.Register)
|
||||
group.POST("login", h.Login)
|
||||
group.POST("resetPass", h.ResetPass)
|
||||
group.GET("login/qrcode", h.GetWxLoginQRCode)
|
||||
group.POST("login/callback", h.WxLoginCallback)
|
||||
group.GET("login/status", h.GetWxLoginState)
|
||||
group.GET("logout", h.Logout)
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.GET("session", h.Session)
|
||||
group.GET("profile", h.Profile)
|
||||
group.POST("profile/update", h.ProfileUpdate)
|
||||
group.POST("password", h.UpdatePass)
|
||||
group.POST("bind/mobile", h.BindMobile)
|
||||
group.POST("bind/email", h.BindEmail)
|
||||
group.GET("signin", h.SignIn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,12 +112,13 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.EnabledVerify && data.RegWay == "username" {
|
||||
// 人机验证
|
||||
if h.captchaService.GetConfig().Enabled {
|
||||
var check bool
|
||||
if data.X != 0 {
|
||||
check = h.captcha.SlideCheck(data)
|
||||
check = h.captchaService.SlideCheck(data)
|
||||
} else {
|
||||
check = h.captcha.Check(data)
|
||||
check = h.captchaService.Check(data)
|
||||
}
|
||||
if !check {
|
||||
resp.ERROR(c, "请先完人机验证")
|
||||
@@ -125,30 +158,8 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 验证邀请码
|
||||
inviteCode := model.InviteCode{}
|
||||
if data.InviteCode != "" {
|
||||
res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "无效的邀请码")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
salt := utils.RandString(8)
|
||||
user := model.User{
|
||||
Username: data.Username,
|
||||
Password: utils.GenPassword(data.Password, salt),
|
||||
Avatar: "/images/avatar/user.png",
|
||||
Salt: salt,
|
||||
Status: true,
|
||||
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
||||
ChatConfig: "{}",
|
||||
ChatModels: "{}",
|
||||
Power: h.App.SysConfig.InitPower,
|
||||
}
|
||||
|
||||
// check if the username is existing
|
||||
user := model.User{Username: data.Username, Password: data.Password}
|
||||
var item model.User
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if data.Mobile != "" {
|
||||
@@ -168,78 +179,19 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 被邀请人也获得赠送算力
|
||||
if data.InviteCode != "" {
|
||||
user.Power += h.App.SysConfig.InvitePower
|
||||
}
|
||||
|
||||
if h.licenseService.GetLicense().Configs.DeCopy {
|
||||
user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
|
||||
} else {
|
||||
defaultNickname := h.App.SysConfig.DefaultNickname
|
||||
if defaultNickname == "" {
|
||||
defaultNickname = "极客学长"
|
||||
}
|
||||
user.Nickname = fmt.Sprintf("%s@%d", defaultNickname, utils.RandomNumber(6))
|
||||
}
|
||||
|
||||
tx := h.DB.Begin()
|
||||
if err := tx.Create(&user).Error; err != nil {
|
||||
user, err := h.createNewUser(user, data.InviteCode)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 记录邀请关系
|
||||
if data.InviteCode != "" {
|
||||
// 增加邀请数量
|
||||
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
|
||||
if h.App.SysConfig.InvitePower > 0 {
|
||||
err := h.userService.IncreasePower(inviteCode.UserId, h.App.SysConfig.InvitePower, model.PowerLog{
|
||||
Type: types.PowerInvite,
|
||||
Model: "Invite",
|
||||
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
|
||||
})
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 添加邀请记录
|
||||
err := tx.Create(&model.InviteLog{
|
||||
InviterId: inviteCode.UserId,
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
InviteCode: inviteCode.Code,
|
||||
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
|
||||
}).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
|
||||
// 自动登录创建 token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": user.Id,
|
||||
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
|
||||
})
|
||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
||||
token, err := h.doLogin(&user, c.ClientIP())
|
||||
if err != nil {
|
||||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 保存到 redis
|
||||
key = fmt.Sprintf("users/%d", user.Id)
|
||||
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
|
||||
resp.ERROR(c, "error with save token: "+err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
|
||||
|
||||
resp.SUCCESS(c, gin.H{"token": token, "user_id": user.Id, "username": user.Username})
|
||||
}
|
||||
|
||||
// Login 用户登录
|
||||
@@ -255,15 +207,12 @@ func (h *UserHandler) Login(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
verifyKey := fmt.Sprintf("users/verify/%s", data.Username)
|
||||
needVerify, err := h.redis.Get(c, verifyKey).Bool()
|
||||
|
||||
if h.App.SysConfig.EnabledVerify && needVerify {
|
||||
if h.captchaService.GetConfig().Enabled {
|
||||
var check bool
|
||||
if data.X != 0 {
|
||||
check = h.captcha.SlideCheck(data)
|
||||
check = h.captchaService.SlideCheck(data)
|
||||
} else {
|
||||
check = h.captcha.Check(data)
|
||||
check = h.captchaService.Check(data)
|
||||
}
|
||||
if !check {
|
||||
resp.ERROR(c, "请先完人机验证")
|
||||
@@ -274,54 +223,28 @@ func (h *UserHandler) Login(c *gin.Context) {
|
||||
var user model.User
|
||||
res := h.DB.Where("username = ?", data.Username).First(&user)
|
||||
if res.Error != nil {
|
||||
h.redis.Set(c, verifyKey, true, 0)
|
||||
resp.ERROR(c, "用户名不存在")
|
||||
return
|
||||
}
|
||||
|
||||
password := utils.GenPassword(data.Password, user.Salt)
|
||||
if password != user.Password {
|
||||
h.redis.Set(c, verifyKey, true, 0)
|
||||
resp.ERROR(c, "用户名或密码错误")
|
||||
return
|
||||
}
|
||||
|
||||
if user.Status == false {
|
||||
if !user.Status {
|
||||
resp.ERROR(c, "该用户已被禁止登录,请联系管理员")
|
||||
return
|
||||
}
|
||||
|
||||
// 更新最后登录时间和IP
|
||||
user.LastLoginIp = c.ClientIP()
|
||||
user.LastLoginAt = time.Now().Unix()
|
||||
h.DB.Model(&user).Updates(user)
|
||||
|
||||
h.DB.Create(&model.UserLoginLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
LoginIp: c.ClientIP(),
|
||||
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
|
||||
})
|
||||
|
||||
// 创建 token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": user.Id,
|
||||
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
|
||||
})
|
||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
||||
token, err := h.doLogin(&user, c.ClientIP())
|
||||
if err != nil {
|
||||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 保存到 redis
|
||||
sessionKey := fmt.Sprintf("users/%d", user.Id)
|
||||
if _, err = h.redis.Set(c, sessionKey, tokenString, 0).Result(); err != nil {
|
||||
resp.ERROR(c, "error with save token: "+err.Error())
|
||||
return
|
||||
}
|
||||
// 移除登录行为验证码
|
||||
h.redis.Del(c, verifyKey)
|
||||
resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
|
||||
|
||||
resp.SUCCESS(c, gin.H{"token": token, "user_id": user.Id, "username": user.Username})
|
||||
}
|
||||
|
||||
// Logout 注 销
|
||||
@@ -333,134 +256,165 @@ func (h *UserHandler) Logout(c *gin.Context) {
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// CLogin 第三方登录请求二维码
|
||||
func (h *UserHandler) CLogin(c *gin.Context) {
|
||||
returnURL := h.GetTrim(c, "return_url")
|
||||
var res types.BizVo
|
||||
apiURL := fmt.Sprintf("%s/api/clogin/request", h.App.Config.ApiConfig.ApiURL)
|
||||
r, err := req.C().R().SetBody(gin.H{"login_type": "wx", "return_url": returnURL}).
|
||||
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
|
||||
SetSuccessResult(&res).
|
||||
Post(apiURL)
|
||||
// GetWxLoginQRCode 获取微信登录二维码URL
|
||||
func (h *UserHandler) GetWxLoginQRCode(c *gin.Context) {
|
||||
if !h.wxLoginService.GetConfig().Enabled {
|
||||
resp.ERROR(c, "微信登录功能未启用")
|
||||
return
|
||||
}
|
||||
|
||||
if h.wxLoginService.GetConfig().ApiKey == "" {
|
||||
resp.ERROR(c, "微信登录服务令牌未配置")
|
||||
return
|
||||
}
|
||||
|
||||
state := utils.RandString(32)
|
||||
qrCodeURL, err := h.wxLoginService.GetLoginQrCodeUrl(state)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
if r.IsErrorState() {
|
||||
resp.ERROR(c, "error with login http status: "+r.Status)
|
||||
return
|
||||
}
|
||||
|
||||
if res.Code != types.Success {
|
||||
resp.ERROR(c, "error with http response: "+res.Message)
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, res.Data)
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"url": qrCodeURL,
|
||||
"state": state,
|
||||
})
|
||||
}
|
||||
|
||||
// CLoginCallback 第三方登录回调
|
||||
func (h *UserHandler) CLoginCallback(c *gin.Context) {
|
||||
loginType := c.Query("login_type")
|
||||
code := c.Query("code")
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
action := c.Query("action")
|
||||
// 查询微信登录状态
|
||||
func (h *UserHandler) GetWxLoginState(c *gin.Context) {
|
||||
state := c.Query("state")
|
||||
if state == "" {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
var res types.BizVo
|
||||
apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL)
|
||||
r, err := req.C().R().SetBody(gin.H{"login_type": loginType, "code": code}).
|
||||
SetHeader("AppId", h.App.Config.ApiConfig.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
|
||||
SetSuccessResult(&res).
|
||||
Post(apiURL)
|
||||
status, err := h.wxLoginService.GetLoginStatus(state)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
if r.IsErrorState() {
|
||||
resp.ERROR(c, "error with login http status: "+r.Status)
|
||||
|
||||
if status.Status != service.LoginStatusSuccess {
|
||||
resp.SUCCESS(c, status)
|
||||
return
|
||||
}
|
||||
|
||||
if res.Code != types.Success {
|
||||
resp.ERROR(c, "error with http response: "+res.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// login successfully
|
||||
data := res.Data.(map[string]interface{})
|
||||
// 登录成功
|
||||
var user model.User
|
||||
if action == "bind" && userId > 0 {
|
||||
err = h.DB.Where("openid", data["openid"]).First(&user).Error
|
||||
if err == nil {
|
||||
resp.ERROR(c, "该微信已经绑定其他账号,请先解绑")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.DB.Where("id", userId).First(&user).Error
|
||||
h.DB.Where("openid = ?", status.OpenID).First(&user)
|
||||
if user.Id == 0 {
|
||||
// 创建新用户
|
||||
user, err = h.createNewUser(model.User{OpenId: status.OpenID}, "")
|
||||
if err != nil {
|
||||
resp.ERROR(c, "绑定用户不存在")
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = h.DB.Model(&user).UpdateColumn("openid", data["openid"]).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "更新用户信息失败,"+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{"token": ""})
|
||||
token, err := h.doLogin(&user, c.ClientIP())
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
session := gin.H{}
|
||||
tx := h.DB.Where("openid", data["openid"]).First(&user)
|
||||
if tx.Error != nil {
|
||||
// create new user
|
||||
var totalUser int64
|
||||
h.DB.Model(&model.User{}).Count(&totalUser)
|
||||
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
|
||||
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
|
||||
return
|
||||
}
|
||||
status.Status = service.LoginStatusExpired
|
||||
h.wxLoginService.SetLoginStatus(state, *status)
|
||||
|
||||
salt := utils.RandString(8)
|
||||
password := fmt.Sprintf("%d", utils.RandomNumber(8))
|
||||
user = model.User{
|
||||
Username: fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)),
|
||||
Password: utils.GenPassword(password, salt),
|
||||
Avatar: fmt.Sprintf("%s", data["avatar"]),
|
||||
Salt: salt,
|
||||
Status: true,
|
||||
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
||||
Power: h.App.SysConfig.InitPower,
|
||||
OpenId: fmt.Sprintf("%s", data["openid"]),
|
||||
Nickname: fmt.Sprintf("%s", data["nickname"]),
|
||||
}
|
||||
status.Status = service.LoginStatusSuccess
|
||||
status.Token = token
|
||||
resp.SUCCESS(c, status)
|
||||
}
|
||||
|
||||
tx = h.DB.Create(&user)
|
||||
if tx.Error != nil {
|
||||
resp.ERROR(c, "保存数据失败")
|
||||
logger.Error(tx.Error)
|
||||
return
|
||||
// createNewUser 创建新用户
|
||||
func (h *UserHandler) createNewUser(user model.User, inviteCode string) (model.User, error) {
|
||||
if user.OpenId != "" {
|
||||
user.Platform = "wechat"
|
||||
user.Nickname = fmt.Sprintf("微信用户@%d", utils.RandomNumber(6))
|
||||
user.Username = fmt.Sprintf("wx@%d", utils.RandomNumber(8))
|
||||
user.Password = "geekai123"
|
||||
} else {
|
||||
user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
|
||||
if user.Username == "" || user.Password == "" {
|
||||
return user, fmt.Errorf("用户名或密码不能为空")
|
||||
}
|
||||
session["username"] = user.Username
|
||||
session["password"] = password
|
||||
} else { // login directly
|
||||
// 更新最后登录时间和IP
|
||||
user.LastLoginIp = c.ClientIP()
|
||||
user.LastLoginAt = time.Now().Unix()
|
||||
h.DB.Model(&user).Updates(user)
|
||||
|
||||
h.DB.Create(&model.UserLoginLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
LoginIp: c.ClientIP(),
|
||||
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
|
||||
})
|
||||
}
|
||||
|
||||
salt := utils.RandString(8)
|
||||
user.Salt = salt
|
||||
user.Password = utils.GenPassword(user.Password, salt)
|
||||
user.Avatar = "/images/avatar/user.png"
|
||||
user.Status = true
|
||||
user.ChatRoles = utils.JsonEncode([]string{"gpt"})
|
||||
user.ChatConfig = "{}"
|
||||
user.ChatModels = "{}"
|
||||
user.Power = h.App.SysConfig.Base.InitPower
|
||||
|
||||
// 创建用户
|
||||
tx := h.DB.Begin()
|
||||
if err := tx.Create(&user).Error; err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
// 记录邀请关系
|
||||
if inviteCode != "" {
|
||||
inviteCode := model.InviteCode{}
|
||||
err := h.DB.Where("code = ?", inviteCode).First(&inviteCode).Error
|
||||
if err != nil {
|
||||
return user, fmt.Errorf("无效的邀请码")
|
||||
}
|
||||
|
||||
// 增加邀请数量
|
||||
h.DB.Model(&model.InviteCode{}).Where("code = ?", inviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
|
||||
if h.App.SysConfig.Base.InvitePower > 0 {
|
||||
err := h.userService.IncreasePower(inviteCode.UserId, h.App.SysConfig.Base.InvitePower, model.PowerLog{
|
||||
Type: types.PowerInvite,
|
||||
Model: "Invite",
|
||||
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.Base.InvitePower, inviteCode.Code, user.Username),
|
||||
})
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return user, err
|
||||
}
|
||||
|
||||
// 添加邀请记录
|
||||
err = tx.Create(&model.InviteLog{
|
||||
InviterId: inviteCode.UserId,
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
InviteCode: inviteCode.Code,
|
||||
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.Base.InvitePower),
|
||||
}).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return user, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// doLogin 执行登录操作
|
||||
func (h *UserHandler) doLogin(user *model.User, ip string) (string, error) {
|
||||
// 更新最后登录时间和IP
|
||||
user.LastLoginIp = ip
|
||||
user.LastLoginAt = time.Now().Unix()
|
||||
err := h.DB.Model(user).Updates(user).Error
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to update user: %v", err)
|
||||
}
|
||||
|
||||
// 记录登录日志
|
||||
h.DB.Create(&model.UserLoginLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
LoginIp: ip,
|
||||
LoginAddress: utils.Ip2Region(h.ipSearcher, ip),
|
||||
})
|
||||
|
||||
// 创建 token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": user.Id,
|
||||
@@ -468,17 +422,42 @@ func (h *UserHandler) CLoginCallback(c *gin.Context) {
|
||||
})
|
||||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
||||
if err != nil {
|
||||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
||||
return
|
||||
return "", fmt.Errorf("failed to generate token: %v", err)
|
||||
}
|
||||
|
||||
// 保存到 redis
|
||||
key := fmt.Sprintf("users/%d", user.Id)
|
||||
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
|
||||
resp.ERROR(c, "error with save token: "+err.Error())
|
||||
sessionKey := fmt.Sprintf("users/%d", user.Id)
|
||||
if _, err = h.redis.Set(context.Background(), sessionKey, tokenString, 0).Result(); err != nil {
|
||||
return "", fmt.Errorf("error with save token: %v", err)
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// WxLoginCallback 微信登录回调处理
|
||||
func (h *UserHandler) WxLoginCallback(c *gin.Context) {
|
||||
var data struct {
|
||||
OpenID string `json:"openid"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
session["token"] = tokenString
|
||||
resp.SUCCESS(c, session)
|
||||
|
||||
if data.OpenID == "" || data.State == "" {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 设置登录状态
|
||||
status := service.LoginStatus{
|
||||
Status: service.LoginStatusSuccess,
|
||||
OpenID: data.OpenID,
|
||||
}
|
||||
h.wxLoginService.SetLoginStatus(data.State, status)
|
||||
|
||||
resp.SUCCESS(c, status)
|
||||
}
|
||||
|
||||
// Session 获取/验证会话
|
||||
@@ -742,11 +721,11 @@ func (h *UserHandler) SignIn(c *gin.Context) {
|
||||
|
||||
// 签到
|
||||
h.levelDB.Put(key, true)
|
||||
if h.App.SysConfig.DailyPower > 0 {
|
||||
h.userService.IncreasePower(userId, h.App.SysConfig.DailyPower, model.PowerLog{
|
||||
if h.App.SysConfig.Base.DailyPower > 0 {
|
||||
h.userService.IncreasePower(userId, h.App.SysConfig.Base.DailyPower, model.PowerLog{
|
||||
Type: types.PowerSignIn,
|
||||
Model: "SignIn",
|
||||
Remark: fmt.Sprintf("每日签到奖励,金额:%d", h.App.SysConfig.DailyPower),
|
||||
Remark: fmt.Sprintf("每日签到奖励,金额:%d", h.App.SysConfig.Base.DailyPower),
|
||||
})
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
|
||||
@@ -10,8 +10,10 @@ package handler
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/middleware"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/moderation"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/video"
|
||||
"geekai/store/model"
|
||||
@@ -26,20 +28,37 @@ import (
|
||||
|
||||
type VideoHandler struct {
|
||||
BaseHandler
|
||||
videoService *video.Service
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
videoService *video.Service
|
||||
uploader *oss.UploaderManager
|
||||
userService *service.UserService
|
||||
moderationManager *moderation.ServiceManager
|
||||
}
|
||||
|
||||
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService) *VideoHandler {
|
||||
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService, moderationManager *moderation.ServiceManager) *VideoHandler {
|
||||
return &VideoHandler{
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
DB: db,
|
||||
},
|
||||
videoService: service,
|
||||
uploader: uploader,
|
||||
userService: userService,
|
||||
videoService: service,
|
||||
uploader: uploader,
|
||||
userService: userService,
|
||||
moderationManager: moderationManager,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *VideoHandler) RegisterRoutes() {
|
||||
group := h.App.Engine.Group("/api/video/")
|
||||
|
||||
// 需要用户授权的接口
|
||||
group.Use(middleware.UserAuthMiddleware(h.App.Config.Session.SecretKey, h.App.Redis))
|
||||
{
|
||||
group.POST("luma/create", h.LumaCreate)
|
||||
group.POST("keling/create", h.KeLingCreate)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,13 +81,36 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.Moderation.Enable {
|
||||
moderationResult, err := h.moderationManager.GetService().Moderate(data.Prompt)
|
||||
if err != nil {
|
||||
logger.Error("failed to moderate content: ", err)
|
||||
}
|
||||
if moderationResult.Flagged {
|
||||
// 记录违规内容
|
||||
moderation := model.Moderation{
|
||||
UserId: h.GetLoginUserId(c),
|
||||
Source: types.ModerationSourceVideo,
|
||||
Input: data.Prompt,
|
||||
Result: utils.JsonEncode(moderationResult),
|
||||
}
|
||||
err = h.DB.Create(&moderation).Error
|
||||
if err != nil {
|
||||
logger.Error("failed to save moderation: ", err)
|
||||
}
|
||||
resp.ERROR(c, "当前创作内容包含敏感词,请重新输入!")
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < h.App.SysConfig.LumaPower {
|
||||
if user.Power < h.App.SysConfig.Base.LumaPower {
|
||||
resp.ERROR(c, "您的算力不足,请充值后再试!")
|
||||
return
|
||||
}
|
||||
@@ -85,14 +127,14 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
Type: types.VideoLuma,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||
}
|
||||
// 插入数据库
|
||||
job := model.VideoJob{
|
||||
UserId: uint(userId),
|
||||
Type: types.VideoLuma,
|
||||
Prompt: data.Prompt,
|
||||
Power: h.App.SysConfig.LumaPower,
|
||||
Power: h.App.SysConfig.Base.LumaPower,
|
||||
TaskInfo: utils.JsonEncode(task),
|
||||
}
|
||||
tx := h.DB.Create(&job)
|
||||
@@ -147,7 +189,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) {
|
||||
|
||||
// 计算当前任务所需算力
|
||||
key := fmt.Sprintf("%s_%s_%s", data.Model, data.Mode, data.Duration)
|
||||
power := h.App.SysConfig.KeLingPowers[key]
|
||||
power := h.App.SysConfig.Base.KeLingPowers[key]
|
||||
if power == 0 {
|
||||
resp.ERROR(c, "当前模型暂不支持")
|
||||
return
|
||||
@@ -181,7 +223,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) {
|
||||
Type: types.VideoKeLing,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
TranslateModelId: h.App.SysConfig.Base.AssistantModelId,
|
||||
Channel: data.Channel,
|
||||
}
|
||||
// 插入数据库
|
||||
|
||||
375
api/main.go
375
api/main.go
@@ -19,6 +19,7 @@ import (
|
||||
"geekai/service/dalle"
|
||||
"geekai/service/jimeng"
|
||||
"geekai/service/mj"
|
||||
"geekai/service/moderation"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/payment"
|
||||
"geekai/service/sd"
|
||||
@@ -30,7 +31,7 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"runtime/debug"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -71,15 +72,16 @@ func main() {
|
||||
if configFile == "" {
|
||||
configFile = "config.toml"
|
||||
}
|
||||
debug, _ := strconv.ParseBool(os.Getenv("APP_DEBUG"))
|
||||
logger.Info("Loading config file: ", configFile)
|
||||
if !debug {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Error("Panic Error:", err)
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Error("Panic Error:", err)
|
||||
// 打印堆栈信息
|
||||
if os.Getenv("GEEKAI_DEBUG") == "true" {
|
||||
debug.PrintStack()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
app := fx.New(
|
||||
// 初始化配置应用配置
|
||||
@@ -89,16 +91,16 @@ func main() {
|
||||
log.Fatal(err)
|
||||
}
|
||||
config.Path = configFile
|
||||
if debug {
|
||||
_ = core.SaveConfig(config)
|
||||
}
|
||||
return config
|
||||
}),
|
||||
// 创建应用服务
|
||||
fx.Provide(core.NewServer),
|
||||
// 初始化
|
||||
fx.Invoke(func(s *core.AppServer, client *redis.Client) {
|
||||
s.Init(debug, client)
|
||||
s.Init(client)
|
||||
}),
|
||||
fx.Provide(func(db *gorm.DB) *types.SystemConfig {
|
||||
return core.LoadSystemConfig(db)
|
||||
}),
|
||||
|
||||
// 初始化数据库
|
||||
@@ -126,7 +128,7 @@ func main() {
|
||||
}),
|
||||
|
||||
// 创建控制器
|
||||
fx.Provide(handler.NewChatRoleHandler),
|
||||
fx.Provide(handler.NewChatAppHandler),
|
||||
fx.Provide(handler.NewUserHandler),
|
||||
fx.Provide(handler.NewChatHandler),
|
||||
fx.Provide(handler.NewNetHandler),
|
||||
@@ -143,6 +145,12 @@ func main() {
|
||||
fx.Provide(handler.NewPowerLogHandler),
|
||||
fx.Provide(handler.NewJimengHandler),
|
||||
|
||||
fx.Provide(service.NewMigrationService),
|
||||
fx.Invoke(func(migrationService *service.MigrationService) {
|
||||
migrationService.StartMigrate()
|
||||
}),
|
||||
|
||||
// 管理后台控制器
|
||||
fx.Provide(admin.NewConfigHandler),
|
||||
fx.Provide(admin.NewAdminHandler),
|
||||
fx.Provide(admin.NewApiKeyHandler),
|
||||
@@ -153,34 +161,23 @@ func main() {
|
||||
fx.Provide(admin.NewChatModelHandler),
|
||||
fx.Provide(admin.NewProductHandler),
|
||||
fx.Provide(admin.NewOrderHandler),
|
||||
fx.Provide(admin.NewChatHandler),
|
||||
fx.Provide(admin.NewPowerLogHandler),
|
||||
fx.Provide(admin.NewAdminJimengHandler),
|
||||
|
||||
// 创建服务
|
||||
fx.Provide(sms.NewSendServiceManager),
|
||||
fx.Provide(func(config *types.AppConfig) *service.CaptchaService {
|
||||
return service.NewCaptchaService(config.ApiConfig)
|
||||
}),
|
||||
fx.Provide(oss.NewUploaderManager),
|
||||
fx.Provide(dalle.NewService),
|
||||
fx.Invoke(func(s *dalle.Service) {
|
||||
s.Run()
|
||||
s.DownloadImages()
|
||||
s.CheckTaskStatus()
|
||||
}),
|
||||
|
||||
fx.Provide(service.NewMigrationService),
|
||||
fx.Invoke(func(s *service.MigrationService) {
|
||||
s.Migrate()
|
||||
}),
|
||||
|
||||
// 邮件服务
|
||||
fx.Provide(service.NewSmtpService),
|
||||
// License 服务
|
||||
fx.Provide(service.NewLicenseService),
|
||||
fx.Invoke(func(licenseService *service.LicenseService) {
|
||||
// licenseService.SyncLicense()
|
||||
licenseService.SyncLicense()
|
||||
}),
|
||||
|
||||
// Dalle 服务
|
||||
fx.Provide(dalle.NewService),
|
||||
fx.Invoke(func(s *dalle.Service) {
|
||||
s.Run()
|
||||
s.DownloadImages()
|
||||
s.CheckTaskStatus()
|
||||
}),
|
||||
|
||||
// MidJourney service pool
|
||||
@@ -213,302 +210,179 @@ func main() {
|
||||
}),
|
||||
|
||||
// 即梦AI 服务
|
||||
fx.Provide(jimeng.NewClient),
|
||||
fx.Provide(jimeng.NewService),
|
||||
fx.Invoke(func(service *jimeng.Service) {
|
||||
service.Start()
|
||||
}),
|
||||
fx.Provide(service.NewUserService),
|
||||
fx.Provide(payment.NewAlipayService),
|
||||
fx.Provide(payment.NewHuPiPay),
|
||||
fx.Provide(payment.NewJPayService),
|
||||
fx.Provide(payment.NewWechatService),
|
||||
|
||||
fx.Provide(service.NewSnowflake),
|
||||
fx.Provide(service.NewXXLJobExecutor),
|
||||
fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) {
|
||||
if config.XXLConfig.Enabled {
|
||||
go func() {
|
||||
log.Fatal(exec.Run())
|
||||
}()
|
||||
}
|
||||
|
||||
// 创建短信服务
|
||||
fx.Provide(sms.NewAliYunSmsService),
|
||||
fx.Provide(sms.NewBaoSmsService),
|
||||
fx.Provide(sms.NewSmsManager),
|
||||
fx.Provide(func(config *types.SystemConfig) *service.CaptchaService {
|
||||
return service.NewCaptchaService(config.Captcha)
|
||||
}),
|
||||
fx.Provide(func(config *types.SystemConfig, client *redis.Client) *service.WxLoginService {
|
||||
return service.NewWxLoginService(config.WxLogin, client)
|
||||
}),
|
||||
|
||||
// 支付服务
|
||||
fx.Provide(payment.NewAlipayService),
|
||||
fx.Provide(payment.NewEPayService),
|
||||
fx.Provide(payment.NewWxpayService),
|
||||
|
||||
// 文件上传服务
|
||||
fx.Provide(oss.NewLocalStorage),
|
||||
fx.Provide(oss.NewMiniOss),
|
||||
fx.Provide(oss.NewQiNiuOss),
|
||||
fx.Provide(oss.NewAliYunOss),
|
||||
fx.Provide(oss.NewUploaderManager),
|
||||
|
||||
// 用户服务
|
||||
fx.Provide(service.NewUserService),
|
||||
|
||||
// 文本审查服务
|
||||
fx.Provide(moderation.NewGiteeAIModeration),
|
||||
fx.Provide(moderation.NewBaiduAIModeration),
|
||||
fx.Provide(moderation.NewTencentAIModeration),
|
||||
fx.Provide(moderation.NewServiceManager),
|
||||
fx.Provide(admin.NewModerationHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ModerationHandler) {
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
|
||||
// 注册路由
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
|
||||
group := s.Engine.Group("/api/app/")
|
||||
group.GET("list", h.List)
|
||||
group.GET("list/user", h.ListByUser)
|
||||
group.POST("update", h.UpdateRole)
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppHandler) {
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) {
|
||||
group := s.Engine.Group("/api/user/")
|
||||
group.POST("register", h.Register)
|
||||
group.POST("login", h.Login)
|
||||
group.GET("logout", h.Logout)
|
||||
group.GET("session", h.Session)
|
||||
group.GET("profile", h.Profile)
|
||||
group.POST("profile/update", h.ProfileUpdate)
|
||||
group.POST("password", h.UpdatePass)
|
||||
group.POST("bind/mobile", h.BindMobile)
|
||||
group.POST("bind/email", h.BindEmail)
|
||||
group.POST("resetPass", h.ResetPass)
|
||||
group.GET("clogin", h.CLogin)
|
||||
group.GET("clogin/callback", h.CLoginCallback)
|
||||
group.GET("signin", h.SignIn)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) {
|
||||
group := s.Engine.Group("/api/chat/")
|
||||
group.Any("message", h.Chat)
|
||||
group.GET("list", h.List)
|
||||
group.GET("detail", h.Detail)
|
||||
group.POST("update", h.Update)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("history", h.History)
|
||||
group.GET("clear", h.Clear)
|
||||
group.POST("tokens", h.Tokens)
|
||||
group.GET("stop", h.StopGenerate)
|
||||
group.POST("tts", h.TextToSpeech)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.NetHandler) {
|
||||
s.Engine.POST("/api/upload", h.Upload)
|
||||
s.Engine.POST("/api/upload/list", h.List)
|
||||
s.Engine.GET("/api/upload/remove", h.Remove)
|
||||
s.Engine.GET("/api/download", h.Download)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
|
||||
group := s.Engine.Group("/api/sms/")
|
||||
group.POST("code", h.SendCode)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.CaptchaHandler) {
|
||||
group := s.Engine.Group("/api/captcha/")
|
||||
group.GET("get", h.Get)
|
||||
group.POST("check", h.Check)
|
||||
group.GET("slide/get", h.SlideGet)
|
||||
group.POST("slide/check", h.SlideCheck)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.RedeemHandler) {
|
||||
group := s.Engine.Group("/api/redeem/")
|
||||
group.POST("verify", h.Verify)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
|
||||
group := s.Engine.Group("/api/mj/")
|
||||
group.POST("image", h.Image)
|
||||
group.POST("upscale", h.Upscale)
|
||||
group.POST("variation", h.Variation)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
||||
group := s.Engine.Group("/api/sd")
|
||||
group.POST("image", h.Image)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
|
||||
group := s.Engine.Group("/api/config/")
|
||||
group.GET("get", h.Get)
|
||||
group.GET("license", h.License)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
|
||||
// 管理后台控制器
|
||||
// 管理后台路由注册
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
||||
group := s.Engine.Group("/api/admin/config")
|
||||
group.POST("update", h.Update)
|
||||
group.GET("get", h.Get)
|
||||
group.POST("active", h.Active)
|
||||
group.GET("fixData", h.FixData)
|
||||
group.GET("license", h.GetLicense)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
|
||||
group := s.Engine.Group("/api/admin/")
|
||||
group.POST("login", h.Login)
|
||||
group.GET("logout", h.Logout)
|
||||
group.GET("session", h.Session)
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.POST("enable", h.Enable)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("resetPass", h.ResetPass)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
|
||||
group := s.Engine.Group("/api/admin/apikey/")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("remove", h.Remove)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) {
|
||||
group := s.Engine.Group("/api/admin/user/")
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("loginLog", h.LoginLog)
|
||||
group.GET("genLoginLink", h.GenLoginLink)
|
||||
group.POST("resetPass", h.ResetPass)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppHandler) {
|
||||
group := s.Engine.Group("/api/admin/role/")
|
||||
group.GET("list", h.List)
|
||||
group.POST("save", h.Save)
|
||||
group.POST("sort", h.Sort)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("remove", h.Remove)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.RedeemHandler) {
|
||||
group := s.Engine.Group("/api/admin/redeem/")
|
||||
group.GET("list", h.List)
|
||||
group.POST("create", h.Create)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("export", h.Export)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
|
||||
group := s.Engine.Group("/api/admin/dashboard/")
|
||||
group.GET("stats", h.Stats)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatModelHandler) {
|
||||
group := s.Engine.Group("/api/model/")
|
||||
group.GET("list", h.List)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatModelHandler) {
|
||||
group := s.Engine.Group("/api/admin/model/")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.POST("set", h.Set)
|
||||
group.POST("sort", h.Sort)
|
||||
group.GET("remove", h.Remove)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) {
|
||||
group := s.Engine.Group("/api/payment/")
|
||||
group.POST("doPay", h.Pay)
|
||||
group.GET("payWays", h.GetPayWays)
|
||||
group.POST("notify/alipay", h.AlipayNotify)
|
||||
group.GET("notify/geek", h.GeekPayNotify)
|
||||
group.POST("notify/wechat", h.WechatPayNotify)
|
||||
group.POST("notify/hupi", h.HuPiPayNotify)
|
||||
h.RegisterRoutes()
|
||||
h.StartSyncOrders()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
|
||||
group := s.Engine.Group("/api/admin/product/")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.POST("enable", h.Enable)
|
||||
group.POST("sort", h.Sort)
|
||||
group.GET("remove", h.Remove)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.OrderHandler) {
|
||||
group := s.Engine.Group("/api/admin/order/")
|
||||
group.POST("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("clear", h.Clear)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) {
|
||||
group := s.Engine.Group("/api/order/")
|
||||
group.GET("list", h.List)
|
||||
group.GET("query", h.Query)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ProductHandler) {
|
||||
group := s.Engine.Group("/api/product/")
|
||||
group.GET("list", h.List)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
|
||||
fx.Provide(handler.NewInviteHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) {
|
||||
group := s.Engine.Group("/api/invite/")
|
||||
group.GET("code", h.Code)
|
||||
group.GET("list", h.List)
|
||||
group.GET("hits", h.Hits)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
|
||||
fx.Provide(admin.NewFunctionHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
|
||||
group := s.Engine.Group("/api/admin/function/")
|
||||
group.POST("save", h.Save)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("token", h.GenToken)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
|
||||
fx.Provide(admin.NewUploadHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
|
||||
s.Engine.POST("/api/admin/upload", h.Upload)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
|
||||
fx.Provide(handler.NewFunctionHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
|
||||
group := s.Engine.Group("/api/function/")
|
||||
group.POST("weibo", h.WeiBo)
|
||||
group.POST("zaobao", h.ZaoBao)
|
||||
group.POST("dalle3", h.Dall3)
|
||||
group.POST("websearch", h.WebSearch)
|
||||
group.GET("list", h.List)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Provide(admin.NewChatHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
|
||||
group := s.Engine.Group("/api/admin/chat/")
|
||||
group.POST("list", h.List)
|
||||
group.POST("message", h.Messages)
|
||||
group.GET("history", h.History)
|
||||
group.GET("remove", h.RemoveChat)
|
||||
group.GET("message/remove", h.RemoveMessage)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) {
|
||||
group := s.Engine.Group("/api/powerLog/")
|
||||
group.POST("list", h.List)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) {
|
||||
group := s.Engine.Group("/api/admin/powerLog/")
|
||||
group.POST("list", h.List)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Provide(admin.NewMenuHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.MenuHandler) {
|
||||
group := s.Engine.Group("/api/admin/menu/")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.POST("enable", h.Enable)
|
||||
group.POST("sort", h.Sort)
|
||||
group.GET("remove", h.Remove)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Provide(handler.NewMenuHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.MenuHandler) {
|
||||
group := s.Engine.Group("/api/menu/")
|
||||
group.GET("list", h.List)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Provide(handler.NewMarkMapHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
|
||||
s.Engine.POST("/api/markMap/gen", h.Generate)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Provide(handler.NewDallJobHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
|
||||
group := s.Engine.Group("/api/dall")
|
||||
group.POST("image", h.Image)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.GET("imgWall", h.ImgWall)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
group.GET("models", h.GetModels)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Provide(handler.NewSunoHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
|
||||
group := s.Engine.Group("/api/suno")
|
||||
group.POST("create", h.Create)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
group.POST("update", h.Update)
|
||||
group.GET("detail", h.Detail)
|
||||
group.GET("play", h.Play)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Provide(handler.NewVideoHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
|
||||
group := s.Engine.Group("/api/video")
|
||||
group.POST("luma/create", h.LumaCreate)
|
||||
group.POST("keling/create", h.KeLingCreate)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
|
||||
// 即梦AI 路由
|
||||
@@ -520,30 +394,19 @@ func main() {
|
||||
}),
|
||||
fx.Provide(admin.NewChatAppTypeHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) {
|
||||
group := s.Engine.Group("/api/admin/app/type")
|
||||
group.POST("save", h.Save)
|
||||
group.GET("list", h.List)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("enable", h.Enable)
|
||||
group.POST("sort", h.Sort)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Provide(handler.NewChatAppTypeHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) {
|
||||
group := s.Engine.Group("/api/app/type")
|
||||
group.GET("list", h.List)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Provide(handler.NewTestHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
|
||||
group := s.Engine.Group("/api/test")
|
||||
group.Any("sse", h.PostTest, h.SseTest)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Provide(handler.NewPromptHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
|
||||
group := s.Engine.Group("/api/prompt")
|
||||
group.POST("/lyric", h.Lyric)
|
||||
group.POST("/image", h.Image)
|
||||
group.POST("/video", h.Video)
|
||||
group.POST("/meta", h.MetaPrompt)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
|
||||
go func() {
|
||||
@@ -568,23 +431,15 @@ func main() {
|
||||
}),
|
||||
fx.Provide(admin.NewImageHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ImageHandler) {
|
||||
group := s.Engine.Group("/api/admin/image")
|
||||
group.POST("/list/mj", h.MjList)
|
||||
group.POST("/list/sd", h.SdList)
|
||||
group.POST("/list/dall", h.DallList)
|
||||
group.GET("/remove", h.Remove)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Provide(admin.NewMediaHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.MediaHandler) {
|
||||
group := s.Engine.Group("/api/admin/media")
|
||||
group.POST("/suno", h.SunoList)
|
||||
group.POST("/videos", h.Videos)
|
||||
group.GET("/remove", h.Remove)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Provide(handler.NewRealtimeHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.RealtimeHandler) {
|
||||
s.Engine.Any("/api/realtime", h.Connection)
|
||||
s.Engine.POST("/api/realtime/voice", h.VoiceChat)
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
)
|
||||
// 启动应用程序
|
||||
|
||||
@@ -8,35 +8,38 @@ package service
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"github.com/imroc/req/v3"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type CaptchaService struct {
|
||||
config types.ApiConfig
|
||||
config types.CaptchaConfig
|
||||
client *req.Client
|
||||
}
|
||||
|
||||
func NewCaptchaService(config types.ApiConfig) *CaptchaService {
|
||||
func NewCaptchaService(captchaConfig types.CaptchaConfig) *CaptchaService {
|
||||
return &CaptchaService{
|
||||
config: config,
|
||||
config: captchaConfig,
|
||||
client: req.C().SetTimeout(10 * time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CaptchaService) UpdateConfig(config types.CaptchaConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s *CaptchaService) GetConfig() types.CaptchaConfig {
|
||||
return s.config
|
||||
}
|
||||
|
||||
func (s *CaptchaService) Get() (interface{}, error) {
|
||||
if s.config.Token == "" {
|
||||
return nil, errors.New("无效的 API Token")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/captcha/get", s.config.ApiURL)
|
||||
url := fmt.Sprintf("%s/api/captcha/get", types.GeekAPIURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("AppId", s.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||
SetSuccessResult(&res).Get(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return nil, fmt.Errorf("请求 API 失败:%v", err)
|
||||
@@ -49,12 +52,11 @@ func (s *CaptchaService) Get() (interface{}, error) {
|
||||
return res.Data, nil
|
||||
}
|
||||
|
||||
func (s *CaptchaService) Check(data interface{}) bool {
|
||||
url := fmt.Sprintf("%s/api/captcha/check", s.config.ApiURL)
|
||||
func (s *CaptchaService) Check(data any) bool {
|
||||
url := fmt.Sprintf("%s/api/captcha/check", types.GeekAPIURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("AppId", s.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||
SetBodyJsonMarshal(data).
|
||||
SetSuccessResult(&res).Post(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
@@ -68,16 +70,11 @@ func (s *CaptchaService) Check(data interface{}) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *CaptchaService) SlideGet() (interface{}, error) {
|
||||
if s.config.Token == "" {
|
||||
return nil, errors.New("无效的 API Token")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/captcha/slide/get", s.config.ApiURL)
|
||||
func (s *CaptchaService) SlideGet() (any, error) {
|
||||
url := fmt.Sprintf("%s/api/captcha/slide/get", types.GeekAPIURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("AppId", s.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||
SetSuccessResult(&res).Get(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return nil, fmt.Errorf("请求 API 失败:%v", err)
|
||||
@@ -90,12 +87,11 @@ func (s *CaptchaService) SlideGet() (interface{}, error) {
|
||||
return res.Data, nil
|
||||
}
|
||||
|
||||
func (s *CaptchaService) SlideCheck(data interface{}) bool {
|
||||
url := fmt.Sprintf("%s/api/captcha/slide/check", s.config.ApiURL)
|
||||
func (s *CaptchaService) SlideCheck(data any) bool {
|
||||
url := fmt.Sprintf("%s/api/captcha/slide/check", types.GeekAPIURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("AppId", s.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||
SetBodyJsonMarshal(data).
|
||||
SetSuccessResult(&res).Post(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
|
||||
@@ -1,333 +0,0 @@
|
||||
package crawler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/logger"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-rod/rod"
|
||||
"github.com/go-rod/rod/lib/launcher"
|
||||
"github.com/go-rod/rod/lib/proto"
|
||||
)
|
||||
|
||||
// Service 网络爬虫服务
|
||||
type Service struct {
|
||||
browser *rod.Browser
|
||||
}
|
||||
|
||||
// NewService 创建一个新的爬虫服务
|
||||
func NewService() (*Service, error) {
|
||||
// 启动浏览器
|
||||
path, _ := launcher.LookPath()
|
||||
u := launcher.New().Bin(path).
|
||||
Headless(true). // 无头模式
|
||||
Set("disable-web-security", ""). // 禁用网络安全限制
|
||||
Set("disable-gpu", ""). // 禁用 GPU 加速
|
||||
Set("no-sandbox", ""). // 禁用沙箱模式
|
||||
Set("disable-setuid-sandbox", ""). // 禁用 setuid 沙箱
|
||||
MustLaunch()
|
||||
|
||||
browser := rod.New().ControlURL(u).MustConnect()
|
||||
|
||||
return &Service{
|
||||
browser: browser,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SearchResult 搜索结果
|
||||
type SearchResult struct {
|
||||
Title string `json:"title"` // 标题
|
||||
URL string `json:"url"` // 链接
|
||||
Content string `json:"content"` // 内容摘要
|
||||
}
|
||||
|
||||
// WebSearch 网络搜索
|
||||
func (s *Service) WebSearch(keyword string, maxPages int) ([]SearchResult, error) {
|
||||
if keyword == "" {
|
||||
return nil, errors.New("搜索关键词不能为空")
|
||||
}
|
||||
|
||||
if maxPages <= 0 {
|
||||
maxPages = 1
|
||||
}
|
||||
if maxPages > 10 {
|
||||
maxPages = 10 // 最多搜索 10 页
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0)
|
||||
|
||||
// 使用百度搜索
|
||||
searchURL := fmt.Sprintf("https://www.baidu.com/s?wd=%s", url.QueryEscape(keyword))
|
||||
|
||||
// 设置页面超时
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 创建页面
|
||||
page := s.browser.MustPage()
|
||||
defer page.MustClose()
|
||||
|
||||
// 设置视口大小
|
||||
err := page.SetViewport(&proto.EmulationSetDeviceMetricsOverride{
|
||||
Width: 1280,
|
||||
Height: 800,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("设置视口失败: %v", err)
|
||||
}
|
||||
|
||||
// 导航到搜索页面
|
||||
err = page.Context(ctx).Navigate(searchURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("导航到搜索页面失败: %v", err)
|
||||
}
|
||||
|
||||
// 等待搜索结果加载完成
|
||||
err = page.WaitLoad()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("等待页面加载完成失败: %v", err)
|
||||
}
|
||||
|
||||
// 分析当前页面的搜索结果
|
||||
for i := 0; i < maxPages; i++ {
|
||||
if i > 0 {
|
||||
// 点击下一页按钮
|
||||
nextPage, err := page.Element("a.n")
|
||||
if err != nil || nextPage == nil {
|
||||
break // 没有下一页
|
||||
}
|
||||
|
||||
err = nextPage.Click(proto.InputMouseButtonLeft, 1)
|
||||
if err != nil {
|
||||
break // 点击下一页失败
|
||||
}
|
||||
|
||||
// 等待新页面加载
|
||||
err = page.WaitLoad()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 提取搜索结果
|
||||
resultElements, err := page.Elements(".result, .c-container")
|
||||
if err != nil || resultElements == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, result := range resultElements {
|
||||
// 获取标题
|
||||
titleElement, err := result.Element("h3, .t")
|
||||
if err != nil || titleElement == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
title, err := titleElement.Text()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取 URL
|
||||
linkElement, err := titleElement.Element("a")
|
||||
if err != nil || linkElement == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
href, err := linkElement.Attribute("href")
|
||||
if err != nil || href == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取内容摘要 - 尝试多个可能的选择器
|
||||
var contentElement *rod.Element
|
||||
var content string
|
||||
|
||||
// 尝试多个可能的选择器来适应不同版本的百度搜索结果
|
||||
selectors := []string{".content-right_8Zs40", ".c-abstract", ".content_LJ0WN", ".content"}
|
||||
for _, selector := range selectors {
|
||||
contentElement, err = result.Element(selector)
|
||||
if err == nil && contentElement != nil {
|
||||
content, _ = contentElement.Text()
|
||||
if content != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果所有选择器都失败,尝试直接从结果块中提取文本
|
||||
if content == "" {
|
||||
// 获取结果元素的所有文本
|
||||
fullText, err := result.Text()
|
||||
if err == nil && fullText != "" {
|
||||
// 简单处理:从全文中移除标题,剩下的可能是摘要
|
||||
fullText = strings.Replace(fullText, title, "", 1)
|
||||
// 清理文本
|
||||
content = strings.TrimSpace(fullText)
|
||||
// 限制内容长度
|
||||
if len(content) > 200 {
|
||||
content = content[:200] + "..."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加到结果集
|
||||
results = append(results, SearchResult{
|
||||
Title: title,
|
||||
URL: *href,
|
||||
Content: content,
|
||||
})
|
||||
|
||||
// 限制结果数量,每页最多 10 条
|
||||
if len(results) >= 10*maxPages {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取真实 URL(百度搜索结果中的 URL 是短链接,需要跳转获取真实 URL)
|
||||
for i, result := range results {
|
||||
realURL, err := s.getRedirectURL(result.URL)
|
||||
if err == nil && realURL != "" {
|
||||
results[i].URL = realURL
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// 获取真实 URL
|
||||
func (s *Service) getRedirectURL(shortURL string) (string, error) {
|
||||
// 创建页面
|
||||
page, err := s.browser.Page(proto.TargetCreateTarget{URL: ""})
|
||||
if err != nil {
|
||||
return shortURL, err // 返回原始URL
|
||||
}
|
||||
defer func() {
|
||||
_ = page.Close()
|
||||
}()
|
||||
|
||||
// 导航到短链接
|
||||
err = page.Navigate(shortURL)
|
||||
if err != nil {
|
||||
return shortURL, err // 返回原始URL
|
||||
}
|
||||
|
||||
// 等待重定向完成
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// 获取当前 URL
|
||||
info, err := page.Info()
|
||||
if err != nil {
|
||||
return shortURL, err // 返回原始URL
|
||||
}
|
||||
|
||||
return info.URL, nil
|
||||
}
|
||||
|
||||
// Close 关闭浏览器
|
||||
func (s *Service) Close() error {
|
||||
if s.browser != nil {
|
||||
err := s.browser.Close()
|
||||
s.browser = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SearchWeb 封装的搜索方法
|
||||
func SearchWeb(keyword string, maxPages int) (string, error) {
|
||||
// 添加panic恢复机制
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log := logger.GetLogger()
|
||||
log.Errorf("爬虫服务崩溃: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
service, err := NewService()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建爬虫服务失败: %v", err)
|
||||
}
|
||||
defer service.Close()
|
||||
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 使用goroutine和通道来处理超时
|
||||
resultChan := make(chan []SearchResult, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
results, err := service.WebSearch(keyword, maxPages)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
resultChan <- results
|
||||
}()
|
||||
|
||||
// 等待结果或超时
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", fmt.Errorf("搜索超时: %v", ctx.Err())
|
||||
case err := <-errChan:
|
||||
return "", fmt.Errorf("搜索失败: %v", err)
|
||||
case results := <-resultChan:
|
||||
if len(results) == 0 {
|
||||
return "未找到关于 \"" + keyword + "\" 的相关搜索结果", nil
|
||||
}
|
||||
|
||||
// 格式化结果
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintf("为您找到关于 \"%s\" 的 %d 条搜索结果:\n\n", keyword, len(results)))
|
||||
|
||||
for i, result := range results {
|
||||
// // 尝试打开链接获取实际内容
|
||||
// page := service.browser.MustPage()
|
||||
// defer page.MustClose()
|
||||
|
||||
// // 设置页面超时
|
||||
// pageCtx, pageCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
// defer pageCancel()
|
||||
|
||||
// // 导航到目标页面
|
||||
// err := page.Context(pageCtx).Navigate(result.URL)
|
||||
// if err == nil {
|
||||
// // 等待页面加载
|
||||
// _ = page.WaitLoad()
|
||||
|
||||
// // 获取页面标题
|
||||
// title, err := page.Eval("() => document.title")
|
||||
// if err == nil && title.Value.String() != "" {
|
||||
// result.Title = title.Value.String()
|
||||
// }
|
||||
|
||||
// // 获取页面主要内容
|
||||
// if content, err := page.Element("body"); err == nil {
|
||||
// if text, err := content.Text(); err == nil {
|
||||
// // 清理并截取内容
|
||||
// text = strings.TrimSpace(text)
|
||||
// if len(text) > 200 {
|
||||
// text = text[:200] + "..."
|
||||
// }
|
||||
// result.Prompt = text
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
builder.WriteString(fmt.Sprintf("%d. **%s**\n", i+1, result.Title))
|
||||
builder.WriteString(fmt.Sprintf(" 链接: %s\n", result.URL))
|
||||
if result.Content != "" {
|
||||
builder.WriteString(fmt.Sprintf(" 摘要: %s\n", result.Content))
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
return builder.String(), nil
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
@@ -94,12 +95,14 @@ func (s *Service) Run() {
|
||||
}
|
||||
|
||||
type imgReq struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Image []string `json:"image,omitempty"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type imgRes struct {
|
||||
@@ -122,15 +125,6 @@ type ErrRes struct {
|
||||
|
||||
func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
logger.Debugf("绘画参数:%+v", task)
|
||||
prompt := task.Prompt
|
||||
// translate prompt
|
||||
if utils.HasChinese(prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, prompt), task.TranslateModelId)
|
||||
if err == nil {
|
||||
prompt = content
|
||||
logger.Debugf("重写后提示词:%s", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
var chatModel model.ChatModel
|
||||
if task.ModelId > 0 {
|
||||
@@ -160,12 +154,17 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
|
||||
reqBody := imgReq{
|
||||
Model: chatModel.Value,
|
||||
Prompt: prompt,
|
||||
Prompt: task.Prompt,
|
||||
N: 1,
|
||||
Size: task.Size,
|
||||
Style: task.Style,
|
||||
Quality: task.Quality,
|
||||
}
|
||||
// 图片编辑
|
||||
if len(task.Image) > 0 {
|
||||
reqBody.Prompt = fmt.Sprintf("%s, %s", strings.Join(task.Image, " "), task.Prompt)
|
||||
}
|
||||
|
||||
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
|
||||
r, err := s.httpClient.R().SetHeader("Body-Type", "application/json").
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
@@ -188,7 +187,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
var imgURL string
|
||||
var data = map[string]interface{}{
|
||||
"progress": 100,
|
||||
"prompt": prompt,
|
||||
"prompt": task.Prompt,
|
||||
}
|
||||
// 如果返回的是base64,则需要上传到oss
|
||||
if res.Data[0].B64Json != "" {
|
||||
@@ -210,11 +209,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
|
||||
var content string
|
||||
if sync {
|
||||
imgURL, err := s.downloadImage(task.Id, res.Data[0].Url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n\n", prompt, imgURL)
|
||||
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n\n", task.Prompt, imgURL)
|
||||
}
|
||||
|
||||
return content, nil
|
||||
|
||||
@@ -3,8 +3,10 @@ package jimeng
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/base"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
@@ -13,14 +15,22 @@ import (
|
||||
// Client 即梦API客户端
|
||||
type Client struct {
|
||||
visual *visual.Visual
|
||||
config types.JimengConfig
|
||||
}
|
||||
|
||||
// NewClient 创建即梦API客户端
|
||||
func NewClient(accessKey, secretKey string) *Client {
|
||||
func NewClient(sysConfig *types.SystemConfig) *Client {
|
||||
|
||||
client := &Client{}
|
||||
client.UpdateConfig(sysConfig.Jimeng)
|
||||
return client
|
||||
}
|
||||
|
||||
func (c *Client) UpdateConfig(config types.JimengConfig) error {
|
||||
// 使用官方SDK的visual实例
|
||||
visualInstance := visual.NewInstance()
|
||||
visualInstance.Client.SetAccessKey(accessKey)
|
||||
visualInstance.Client.SetSecretKey(secretKey)
|
||||
visualInstance.Client.SetAccessKey(config.AccessKey)
|
||||
visualInstance.Client.SetSecretKey(config.SecretKey)
|
||||
|
||||
// 添加即梦AI专有的API配置
|
||||
jimengApis := map[string]*base.ApiInfo{
|
||||
@@ -55,9 +65,32 @@ func NewClient(accessKey, secretKey string) *Client {
|
||||
visualInstance.Client.ApiInfoList[name] = info
|
||||
}
|
||||
|
||||
return &Client{
|
||||
visual: visualInstance,
|
||||
c.config = config
|
||||
c.visual = visualInstance
|
||||
|
||||
return c.testConnection()
|
||||
}
|
||||
|
||||
// testConnection 测试即梦AI连接
|
||||
func (c *Client) testConnection() error {
|
||||
|
||||
// 使用一个简单的查询任务来测试连接
|
||||
testReq := &QueryTaskRequest{
|
||||
ReqKey: "test_connection",
|
||||
TaskId: "test_task_id_12345",
|
||||
}
|
||||
|
||||
_, err := c.QueryTask(testReq)
|
||||
// 即使任务不存在,只要不是认证错误就说明连接正常
|
||||
if err != nil {
|
||||
// 检查是否是认证错误
|
||||
if strings.Contains(err.Error(), "InvalidAccessKey") {
|
||||
return fmt.Errorf("认证失败,请检查AccessKey和SecretKey是否正确")
|
||||
}
|
||||
// 其他错误(如任务不存在)说明连接正常
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubmitTask 提交异步任务
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -16,8 +15,6 @@ import (
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
|
||||
"geekai/core/types"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
@@ -36,17 +33,8 @@ type Service struct {
|
||||
}
|
||||
|
||||
// NewService 创建即梦服务
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager) *Service {
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager, client *Client) *Service {
|
||||
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
|
||||
// 从数据库加载配置
|
||||
var config model.Config
|
||||
db.Where("name = ?", "Jimeng").First(&config)
|
||||
var jimengConfig types.JimengConfig
|
||||
if config.Id > 0 {
|
||||
_ = utils.JsonDecode(config.Value, &jimengConfig)
|
||||
}
|
||||
client := NewClient(jimengConfig.AccessKey, jimengConfig.SecretKey)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Service{
|
||||
db: db,
|
||||
@@ -378,7 +366,7 @@ func (s *Service) pollTaskStatus() {
|
||||
|
||||
for _, job := range jobs {
|
||||
// 任务超时处理
|
||||
if job.UpdatedAt.Before(time.Now().Add(-5 * time.Minute)) {
|
||||
if job.UpdatedAt.Before(time.Now().Add(-10 * time.Minute)) {
|
||||
s.handleTaskError(job.Id, "task timeout")
|
||||
continue
|
||||
}
|
||||
@@ -391,7 +379,7 @@ func (s *Service) pollTaskStatus() {
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("query jimeng task status failed: %v", err)
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", err.Error()))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -446,9 +434,7 @@ func (s *Service) pollTaskStatus() {
|
||||
s.handleTaskError(job.Id, "task not found")
|
||||
|
||||
case model.JMTaskStatusExpired:
|
||||
// 任务过期
|
||||
s.handleTaskError(job.Id, "task expired")
|
||||
|
||||
continue
|
||||
default:
|
||||
logger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||
}
|
||||
@@ -524,77 +510,3 @@ func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
|
||||
}
|
||||
return &job, nil
|
||||
}
|
||||
|
||||
// testConnection 测试即梦AI连接
|
||||
func (s *Service) testConnection(accessKey, secretKey string) error {
|
||||
testClient := NewClient(accessKey, secretKey)
|
||||
|
||||
// 使用一个简单的查询任务来测试连接
|
||||
testReq := &QueryTaskRequest{
|
||||
ReqKey: "test_connection",
|
||||
TaskId: "test_task_id_12345",
|
||||
}
|
||||
|
||||
_, err := testClient.QueryTask(testReq)
|
||||
// 即使任务不存在,只要不是认证错误就说明连接正常
|
||||
if err != nil {
|
||||
// 检查是否是认证错误
|
||||
if strings.Contains(err.Error(), "InvalidAccessKey") {
|
||||
return fmt.Errorf("认证失败,请检查AccessKey和SecretKey是否正确")
|
||||
}
|
||||
// 其他错误(如任务不存在)说明连接正常
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateClientConfig 更新客户端配置
|
||||
func (s *Service) UpdateClientConfig(accessKey, secretKey string) error {
|
||||
// 创建新的客户端
|
||||
newClient := NewClient(accessKey, secretKey)
|
||||
|
||||
// 测试新客户端是否可用
|
||||
err := s.testConnection(accessKey, secretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新客户端
|
||||
s.client = newClient
|
||||
return nil
|
||||
}
|
||||
|
||||
var defaultPower = types.JimengPower{
|
||||
TextToImage: 20,
|
||||
ImageToImage: 20,
|
||||
ImageEdit: 20,
|
||||
ImageEffects: 20,
|
||||
TextToVideo: 300,
|
||||
ImageToVideo: 300,
|
||||
}
|
||||
|
||||
// GetConfig 获取即梦AI配置
|
||||
func (s *Service) GetConfig() *types.JimengConfig {
|
||||
var config model.Config
|
||||
err := s.db.Where("name", "jimeng").First(&config).Error
|
||||
if err != nil {
|
||||
// 如果配置不存在,返回默认配置
|
||||
return &types.JimengConfig{
|
||||
AccessKey: "",
|
||||
SecretKey: "",
|
||||
Power: defaultPower,
|
||||
}
|
||||
}
|
||||
|
||||
var jimengConfig types.JimengConfig
|
||||
err = utils.JsonDecode(config.Value, &jimengConfig)
|
||||
if err != nil {
|
||||
return &types.JimengConfig{
|
||||
AccessKey: "",
|
||||
SecretKey: "",
|
||||
Power: defaultPower,
|
||||
}
|
||||
}
|
||||
|
||||
return &jimengConfig
|
||||
}
|
||||
|
||||
@@ -8,30 +8,37 @@ package service
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/shirou/gopsutil/host"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type LicenseService struct {
|
||||
config types.ApiConfig
|
||||
levelDB *store.LevelDB
|
||||
license *types.License
|
||||
urlWhiteList []string
|
||||
machineId string
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
|
||||
var license types.License
|
||||
func NewLicenseService(sysConfig *types.SystemConfig, db *gorm.DB) *LicenseService {
|
||||
var machineId string
|
||||
info, err := host.Info()
|
||||
if err == nil {
|
||||
machineId = info.HostID
|
||||
}
|
||||
logger.Infof("License: %+v", sysConfig.License)
|
||||
return &LicenseService{
|
||||
config: server.Config.ApiConfig,
|
||||
levelDB: levelDB,
|
||||
license: &license,
|
||||
machineId: "",
|
||||
license: &sysConfig.License,
|
||||
machineId: machineId,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,15 +53,15 @@ type License struct {
|
||||
}
|
||||
|
||||
// ActiveLicense 激活 License
|
||||
func (s *LicenseService) ActiveLicense(license string, machineId string) error {
|
||||
func (s *LicenseService) ActiveLicense(license string) error {
|
||||
var res struct {
|
||||
Code types.BizCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data License `json:"data"`
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active")
|
||||
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/active")
|
||||
response, err := req.C().R().
|
||||
SetBody(map[string]string{"license": license, "machine_id": machineId}).
|
||||
SetBody(map[string]string{"license": license, "machine_id": s.machineId}).
|
||||
SetSuccessResult(&res).Post(apiURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("发送激活请求失败: %v", err)
|
||||
@@ -68,17 +75,24 @@ func (s *LicenseService) ActiveLicense(license string, machineId string) error {
|
||||
return fmt.Errorf("激活失败:%v", res.Message)
|
||||
}
|
||||
|
||||
if res.Data.ExpiredAt > 0 && res.Data.ExpiredAt < time.Now().Unix() {
|
||||
return fmt.Errorf("License 已过期")
|
||||
}
|
||||
|
||||
s.license = &types.License{
|
||||
Key: license,
|
||||
MachineId: machineId,
|
||||
MachineId: s.machineId,
|
||||
Configs: res.Data.Configs,
|
||||
ExpiredAt: res.Data.ExpiredAt,
|
||||
IsActive: true,
|
||||
}
|
||||
err = s.levelDB.Put(types.LicenseKey, s.license)
|
||||
|
||||
// 保存 License 到数据库
|
||||
err = s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).UpdateColumn("value", utils.JsonEncode(s.license)).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("保存许可证书失败:%v", err)
|
||||
return fmt.Errorf("保存 License 到数据库失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -96,6 +110,11 @@ func (s *LicenseService) SyncLicense() {
|
||||
s.license.IsActive = false
|
||||
} else {
|
||||
s.license = license
|
||||
// 保存 License 到数据库
|
||||
err = s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).UpdateColumn("value", utils.JsonEncode(s.license)).Error
|
||||
if err != nil {
|
||||
logger.Errorf("保存 License 到数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
urls, err := s.fetchUrlWhiteList()
|
||||
@@ -109,33 +128,30 @@ func (s *LicenseService) SyncLicense() {
|
||||
}
|
||||
|
||||
func (s *LicenseService) fetchLicense() (*types.License, error) {
|
||||
//var res struct {
|
||||
// Code types.BizCode `json:"code"`
|
||||
// Message string `json:"message"`
|
||||
// Data License `json:"data"`
|
||||
//}
|
||||
//apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
|
||||
//response, err := req.C().R().
|
||||
// SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
|
||||
// SetSuccessResult(&res).Post(apiURL)
|
||||
//if err != nil {
|
||||
// return nil, fmt.Errorf("发送激活请求失败: %v", err)
|
||||
//}
|
||||
//if response.IsErrorState() {
|
||||
// return nil, fmt.Errorf("激活失败:%v", response.Status)
|
||||
//}
|
||||
//if res.Code != types.Success {
|
||||
// return nil, fmt.Errorf("激活失败:%v", res.Message)
|
||||
//}
|
||||
var res struct {
|
||||
Code types.BizCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data License `json:"data"`
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/check")
|
||||
response, err := req.C().R().
|
||||
SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
|
||||
SetSuccessResult(&res).Post(apiURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("License 同步失败: %v", err)
|
||||
}
|
||||
if response.IsErrorState() {
|
||||
return nil, fmt.Errorf("License 同步失败:%v", response.Status)
|
||||
}
|
||||
if res.Code != types.Success {
|
||||
return nil, fmt.Errorf("License 同步失败:%v", res.Message)
|
||||
}
|
||||
|
||||
return &types.License{
|
||||
Key: "abc",
|
||||
MachineId: "abc",
|
||||
Configs: types.LicenseConfig{
|
||||
UserNum: 10000,
|
||||
DeCopy: false,
|
||||
},
|
||||
ExpiredAt: 0,
|
||||
Key: res.Data.License,
|
||||
MachineId: res.Data.MachineId,
|
||||
Configs: res.Data.Configs,
|
||||
ExpiredAt: res.Data.ExpiredAt,
|
||||
IsActive: true,
|
||||
}, nil
|
||||
}
|
||||
@@ -146,7 +162,7 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
|
||||
Message string `json:"message"`
|
||||
Data []string `json:"data"`
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls")
|
||||
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/urls")
|
||||
response, err := req.C().R().SetSuccessResult(&res).Get(apiURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("发送请求失败: %v", err)
|
||||
@@ -163,35 +179,46 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
|
||||
|
||||
// GetLicense 获取许可信息
|
||||
func (s *LicenseService) GetLicense() *types.License {
|
||||
if s.license == nil {
|
||||
var config model.Config
|
||||
s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).First(&config)
|
||||
if config.Value != "" {
|
||||
utils.JsonDecode(config.Value, &s.license)
|
||||
}
|
||||
}
|
||||
return s.license
|
||||
}
|
||||
|
||||
func (s *LicenseService) SetLicense(licenseKey string) {
|
||||
s.license.Key = licenseKey
|
||||
|
||||
}
|
||||
|
||||
// IsValidApiURL 判断是否合法的中转 URL
|
||||
func (s *LicenseService) IsValidApiURL(uri string) error {
|
||||
// 获得许可授权的直接放行
|
||||
return nil
|
||||
//if s.license.IsActive {
|
||||
// if s.license.MachineId != s.machineId {
|
||||
// return errors.New("系统使用了盗版的许可证书")
|
||||
// }
|
||||
//
|
||||
// if time.Now().Unix() > s.license.ExpiredAt {
|
||||
// return errors.New("系统许可证书已经过期")
|
||||
// }
|
||||
// return nil
|
||||
//}
|
||||
//
|
||||
//if len(s.urlWhiteList) == 0 {
|
||||
// urls, err := s.fetchUrlWhiteList()
|
||||
// if err == nil {
|
||||
// s.urlWhiteList = urls
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//for _, v := range s.urlWhiteList {
|
||||
// if strings.HasPrefix(uri, v) {
|
||||
// return nil
|
||||
// }
|
||||
//}
|
||||
//return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
|
||||
if s.license.IsActive {
|
||||
if s.license.MachineId != s.machineId {
|
||||
return errors.New("系统使用了盗版的许可证书")
|
||||
}
|
||||
|
||||
if time.Now().Unix() > s.license.ExpiredAt {
|
||||
return errors.New("系统许可证书已经过期")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(s.urlWhiteList) == 0 {
|
||||
urls, err := s.fetchUrlWhiteList()
|
||||
if err == nil {
|
||||
s.urlWhiteList = urls
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range s.urlWhiteList {
|
||||
if strings.HasPrefix(uri, v) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
|
||||
}
|
||||
|
||||
@@ -1,52 +1,342 @@
|
||||
package service
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// Use of this source code is governed by a Apache-2.0 license
|
||||
// that can be found in the LICENSE file.
|
||||
// @Author yangjian102621@163.com
|
||||
// ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"strings"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
// 迁移状态Redis key
|
||||
MigrationStatusKey = "config_migration:status"
|
||||
// 迁移完成标志
|
||||
MigrationCompleted = "completed"
|
||||
)
|
||||
|
||||
// MigrationService 配置迁移服务
|
||||
type MigrationService struct {
|
||||
db *gorm.DB
|
||||
db *gorm.DB
|
||||
redisClient *redis.Client
|
||||
appConfig *types.AppConfig
|
||||
levelDB *store.LevelDB
|
||||
licenseService *LicenseService
|
||||
}
|
||||
|
||||
func NewMigrationService(db *gorm.DB) *MigrationService {
|
||||
return &MigrationService{db: db}
|
||||
func NewMigrationService(db *gorm.DB, redisClient *redis.Client, appConfig *types.AppConfig, levelDB *store.LevelDB, licenseService *LicenseService) *MigrationService {
|
||||
return &MigrationService{
|
||||
db: db,
|
||||
redisClient: redisClient,
|
||||
appConfig: appConfig,
|
||||
levelDB: levelDB,
|
||||
licenseService: licenseService,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MigrationService) Migrate() error {
|
||||
err := s.db.AutoMigrate(
|
||||
&model.AdminUser{},
|
||||
&model.ApiKey{},
|
||||
&model.AppType{},
|
||||
&model.ChatItem{},
|
||||
&model.ChatMessage{},
|
||||
&model.ChatModel{},
|
||||
&model.ChatRole{},
|
||||
&model.Config{},
|
||||
&model.DallJob{},
|
||||
&model.File{},
|
||||
&model.Function{},
|
||||
&model.InviteCode{},
|
||||
&model.InviteLog{},
|
||||
&model.Menu{},
|
||||
&model.MidJourneyJob{},
|
||||
&model.Order{},
|
||||
&model.PowerLog{},
|
||||
&model.Product{},
|
||||
&model.Redeem{},
|
||||
&model.SdJob{},
|
||||
&model.SunoJob{},
|
||||
&model.User{},
|
||||
&model.UserLoginLog{},
|
||||
&model.VideoJob{},
|
||||
)
|
||||
return err
|
||||
func (s *MigrationService) StartMigrate() {
|
||||
go func() {
|
||||
s.MigrateConfig(s.appConfig)
|
||||
s.TableMigration()
|
||||
s.MigrateLicense()
|
||||
}()
|
||||
}
|
||||
|
||||
// 迁移 License
|
||||
func (s *MigrationService) MigrateLicense() {
|
||||
key := "migrate:license"
|
||||
if s.redisClient.Get(context.Background(), key).Val() == "1" {
|
||||
logger.Info("License 已迁移,跳过迁移")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("开始迁移 License...")
|
||||
var license types.License
|
||||
err := s.levelDB.Get(types.LicenseKey, &license)
|
||||
if err != nil {
|
||||
license = types.License{
|
||||
Key: "",
|
||||
MachineId: "",
|
||||
Configs: types.LicenseConfig{UserNum: 0, DeCopy: false},
|
||||
ExpiredAt: 0,
|
||||
IsActive: false,
|
||||
}
|
||||
}
|
||||
logger.Infof("迁移 License: %+v", license)
|
||||
if err := s.saveConfig(types.ConfigKeyLicense, license); err != nil {
|
||||
logger.Errorf("迁移 License 失败: %v", err)
|
||||
return
|
||||
}
|
||||
s.licenseService.SetLicense(license.Key)
|
||||
logger.Info("迁移 License 完成")
|
||||
s.redisClient.Set(context.Background(), key, "1", 0)
|
||||
}
|
||||
|
||||
// 迁移配置内容
|
||||
func (s *MigrationService) MigrateConfigContent() error {
|
||||
// 用户协议
|
||||
if err := s.saveConfig(types.ConfigKeyPrivacy, map[string]string{
|
||||
"content": "用户协议内容",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
// 隐私政策
|
||||
if err := s.saveConfig(types.ConfigKeyAgreement, map[string]string{
|
||||
"content": "隐私政策内容",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
// 思维导图
|
||||
if err := s.saveConfig(types.ConfigKeyMarkMap, map[string]string{
|
||||
"content": `# GeekAI 演示站
|
||||
|
||||
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
|
||||
- 基于 Websocket 实现,完美的打字机体验。
|
||||
- 内置了各种预训练好的角色应用,轻松满足你的各种聊天和应用需求。
|
||||
- 支持 OPenAI,Azure,文心一言,讯飞星火,清华 ChatGLM等多个大语言模型。
|
||||
- 支持 MidJourney / Stable Diffusion AI 绘画集成,开箱即用。
|
||||
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
|
||||
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
|
||||
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件。`,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
|
||||
// 微信登录配置
|
||||
if err := s.saveConfig(types.ConfigKeyWxLogin, map[string]string{
|
||||
"api_key": "",
|
||||
"notify_url": "",
|
||||
"enabled": "false",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证码配置
|
||||
if err := s.saveConfig(types.ConfigKeyCaptcha, map[string]string{
|
||||
"api_key": "",
|
||||
"type": "dot",
|
||||
"enabled": "false",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
|
||||
// 文本审核
|
||||
if err := s.saveConfig(types.ConfigKeyModeration, map[string]any{
|
||||
"enable": "false",
|
||||
"active": "gitee",
|
||||
"enable_guide": "false",
|
||||
"guide_prompt": "",
|
||||
"gitee": map[string]string{
|
||||
"api_key": "",
|
||||
"model": "Security-semantic-filtering",
|
||||
},
|
||||
"baidu": map[string]string{
|
||||
"access_key": "",
|
||||
"secret_key": "",
|
||||
},
|
||||
"tencent": map[string]string{
|
||||
"access_key": "",
|
||||
"secret_key": "",
|
||||
},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 数据表迁移
|
||||
func (s *MigrationService) TableMigration() {
|
||||
// 新数据表
|
||||
s.db.AutoMigrate(&model.Moderation{})
|
||||
|
||||
// 订单字段整理
|
||||
if s.db.Migrator().HasColumn(&model.Order{}, "pay_type") {
|
||||
s.db.Migrator().RenameColumn(&model.Order{}, "pay_type", "channel")
|
||||
}
|
||||
if !s.db.Migrator().HasColumn(&model.Order{}, "checked") {
|
||||
s.db.Migrator().AddColumn(&model.Order{}, "checked")
|
||||
}
|
||||
|
||||
// 重命名 config 表字段
|
||||
if s.db.Migrator().HasColumn(&model.Config{}, "config_json") {
|
||||
s.db.Migrator().RenameColumn(&model.Config{}, "config_json", "value")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.Config{}, "marker") {
|
||||
s.db.Migrator().RenameColumn(&model.Config{}, "marker", "name")
|
||||
}
|
||||
if s.db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") {
|
||||
s.db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key")
|
||||
}
|
||||
if s.db.Migrator().HasIndex(&model.Config{}, "marker") {
|
||||
s.db.Migrator().DropIndex(&model.Config{}, "marker")
|
||||
}
|
||||
|
||||
// 手动删除字段
|
||||
if s.db.Migrator().HasColumn(&model.Order{}, "deleted_at") {
|
||||
s.db.Migrator().DropColumn(&model.Order{}, "deleted_at")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") {
|
||||
s.db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") {
|
||||
s.db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.User{}, "chat_config") {
|
||||
s.db.Migrator().DropColumn(&model.User{}, "chat_config")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.ChatModel{}, "category") {
|
||||
s.db.Migrator().DropColumn(&model.ChatModel{}, "category")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.ChatModel{}, "description") {
|
||||
s.db.Migrator().DropColumn(&model.ChatModel{}, "description")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.Product{}, "discount") {
|
||||
s.db.Migrator().DropColumn(&model.Product{}, "discount")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.Product{}, "days") {
|
||||
s.db.Migrator().DropColumn(&model.Product{}, "days")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.Product{}, "app_url") {
|
||||
s.db.Migrator().DropColumn(&model.Product{}, "app_url")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.Product{}, "url") {
|
||||
s.db.Migrator().DropColumn(&model.Product{}, "url")
|
||||
}
|
||||
}
|
||||
|
||||
// 迁移配置数据
|
||||
func (s *MigrationService) MigrateConfig(config *types.AppConfig) error {
|
||||
|
||||
logger.Info("开始迁移配置到数据库...")
|
||||
|
||||
// 迁移支付配置
|
||||
if err := s.migratePaymentConfig(config); err != nil {
|
||||
logger.Errorf("迁移支付配置失败: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 迁移存储配置
|
||||
if err := s.migrateStorageConfig(config); err != nil {
|
||||
logger.Errorf("迁移存储配置失败: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 迁移通信配置
|
||||
if err := s.migrateCommunicationConfig(config); err != nil {
|
||||
logger.Errorf("迁移通信配置失败: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 迁移配置内容
|
||||
if err := s.MigrateConfigContent(); err != nil {
|
||||
logger.Errorf("迁移配置内容失败: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("配置迁移完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 迁移支付配置
|
||||
func (s *MigrationService) migratePaymentConfig(config *types.AppConfig) error {
|
||||
|
||||
paymentConfig := types.PaymentConfig{
|
||||
Alipay: config.AlipayConfig,
|
||||
Epay: config.GeekPayConfig,
|
||||
WxPay: config.WechatPayConfig,
|
||||
}
|
||||
if err := s.saveConfig(types.ConfigKeyPayment, paymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 迁移存储配置
|
||||
func (s *MigrationService) migrateStorageConfig(config *types.AppConfig) error {
|
||||
|
||||
ossConfig := types.OSSConfig{
|
||||
Active: config.OSS.Active,
|
||||
Local: config.OSS.Local,
|
||||
Minio: config.OSS.Minio,
|
||||
QiNiu: config.OSS.QiNiu,
|
||||
AliYun: config.OSS.AliYun,
|
||||
}
|
||||
return s.saveConfig(types.ConfigKeyOss, ossConfig)
|
||||
}
|
||||
|
||||
// 迁移通信配置
|
||||
func (s *MigrationService) migrateCommunicationConfig(config *types.AppConfig) error {
|
||||
// SMTP配置
|
||||
smtpConfig := map[string]any{
|
||||
"use_tls": config.SmtpConfig.UseTls,
|
||||
"host": config.SmtpConfig.Host,
|
||||
"port": config.SmtpConfig.Port,
|
||||
"app_name": config.SmtpConfig.AppName,
|
||||
"from": config.SmtpConfig.From,
|
||||
"password": config.SmtpConfig.Password,
|
||||
}
|
||||
if err := s.saveConfig(types.ConfigKeySmtp, smtpConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 短信配置
|
||||
smsConfig := map[string]any{
|
||||
"active": strings.ToLower(config.SMS.Active),
|
||||
"aliyun": map[string]any{
|
||||
"access_key": config.SMS.Ali.AccessKey,
|
||||
"access_secret": config.SMS.Ali.AccessSecret,
|
||||
"sign": config.SMS.Ali.Sign,
|
||||
"code_temp_id": config.SMS.Ali.CodeTempId,
|
||||
},
|
||||
"bao": map[string]any{
|
||||
"username": config.SMS.Bao.Username,
|
||||
"password": config.SMS.Bao.Password,
|
||||
"sign": config.SMS.Bao.Sign,
|
||||
"code_template": config.SMS.Bao.CodeTemplate,
|
||||
},
|
||||
}
|
||||
return s.saveConfig(types.ConfigKeySms, smsConfig)
|
||||
}
|
||||
|
||||
// 保存配置到数据库
|
||||
func (s *MigrationService) saveConfig(key string, config any) error {
|
||||
// 检查是否已存在
|
||||
var existingConfig model.Config
|
||||
if err := s.db.Where("name", key).First(&existingConfig).Error; err == nil {
|
||||
// 配置已存在,跳过
|
||||
logger.Infof("配置 %s 已存在,跳过迁移", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 序列化配置
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
newConfig := model.Config{
|
||||
Name: key,
|
||||
Value: string(configJSON),
|
||||
}
|
||||
if err := s.db.Create(&newConfig).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Infof("成功迁移配置 %s", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -67,25 +67,6 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
|
||||
if err == nil {
|
||||
task.Prompt = content
|
||||
} else {
|
||||
logger.Warnf("error with translate prompt: %v", err)
|
||||
}
|
||||
}
|
||||
// translate negative prompt
|
||||
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), task.TranslateModelId)
|
||||
if err == nil {
|
||||
task.NegPrompt = content
|
||||
} else {
|
||||
logger.Warnf("error with translate prompt: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// use fast mode as default
|
||||
if task.Mode == "" {
|
||||
task.Mode = "fast"
|
||||
|
||||
33
api/service/moderation/baidu_moderation.go
Normal file
33
api/service/moderation/baidu_moderation.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package moderation
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"geekai/core/types"
|
||||
)
|
||||
|
||||
type BaiduAIModeration struct {
|
||||
config types.ModerationBaiduConfig
|
||||
}
|
||||
|
||||
func NewBaiduAIModeration(sysConfig *types.SystemConfig) *BaiduAIModeration {
|
||||
return &BaiduAIModeration{
|
||||
config: sysConfig.Moderation.Baidu,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BaiduAIModeration) UpdateConfig(config types.ModerationBaiduConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s *BaiduAIModeration) Moderate(text string) (types.ModerationResult, error) {
|
||||
return types.ModerationResult{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
var _ Service = (*BaiduAIModeration)(nil)
|
||||
58
api/service/moderation/gitee_moderation.go
Normal file
58
api/service/moderation/gitee_moderation.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package moderation
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"geekai/core/types"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type GiteeAIModeration struct {
|
||||
config types.ModerationGiteeConfig
|
||||
apiURL string
|
||||
}
|
||||
|
||||
func NewGiteeAIModeration(sysConfig *types.SystemConfig) *GiteeAIModeration {
|
||||
return &GiteeAIModeration{
|
||||
config: sysConfig.Moderation.Gitee,
|
||||
apiURL: "https://ai.gitee.com/v1/moderations",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GiteeAIModeration) UpdateConfig(config types.ModerationGiteeConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
type GiteeAIModerationResult struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Results []types.ModerationResult `json:"results"`
|
||||
}
|
||||
|
||||
func (s *GiteeAIModeration) Moderate(text string) (types.ModerationResult, error) {
|
||||
|
||||
body := map[string]any{
|
||||
"input": text,
|
||||
"model": s.config.Model,
|
||||
}
|
||||
var res GiteeAIModerationResult
|
||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+s.config.ApiKey).SetBody(body).SetSuccessResult(&res).Post(s.apiURL)
|
||||
if err != nil {
|
||||
return types.ModerationResult{}, err
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
return types.ModerationResult{}, errors.New(r.String())
|
||||
}
|
||||
|
||||
return res.Results[0], nil
|
||||
}
|
||||
|
||||
var _ Service = (*GiteeAIModeration)(nil)
|
||||
58
api/service/moderation/moderation_manager.go
Normal file
58
api/service/moderation/moderation_manager.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package moderation
|
||||
|
||||
import (
|
||||
"geekai/core/types"
|
||||
|
||||
logger2 "geekai/logger"
|
||||
)
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type Service interface {
|
||||
Moderate(text string) (types.ModerationResult, error)
|
||||
}
|
||||
|
||||
type ServiceManager struct {
|
||||
gitee *GiteeAIModeration
|
||||
baidu *BaiduAIModeration
|
||||
tencent *TencentAIModeration
|
||||
active string
|
||||
}
|
||||
|
||||
func NewServiceManager(gitee *GiteeAIModeration, baidu *BaiduAIModeration, tencent *TencentAIModeration) *ServiceManager {
|
||||
return &ServiceManager{
|
||||
gitee: gitee,
|
||||
baidu: baidu,
|
||||
tencent: tencent,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServiceManager) GetService() Service {
|
||||
switch s.active {
|
||||
case types.ModerationBaidu:
|
||||
return s.baidu
|
||||
case types.ModerationTencent:
|
||||
return s.tencent
|
||||
default:
|
||||
return s.gitee
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServiceManager) UpdateConfig(config types.ModerationConfig) {
|
||||
switch config.Active {
|
||||
case types.ModerationGitee:
|
||||
s.gitee.UpdateConfig(config.Gitee)
|
||||
case types.ModerationBaidu:
|
||||
s.baidu.UpdateConfig(config.Baidu)
|
||||
case types.ModerationTencent:
|
||||
s.tencent.UpdateConfig(config.Tencent)
|
||||
}
|
||||
s.active = config.Active
|
||||
}
|
||||
33
api/service/moderation/tencent_moderation.go
Normal file
33
api/service/moderation/tencent_moderation.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package moderation
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"geekai/core/types"
|
||||
)
|
||||
|
||||
type TencentAIModeration struct {
|
||||
config types.ModerationTencentConfig
|
||||
}
|
||||
|
||||
func NewTencentAIModeration(sysConfig *types.SystemConfig) *TencentAIModeration {
|
||||
return &TencentAIModeration{
|
||||
config: sysConfig.Moderation.Tencent,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TencentAIModeration) UpdateConfig(config types.ModerationTencentConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s *TencentAIModeration) Moderate(text string) (types.ModerationResult, error) {
|
||||
return types.ModerationResult{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
var _ Service = (*TencentAIModeration)(nil)
|
||||
@@ -23,35 +23,35 @@ import (
|
||||
)
|
||||
|
||||
type AliYunOss struct {
|
||||
config *types.AliYunOssConfig
|
||||
config types.AliYunOssConfig
|
||||
bucket *oss.Bucket
|
||||
proxyURL string
|
||||
}
|
||||
|
||||
func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) {
|
||||
config := &appConfig.OSS.AliYun
|
||||
// 创建 OSS 客户端
|
||||
func NewAliYunOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*AliYunOss, error) {
|
||||
s := &AliYunOss{
|
||||
proxyURL: appConfig.ProxyURL,
|
||||
}
|
||||
err := s.UpdateConfig(sysConfig.OSS.AliYun)
|
||||
if err != nil {
|
||||
logger.Warnf("阿里云OSS初始化失败: %v", err)
|
||||
}
|
||||
return s, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *AliYunOss) UpdateConfig(config types.AliYunOssConfig) error {
|
||||
client, err := oss.New(config.Endpoint, config.AccessKey, config.AccessSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取存储空间
|
||||
bucket, err := client.Bucket(config.Bucket)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if config.SubDir == "" {
|
||||
config.SubDir = "gpt"
|
||||
}
|
||||
|
||||
return &AliYunOss{
|
||||
config: config,
|
||||
bucket: bucket,
|
||||
proxyURL: appConfig.ProxyURL,
|
||||
}, nil
|
||||
|
||||
s.bucket = bucket
|
||||
s.config = config
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
@@ -68,7 +68,7 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
defer src.Close()
|
||||
|
||||
fileExt := filepath.Ext(file.Filename)
|
||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
|
||||
// 上传文件
|
||||
err = s.bucket.PutObject(objectKey, src)
|
||||
if err != nil {
|
||||
@@ -102,7 +102,7 @@ func (s AliYunOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
||||
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
|
||||
// 上传文件字节数据
|
||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
|
||||
if err != nil {
|
||||
@@ -116,7 +116,7 @@ func (s AliYunOss) PutBase64(base64Img string) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
|
||||
// 上传文件字节数据
|
||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
|
||||
if err != nil {
|
||||
@@ -128,8 +128,7 @@ func (s AliYunOss) PutBase64(base64Img string) (string, error) {
|
||||
func (s AliYunOss) Delete(fileURL string) error {
|
||||
var objectKey string
|
||||
if strings.HasPrefix(fileURL, "http") {
|
||||
filename := filepath.Base(fileURL)
|
||||
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
|
||||
objectKey = filepath.Base(fileURL)
|
||||
} else {
|
||||
objectKey = fileURL
|
||||
}
|
||||
|
||||
@@ -21,17 +21,21 @@ import (
|
||||
)
|
||||
|
||||
type LocalStorage struct {
|
||||
config *types.LocalStorageConfig
|
||||
config types.LocalStorageConfig
|
||||
proxyURL string
|
||||
}
|
||||
|
||||
func NewLocalStorage(config *types.AppConfig) LocalStorage {
|
||||
return LocalStorage{
|
||||
config: &config.OSS.Local,
|
||||
proxyURL: config.ProxyURL,
|
||||
func NewLocalStorage(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *LocalStorage {
|
||||
return &LocalStorage{
|
||||
config: sysConfig.OSS.Local,
|
||||
proxyURL: appConfig.ProxyURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LocalStorage) UpdateConfig(config types.LocalStorageConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
file, err := ctx.FormFile(name)
|
||||
if err != nil {
|
||||
|
||||
@@ -24,24 +24,32 @@ import (
|
||||
)
|
||||
|
||||
type MiniOss struct {
|
||||
config *types.MiniOssConfig
|
||||
config types.MiniOssConfig
|
||||
client *minio.Client
|
||||
proxyURL string
|
||||
}
|
||||
|
||||
func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
|
||||
config := &appConfig.OSS.Minio
|
||||
func NewMiniOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*MiniOss, error) {
|
||||
|
||||
s := &MiniOss{proxyURL: appConfig.ProxyURL}
|
||||
err := s.UpdateConfig(sysConfig.OSS.Minio)
|
||||
if err != nil {
|
||||
logger.Warnf("MinioOSS初始化失败: %v", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *MiniOss) UpdateConfig(config types.MiniOssConfig) error {
|
||||
minioClient, err := minio.New(config.Endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(config.AccessKey, config.AccessSecret, ""),
|
||||
Secure: config.UseSSL,
|
||||
})
|
||||
if err != nil {
|
||||
return MiniOss{}, err
|
||||
return err
|
||||
}
|
||||
if config.SubDir == "" {
|
||||
config.SubDir = "gpt"
|
||||
}
|
||||
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
|
||||
s.config = config
|
||||
s.client = minioClient
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
@@ -62,7 +70,7 @@ func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string,
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
||||
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
|
||||
info, err := s.client.PutObject(
|
||||
context.Background(),
|
||||
s.config.Bucket,
|
||||
@@ -89,7 +97,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
defer fileReader.Close()
|
||||
|
||||
fileExt := filepath.Ext(file.Filename)
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
|
||||
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
|
||||
ContentType: file.Header.Get("Body-Type"),
|
||||
})
|
||||
@@ -111,7 +119,7 @@ func (s MiniOss) PutBase64(base64Img string) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
|
||||
info, err := s.client.PutObject(
|
||||
context.Background(),
|
||||
s.config.Bucket,
|
||||
@@ -128,8 +136,7 @@ func (s MiniOss) PutBase64(base64Img string) (string, error) {
|
||||
func (s MiniOss) Delete(fileURL string) error {
|
||||
var objectKey string
|
||||
if strings.HasPrefix(fileURL, "http") {
|
||||
filename := filepath.Base(fileURL)
|
||||
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
|
||||
objectKey = filepath.Base(fileURL)
|
||||
} else {
|
||||
objectKey = fileURL
|
||||
}
|
||||
|
||||
@@ -24,18 +24,24 @@ import (
|
||||
"github.com/qiniu/go-sdk/v7/storage"
|
||||
)
|
||||
|
||||
type QinNiuOss struct {
|
||||
config *types.QiNiuOssConfig
|
||||
type QiNiuOss struct {
|
||||
config types.QiNiuOssConfig
|
||||
mac *qbox.Mac
|
||||
putPolicy storage.PutPolicy
|
||||
uploader *storage.FormUploader
|
||||
manager *storage.BucketManager
|
||||
bucket *storage.BucketManager
|
||||
proxyURL string
|
||||
}
|
||||
|
||||
func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
|
||||
config := &appConfig.OSS.QiNiu
|
||||
// build storage uploader
|
||||
func NewQiNiuOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *QiNiuOss {
|
||||
s := &QiNiuOss{
|
||||
proxyURL: appConfig.ProxyURL,
|
||||
}
|
||||
s.UpdateConfig(sysConfig.OSS.QiNiu)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *QiNiuOss) UpdateConfig(config types.QiNiuOssConfig) {
|
||||
zone, ok := storage.GetRegionByID(storage.RegionID(config.Zone))
|
||||
if !ok {
|
||||
zone = storage.ZoneHuanan
|
||||
@@ -47,20 +53,13 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
|
||||
putPolicy := storage.PutPolicy{
|
||||
Scope: config.Bucket,
|
||||
}
|
||||
if config.SubDir == "" {
|
||||
config.SubDir = "gpt"
|
||||
}
|
||||
return QinNiuOss{
|
||||
config: config,
|
||||
mac: mac,
|
||||
putPolicy: putPolicy,
|
||||
uploader: formUploader,
|
||||
manager: storage.NewBucketManager(mac, &storeConfig),
|
||||
proxyURL: appConfig.ProxyURL,
|
||||
}
|
||||
s.config = config
|
||||
s.mac = mac
|
||||
s.putPolicy = putPolicy
|
||||
s.uploader = formUploader
|
||||
s.bucket = storage.NewBucketManager(mac, &storeConfig)
|
||||
}
|
||||
|
||||
func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
func (s QiNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
// 解析表单
|
||||
file, err := ctx.FormFile(name)
|
||||
if err != nil {
|
||||
@@ -74,7 +73,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
defer src.Close()
|
||||
|
||||
fileExt := filepath.Ext(file.Filename)
|
||||
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
key := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
|
||||
// 上传文件
|
||||
ret := storage.PutRet{}
|
||||
extra := storage.PutExtra{}
|
||||
@@ -93,7 +92,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
|
||||
}
|
||||
|
||||
func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
func (s QiNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
@@ -111,7 +110,7 @@ func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
||||
key := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
|
||||
ret := storage.PutRet{}
|
||||
extra := storage.PutExtra{}
|
||||
// 上传文件字节数据
|
||||
@@ -122,12 +121,12 @@ func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
|
||||
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
||||
}
|
||||
|
||||
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
|
||||
func (s QiNiuOss) PutBase64(base64Img string) (string, error) {
|
||||
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
|
||||
ret := storage.PutRet{}
|
||||
extra := storage.PutExtra{}
|
||||
// 上传文件字节数据
|
||||
@@ -138,16 +137,15 @@ func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
|
||||
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
||||
}
|
||||
|
||||
func (s QinNiuOss) Delete(fileURL string) error {
|
||||
func (s QiNiuOss) Delete(fileURL string) error {
|
||||
var objectKey string
|
||||
if strings.HasPrefix(fileURL, "http") {
|
||||
filename := filepath.Base(fileURL)
|
||||
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
|
||||
objectKey = filepath.Base(fileURL)
|
||||
} else {
|
||||
objectKey = fileURL
|
||||
}
|
||||
|
||||
return s.manager.Delete(s.config.Bucket, objectKey)
|
||||
return s.bucket.Delete(s.config.Bucket, objectKey)
|
||||
}
|
||||
|
||||
var _ Uploader = QinNiuOss{}
|
||||
var _ Uploader = QiNiuOss{}
|
||||
|
||||
@@ -9,10 +9,10 @@ package oss
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
const Local = "LOCAL"
|
||||
const Minio = "MINIO"
|
||||
const QiNiu = "QINIU"
|
||||
const AliYun = "ALIYUN"
|
||||
const Local = "local"
|
||||
const Minio = "minio"
|
||||
const QiNiu = "qiniu"
|
||||
const AliYun = "aliyun"
|
||||
|
||||
type File struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -9,45 +9,58 @@ package oss
|
||||
|
||||
import (
|
||||
"geekai/core/types"
|
||||
"strings"
|
||||
|
||||
logger2 "geekai/logger"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type UploaderManager struct {
|
||||
handler Uploader
|
||||
local *LocalStorage
|
||||
aliyun *AliYunOss
|
||||
mini *MiniOss
|
||||
qiniu *QiNiuOss
|
||||
active string
|
||||
}
|
||||
|
||||
func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) {
|
||||
active := Local
|
||||
if config.OSS.Active != "" {
|
||||
active = strings.ToUpper(config.OSS.Active)
|
||||
}
|
||||
var handler Uploader
|
||||
switch active {
|
||||
case Local:
|
||||
handler = NewLocalStorage(config)
|
||||
break
|
||||
case Minio:
|
||||
client, err := NewMiniOss(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handler = client
|
||||
break
|
||||
case QiNiu:
|
||||
handler = NewQiNiuOss(config)
|
||||
break
|
||||
case AliYun:
|
||||
client, err := NewAliYunOss(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handler = client
|
||||
break
|
||||
func NewUploaderManager(sysConfig *types.SystemConfig, local *LocalStorage, aliyun *AliYunOss, mini *MiniOss, qiniu *QiNiuOss) (*UploaderManager, error) {
|
||||
if sysConfig.OSS.Active == "" {
|
||||
sysConfig.OSS.Active = Local
|
||||
}
|
||||
|
||||
return &UploaderManager{handler: handler}, nil
|
||||
return &UploaderManager{
|
||||
active: sysConfig.OSS.Active,
|
||||
local: local,
|
||||
aliyun: aliyun,
|
||||
mini: mini,
|
||||
qiniu: qiniu,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *UploaderManager) GetUploadHandler() Uploader {
|
||||
return m.handler
|
||||
switch m.active {
|
||||
case Local:
|
||||
return m.local
|
||||
case AliYun:
|
||||
return m.aliyun
|
||||
case Minio:
|
||||
return m.mini
|
||||
case QiNiu:
|
||||
return m.qiniu
|
||||
}
|
||||
return m.local
|
||||
}
|
||||
|
||||
func (m *UploaderManager) UpdateConfig(config types.OSSConfig) {
|
||||
switch config.Active {
|
||||
case Local:
|
||||
m.local.UpdateConfig(config.Local)
|
||||
case AliYun:
|
||||
m.aliyun.UpdateConfig(config.AliYun)
|
||||
case Minio:
|
||||
m.mini.UpdateConfig(config.Minio)
|
||||
case QiNiu:
|
||||
m.qiniu.UpdateConfig(config.QiNiu)
|
||||
}
|
||||
m.active = config.Active
|
||||
}
|
||||
|
||||
@@ -12,129 +12,98 @@ import (
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"github.com/go-pay/gopay"
|
||||
"github.com/go-pay/gopay/alipay"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/go-pay/gopay"
|
||||
"github.com/go-pay/gopay/alipay"
|
||||
)
|
||||
|
||||
type AlipayService struct {
|
||||
config *types.AlipayConfig
|
||||
client *alipay.Client
|
||||
config *types.AlipayConfig
|
||||
}
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
|
||||
config := appConfig.AlipayConfig
|
||||
func NewAlipayService(sysConfig *types.SystemConfig) (*AlipayService, error) {
|
||||
config := sysConfig.Payment.Alipay
|
||||
if !config.Enabled {
|
||||
logger.Info("Disabled Alipay service")
|
||||
return nil, nil
|
||||
logger.Debug("Disabled Alipay service")
|
||||
}
|
||||
priKey, err := readKey(config.PrivateKey)
|
||||
|
||||
service := &AlipayService{config: &config}
|
||||
if config.Enabled {
|
||||
err := service.UpdateConfig(&config)
|
||||
if err != nil {
|
||||
logger.Errorf("支付宝服务初始化失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *AlipayService) UpdateConfig(config *types.AlipayConfig) error {
|
||||
client, err := alipay.NewClient(config.AppId, config.PrivateKey, !config.SandBox)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error with read App Private key: %v", err)
|
||||
return fmt.Errorf("error with initialize alipay service: %v", err)
|
||||
}
|
||||
|
||||
client, err := alipay.NewClient(config.AppId, priKey, !config.SandBox)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error with initialize alipay service: %v", err)
|
||||
s.client = client
|
||||
s.config = config
|
||||
if os.Getenv("GEEKAI_DEBUG") == "true" {
|
||||
logger.Info("Alipay Debug mode is enabled")
|
||||
client.DebugSwitch = gopay.DebugOn
|
||||
}
|
||||
|
||||
//client.DebugSwitch = gopay.DebugOn // 开启调试模式
|
||||
client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间
|
||||
SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
|
||||
SetSignType(alipay.RSA2) // 设置签名类型,不设置默认 RSA2
|
||||
|
||||
if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil {
|
||||
return nil, fmt.Errorf("error with load payment public key: %v", err)
|
||||
}
|
||||
|
||||
return &AlipayService{config: &config, client: client}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
type AlipayParams struct {
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
Subject string `json:"subject"`
|
||||
TotalFee string `json:"total_fee"`
|
||||
ReturnURL string `json:"return_url"`
|
||||
NotifyURL string `json:"notify_url"`
|
||||
}
|
||||
|
||||
func (s *AlipayService) PayMobile(params AlipayParams) (string, error) {
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("subject", params.Subject)
|
||||
bm.Set("out_trade_no", params.OutTradeNo)
|
||||
bm.Set("quit_url", params.ReturnURL)
|
||||
bm.Set("total_amount", params.TotalFee)
|
||||
bm.Set("product_code", "QUICK_WAP_WAY")
|
||||
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradeWapPay(context.Background(), bm)
|
||||
}
|
||||
|
||||
func (s *AlipayService) PayPC(params AlipayParams) (string, error) {
|
||||
func (s *AlipayService) Pay(params PayRequest) (string, error) {
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("subject", params.Subject)
|
||||
bm.Set("out_trade_no", params.OutTradeNo)
|
||||
bm.Set("total_amount", params.TotalFee)
|
||||
bm.Set("product_code", "FAST_INSTANT_TRADE_PAY")
|
||||
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradePagePay(context.Background(), bm)
|
||||
return s.client.TradeWapPay(context.Background(), bm)
|
||||
}
|
||||
|
||||
func (s *AlipayService) Query(outTradeNo string) (OrderInfo, error) {
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("out_trade_no", outTradeNo)
|
||||
rsp, err := s.client.TradeQuery(context.Background(), bm)
|
||||
if err != nil {
|
||||
return OrderInfo{}, fmt.Errorf("error with trade query: %v", err)
|
||||
}
|
||||
|
||||
switch rsp.Response.TradeStatus {
|
||||
case "TRADE_SUCCESS":
|
||||
logger.Debugf("支付宝查询订单成功:%+v", rsp.Response)
|
||||
return OrderInfo{
|
||||
OutTradeNo: rsp.Response.OutTradeNo,
|
||||
TradeId: rsp.Response.TradeNo,
|
||||
Amount: rsp.Response.TotalAmount,
|
||||
Status: Success,
|
||||
PayTime: rsp.Response.SendPayDate,
|
||||
}, nil
|
||||
case "TRADE_CLOSED":
|
||||
return OrderInfo{Status: Closed}, nil
|
||||
default:
|
||||
return OrderInfo{}, fmt.Errorf("error with trade query: %v", rsp.Response.TradeStatus)
|
||||
}
|
||||
}
|
||||
|
||||
// TradeVerify 交易验证
|
||||
func (s *AlipayService) TradeVerify(request *http.Request) NotifyVo {
|
||||
func (s *AlipayService) TradeVerify(request *http.Request) (OrderInfo, error) {
|
||||
notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法
|
||||
if err != nil {
|
||||
return NotifyVo{
|
||||
Status: Failure,
|
||||
Message: "error with parse notify request: " + err.Error(),
|
||||
}
|
||||
return OrderInfo{}, fmt.Errorf("error with parse notify request: %v", err)
|
||||
}
|
||||
|
||||
_, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq)
|
||||
if err != nil {
|
||||
return NotifyVo{
|
||||
Status: Failure,
|
||||
Message: "error with verify sign: " + err.Error(),
|
||||
}
|
||||
return OrderInfo{}, fmt.Errorf("error with verify sign: %v", err)
|
||||
}
|
||||
|
||||
return s.TradeQuery(request.Form.Get("out_trade_no"))
|
||||
return s.Query(request.Form.Get("out_trade_no"))
|
||||
}
|
||||
|
||||
func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo {
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("out_trade_no", outTradeNo)
|
||||
|
||||
//查询订单
|
||||
rsp, err := s.client.TradeQuery(context.Background(), bm)
|
||||
if err != nil {
|
||||
return NotifyVo{
|
||||
Status: Failure,
|
||||
Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
if rsp.Response.TradeStatus == "TRADE_SUCCESS" {
|
||||
return NotifyVo{
|
||||
Status: Success,
|
||||
OutTradeNo: rsp.Response.OutTradeNo,
|
||||
TradeId: rsp.Response.TradeNo,
|
||||
Amount: rsp.Response.TotalAmount,
|
||||
Subject: rsp.Response.Subject,
|
||||
Message: "OK",
|
||||
}
|
||||
} else {
|
||||
return NotifyVo{
|
||||
Status: Failure,
|
||||
Message: "异步查询验证订单信息发生错误" + outTradeNo,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readKey(filename string) (string, error) {
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
var _ PayService = (*AlipayService)(nil)
|
||||
|
||||
@@ -22,41 +22,30 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// GeekPayService Geek 支付服务
|
||||
type GeekPayService struct {
|
||||
config *types.GeekPayConfig
|
||||
// EPayService 易支付服务
|
||||
type EPayService struct {
|
||||
config *types.EpayConfig
|
||||
}
|
||||
|
||||
func NewJPayService(appConfig *types.AppConfig) *GeekPayService {
|
||||
return &GeekPayService{
|
||||
config: &appConfig.GeekPayConfig,
|
||||
func NewEPayService(sysConfig *types.SystemConfig) *EPayService {
|
||||
return &EPayService{
|
||||
config: &sysConfig.Payment.Epay,
|
||||
}
|
||||
}
|
||||
|
||||
type GeekPayParams struct {
|
||||
Method string `json:"method"` // 接口类型
|
||||
Device string `json:"device"` // 设备类型
|
||||
Type string `json:"type"` // 支付方式
|
||||
OutTradeNo string `json:"out_trade_no"` // 商户订单号
|
||||
Name string `json:"name"` // 商品名称
|
||||
Money string `json:"money"` // 商品金额
|
||||
ClientIP string `json:"clientip"` //用户IP地址
|
||||
SubOpenId string `json:"sub_openid"` // 微信用户 openid,仅小程序支付需要
|
||||
SubAppId string `json:"sub_appid"` // 小程序 AppId,仅小程序支付需要
|
||||
NotifyURL string `json:"notify_url"`
|
||||
ReturnURL string `json:"return_url"`
|
||||
func (s *EPayService) UpdateConfig(config *types.EpayConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
// Pay 支付订单
|
||||
func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
|
||||
func (s *EPayService) Pay(params PayRequest) (string, error) {
|
||||
p := map[string]string{
|
||||
"pid": s.config.AppId,
|
||||
//"method": params.Method,
|
||||
"pid": s.config.AppId,
|
||||
"device": params.Device,
|
||||
"type": params.Type,
|
||||
"type": params.PayWay,
|
||||
"out_trade_no": params.OutTradeNo,
|
||||
"name": params.Name,
|
||||
"money": params.Money,
|
||||
"name": params.Subject,
|
||||
"money": params.TotalFee,
|
||||
"clientip": params.ClientIP,
|
||||
"notify_url": params.NotifyURL,
|
||||
"return_url": params.ReturnURL,
|
||||
@@ -64,10 +53,21 @@ func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
|
||||
}
|
||||
p["sign"] = s.Sign(p)
|
||||
p["sign_type"] = "MD5"
|
||||
return s.sendRequest(s.config.ApiURL, p)
|
||||
resp, err := s.sendRequest(s.config.ApiURL, p)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp.Code != 1 {
|
||||
return "", errors.New(resp.Msg)
|
||||
}
|
||||
if resp.PayURL != "" {
|
||||
return resp.PayURL, nil
|
||||
} else {
|
||||
return resp.QrCode, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GeekPayService) Sign(params map[string]string) string {
|
||||
func (s *EPayService) Sign(params map[string]string) string {
|
||||
// 按字母顺序排序参数
|
||||
var keys []string
|
||||
for k := range params {
|
||||
@@ -100,7 +100,7 @@ type GeekPayResp struct {
|
||||
UrlScheme string `json:"urlscheme"` // 小程序跳转支付链接
|
||||
}
|
||||
|
||||
func (s *GeekPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
|
||||
func (s *EPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
|
||||
form := url.Values{}
|
||||
for k, v := range params {
|
||||
form.Add(k, v)
|
||||
@@ -137,3 +137,61 @@ func (s *GeekPayService) sendRequest(endpoint string, params map[string]string)
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func (s *EPayService) Query(outTradeNo string) (OrderInfo, error) {
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("act", "order")
|
||||
params.Set("pid", s.config.AppId)
|
||||
params.Set("key", s.config.PrivateKey)
|
||||
params.Set("out_trade_no", outTradeNo)
|
||||
|
||||
apiURL := fmt.Sprintf("%s/api.php?%s", s.config.ApiURL, params.Encode())
|
||||
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
client := &http.Client{Transport: tr}
|
||||
resp, err := client.Get(apiURL)
|
||||
if err != nil {
|
||||
return OrderInfo{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return OrderInfo{}, err
|
||||
}
|
||||
logger.Debugf(string(body))
|
||||
|
||||
var result struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Status string `json:"status"`
|
||||
Name string `json:"name"`
|
||||
Money string `json:"money"`
|
||||
EndTime string `json:"endtime"`
|
||||
TradeNo string `json:"trade_no"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return OrderInfo{}, errors.New("订单查询响应解析失败")
|
||||
}
|
||||
if result.Code != 1 {
|
||||
return OrderInfo{}, errors.New(result.Msg)
|
||||
}
|
||||
logger.Debugf("订单信息:%+v", result)
|
||||
orderInfo := OrderInfo{
|
||||
OutTradeNo: outTradeNo,
|
||||
TradeId: result.TradeNo,
|
||||
Amount: result.Money,
|
||||
PayTime: result.EndTime,
|
||||
}
|
||||
if result.Status == "1" {
|
||||
orderInfo.Status = Success
|
||||
} else {
|
||||
orderInfo.Status = Failure
|
||||
}
|
||||
return orderInfo, nil
|
||||
}
|
||||
|
||||
var _ PayService = (*EPayService)(nil)
|
||||
@@ -1,171 +0,0 @@
|
||||
package payment
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type HuPiPayService struct {
|
||||
appId string
|
||||
appSecret string
|
||||
apiURL string
|
||||
}
|
||||
|
||||
func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
|
||||
return &HuPiPayService{
|
||||
appId: config.HuPiPayConfig.AppId,
|
||||
appSecret: config.HuPiPayConfig.AppSecret,
|
||||
apiURL: config.HuPiPayConfig.ApiURL,
|
||||
}
|
||||
}
|
||||
|
||||
type HuPiPayParams struct {
|
||||
AppId string `json:"appid"`
|
||||
Version string `json:"version"`
|
||||
TradeOrderId string `json:"trade_order_id"`
|
||||
TotalFee string `json:"total_fee"`
|
||||
Title string `json:"title"`
|
||||
NotifyURL string `json:"notify_url"`
|
||||
ReturnURL string `json:"return_url"`
|
||||
WapName string `json:"wap_name"`
|
||||
CallbackURL string `json:"callback_url"`
|
||||
Time string `json:"time"`
|
||||
NonceStr string `json:"nonce_str"`
|
||||
Type string `json:"type"`
|
||||
WapUrl string `json:"wap_url"`
|
||||
}
|
||||
|
||||
type HuPiPayResp struct {
|
||||
Openid interface{} `json:"openid"`
|
||||
UrlQrcode string `json:"url_qrcode"`
|
||||
URL string `json:"url"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg,omitempty"`
|
||||
}
|
||||
|
||||
// Pay 执行支付请求操作
|
||||
func (s *HuPiPayService) Pay(params HuPiPayParams) (HuPiPayResp, error) {
|
||||
data := url.Values{}
|
||||
simple := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
params.AppId = s.appId
|
||||
params.Time = simple
|
||||
params.NonceStr = simple
|
||||
encode := utils.JsonEncode(params)
|
||||
m := make(map[string]string)
|
||||
_ = utils.JsonDecode(encode, &m)
|
||||
for k, v := range m {
|
||||
data.Add(k, fmt.Sprintf("%v", v))
|
||||
}
|
||||
// 生成签名
|
||||
data.Add("hash", s.Sign(data))
|
||||
// 发送支付请求
|
||||
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
|
||||
resp, err := http.PostForm(apiURL, data)
|
||||
if err != nil {
|
||||
return HuPiPayResp{}, fmt.Errorf("error with requst api: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
all, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return HuPiPayResp{}, fmt.Errorf("error with reading response: %v", err)
|
||||
}
|
||||
|
||||
var res HuPiPayResp
|
||||
err = utils.JsonDecode(string(all), &res)
|
||||
if err != nil {
|
||||
return HuPiPayResp{}, fmt.Errorf("error with decode payment result: %v", err)
|
||||
}
|
||||
|
||||
if res.ErrCode != 0 {
|
||||
return HuPiPayResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Sign 签名方法
|
||||
func (s *HuPiPayService) Sign(params url.Values) string {
|
||||
params.Del(`Sign`)
|
||||
var keys = make([]string, 0, 0)
|
||||
for key := range params {
|
||||
if params.Get(key) != `` {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
var pList = make([]string, 0, 0)
|
||||
for _, key := range keys {
|
||||
var value = strings.TrimSpace(params.Get(key))
|
||||
if len(value) > 0 {
|
||||
pList = append(pList, key+"="+value)
|
||||
}
|
||||
}
|
||||
var src = strings.Join(pList, "&")
|
||||
src += s.appSecret
|
||||
|
||||
md5bs := md5.Sum([]byte(src))
|
||||
return hex.EncodeToString(md5bs[:])
|
||||
}
|
||||
|
||||
// Check 校验订单状态
|
||||
func (s *HuPiPayService) Check(outTradeNo string) error {
|
||||
data := url.Values{}
|
||||
data.Add("appid", s.appId)
|
||||
data.Add("out_trade_order", outTradeNo)
|
||||
stamp := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
data.Add("time", stamp)
|
||||
data.Add("nonce_str", stamp)
|
||||
data.Add("hash", s.Sign(data))
|
||||
|
||||
apiURL := fmt.Sprintf("%s/payment/query.html", s.apiURL)
|
||||
resp, err := http.PostForm(apiURL, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with http reqeust: %v", err)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with reading response: %v", err)
|
||||
}
|
||||
|
||||
var r struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
Data struct {
|
||||
Status string `json:"status"`
|
||||
OpenOrderId string `json:"open_order_id"`
|
||||
} `json:"data,omitempty"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
Hash string `json:"hash"`
|
||||
}
|
||||
err = utils.JsonDecode(string(body), &r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with decode response: %v", err)
|
||||
}
|
||||
|
||||
if r.ErrCode == 0 && r.Data.Status == "OD" {
|
||||
return nil
|
||||
} else {
|
||||
logger.Debugf("%+v", r)
|
||||
return errors.New("order not paid:" + r.ErrMsg)
|
||||
}
|
||||
}
|
||||
54
api/service/payment/pay_service.go
Normal file
54
api/service/payment/pay_service.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package payment
|
||||
|
||||
// 支付渠道定义
|
||||
const PayChannelAL = "alipay" // 支付宝
|
||||
const PayChannelWX = "wxpay" // 微信支付
|
||||
const PayChannelEpay = "epay" // 易支付
|
||||
|
||||
// 支付方式
|
||||
const PayWayAL = "alipay"
|
||||
const PayWayWX = "wxpay"
|
||||
|
||||
const (
|
||||
Success = 0
|
||||
Failure = 1
|
||||
Closed = 2
|
||||
)
|
||||
|
||||
type PayRequest struct {
|
||||
OutTradeNo string // 商户订单号
|
||||
Subject string // 商品名称
|
||||
TotalFee string // 商品金额
|
||||
ReturnURL string // 回调地址
|
||||
NotifyURL string // 回调地址
|
||||
|
||||
// 易支付专有参数
|
||||
Method string // 接口类型
|
||||
Device string // 设备类型
|
||||
PayWay string // 支付方式
|
||||
ClientIP string //用户IP地址
|
||||
OpenID string // 用户openid
|
||||
|
||||
}
|
||||
|
||||
type OrderInfo struct {
|
||||
Mchid string // 商户号
|
||||
OutTradeNo string // 商户订单号
|
||||
TradeId string // 交易号
|
||||
Amount string // 金额
|
||||
Status int // 状态 0: 未支付 1: 已支付 2: 已关闭
|
||||
PayTime string // 完成支付时间
|
||||
}
|
||||
|
||||
func (o OrderInfo) Closed() bool {
|
||||
return o.Status == Closed
|
||||
}
|
||||
|
||||
func (o OrderInfo) Success() bool {
|
||||
return o.Status == Success
|
||||
}
|
||||
|
||||
type PayService interface {
|
||||
Pay(params PayRequest) (string, error) // 生成支付链接
|
||||
Query(outTradeNo string) (OrderInfo, error) // 查询订单
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package payment
|
||||
|
||||
type NotifyVo struct {
|
||||
Status int
|
||||
OutTradeNo string // 商户订单号
|
||||
TradeId string // 交易ID
|
||||
Amount string // 交易金额
|
||||
Message string
|
||||
Subject string
|
||||
}
|
||||
|
||||
func (v NotifyVo) Success() bool {
|
||||
return v.Status == Success
|
||||
}
|
||||
|
||||
const (
|
||||
Success = 0
|
||||
Failure = 1
|
||||
)
|
||||
@@ -1,144 +0,0 @@
|
||||
package payment
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"github.com/go-pay/gopay"
|
||||
"github.com/go-pay/gopay/wechat/v3"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WechatPayService struct {
|
||||
config *types.WechatPayConfig
|
||||
client *wechat.ClientV3
|
||||
}
|
||||
|
||||
func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) {
|
||||
config := appConfig.WechatPayConfig
|
||||
if !config.Enabled {
|
||||
logger.Info("Disabled WechatPay service")
|
||||
return nil, nil
|
||||
}
|
||||
priKey, err := readKey(config.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error with read App Private key: %v", err)
|
||||
}
|
||||
|
||||
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, priKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error with initialize WechatPay service: %v", err)
|
||||
}
|
||||
err = client.AutoVerifySign()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error with autoVerifySign: %v", err)
|
||||
}
|
||||
//client.DebugSwitch = gopay.DebugOn
|
||||
|
||||
return &WechatPayService{config: &config, client: client}, nil
|
||||
}
|
||||
|
||||
type WechatPayParams struct {
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
TotalFee int `json:"total_fee"`
|
||||
Subject string `json:"subject"`
|
||||
ClientIP string `json:"client_ip"`
|
||||
ReturnURL string `json:"return_url"`
|
||||
NotifyURL string `json:"notify_url"`
|
||||
}
|
||||
|
||||
func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
|
||||
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||
// 初始化 BodyMap
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("appid", s.config.AppId).
|
||||
Set("mchid", s.config.MchId).
|
||||
Set("description", params.Subject).
|
||||
Set("out_trade_no", params.OutTradeNo).
|
||||
Set("time_expire", expire).
|
||||
Set("notify_url", params.NotifyURL).
|
||||
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||
bm.Set("total", params.TotalFee).
|
||||
Set("currency", "CNY")
|
||||
})
|
||||
|
||||
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
|
||||
}
|
||||
if wxRsp.Code != wechat.Success {
|
||||
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
|
||||
}
|
||||
return wxRsp.Response.CodeUrl, nil
|
||||
}
|
||||
|
||||
func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
|
||||
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||
// 初始化 BodyMap
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("appid", s.config.AppId).
|
||||
Set("mchid", s.config.MchId).
|
||||
Set("description", params.Subject).
|
||||
Set("out_trade_no", params.OutTradeNo).
|
||||
Set("time_expire", expire).
|
||||
Set("notify_url", params.NotifyURL).
|
||||
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||
bm.Set("total", params.TotalFee).
|
||||
Set("currency", "CNY")
|
||||
}).
|
||||
SetBodyMap("scene_info", func(bm gopay.BodyMap) {
|
||||
bm.Set("payer_client_ip", params.ClientIP).
|
||||
SetBodyMap("h5_info", func(bm gopay.BodyMap) {
|
||||
bm.Set("type", "Wap")
|
||||
})
|
||||
})
|
||||
|
||||
wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
|
||||
}
|
||||
if wxRsp.Code != wechat.Success {
|
||||
return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
|
||||
}
|
||||
return wxRsp.Response.H5Url, nil
|
||||
}
|
||||
|
||||
type NotifyResponse struct {
|
||||
Code string `json:"code"`
|
||||
Message string `xml:"message"`
|
||||
}
|
||||
|
||||
// TradeVerify 交易验证
|
||||
func (s *WechatPayService) TradeVerify(request *http.Request) NotifyVo {
|
||||
notifyReq, err := wechat.V3ParseNotify(request)
|
||||
if err != nil {
|
||||
return NotifyVo{Status: 1, Message: fmt.Sprintf("error with client v3 parse notify: %v", err)}
|
||||
}
|
||||
|
||||
// TODO: 这里验签程序有 Bug,一直报错:crypto/rsa: verification error,先暂时取消验签
|
||||
//err = notifyReq.VerifySignByPK(s.client.WxPublicKey())
|
||||
//if err != nil {
|
||||
// return fmt.Errorf("error with client v3 verify sign: %v", err)
|
||||
//}
|
||||
|
||||
// 解密支付密文,验证订单信息
|
||||
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
|
||||
if err != nil {
|
||||
return NotifyVo{Status: Failure, Message: fmt.Sprintf("error with client v3 decrypt: %v", err)}
|
||||
}
|
||||
|
||||
return NotifyVo{
|
||||
Status: Success,
|
||||
OutTradeNo: result.OutTradeNo,
|
||||
TradeId: result.TransactionId,
|
||||
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
|
||||
}
|
||||
}
|
||||
217
api/service/payment/wxpay_service.go
Normal file
217
api/service/payment/wxpay_service.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package payment
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/go-pay/gopay"
|
||||
"github.com/go-pay/gopay/wechat/v3"
|
||||
)
|
||||
|
||||
type WxPayService struct {
|
||||
config *types.WxPayConfig
|
||||
client *wechat.ClientV3
|
||||
}
|
||||
|
||||
func NewWxpayService(sysConfig *types.SystemConfig) (*WxPayService, error) {
|
||||
config := sysConfig.Payment.WxPay
|
||||
if !config.Enabled {
|
||||
logger.Debug("Disabled WechatPay service")
|
||||
}
|
||||
|
||||
service := &WxPayService{config: &config}
|
||||
if config.Enabled {
|
||||
err := service.UpdateConfig(&config)
|
||||
if err != nil {
|
||||
logger.Errorf("微信支付服务初始化失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *WxPayService) UpdateConfig(config *types.WxPayConfig) error {
|
||||
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, config.PrivateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with initialize WechatPay service: %v", err)
|
||||
}
|
||||
err = client.AutoVerifySign()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with autoVerifySign: %v", err)
|
||||
}
|
||||
s.client = client
|
||||
if os.Getenv("GEEKAI_DEBUG") == "true" {
|
||||
logger.Info("WechatPay Debug mode is enabled")
|
||||
client.DebugSwitch = gopay.DebugOn
|
||||
}
|
||||
s.config = config
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WxPayService) Pay(params PayRequest) (string, error) {
|
||||
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||
// 初始化 BodyMap
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("appid", s.config.AppId).
|
||||
Set("mchid", s.config.MchId).
|
||||
Set("description", params.Subject).
|
||||
Set("out_trade_no", params.OutTradeNo).
|
||||
Set("time_expire", expire).
|
||||
Set("notify_url", params.NotifyURL).
|
||||
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||
bm.Set("total", utils.IntValue(params.TotalFee, 0)).
|
||||
Set("currency", "CNY")
|
||||
})
|
||||
logger.Debugf("wxpay params: %+v", bm)
|
||||
if params.Device == "mobile" {
|
||||
bm.SetBodyMap("scene_info", func(bm gopay.BodyMap) {
|
||||
bm.Set("payer_client_ip", params.ClientIP)
|
||||
}).SetBodyMap("payer", func(bm gopay.BodyMap) {
|
||||
bm.Set("openid", params.OpenID)
|
||||
})
|
||||
wxRsp, err := s.client.V3TransactionJsapi(context.Background(), bm)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with client v3 transaction Jsapi: %v", err)
|
||||
}
|
||||
if wxRsp.Code != wechat.Success {
|
||||
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
|
||||
}
|
||||
return wxRsp.Response.PrepayId, nil
|
||||
} else if params.Device == "pc" {
|
||||
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
|
||||
}
|
||||
if wxRsp.Code != wechat.Success {
|
||||
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
|
||||
}
|
||||
return wxRsp.Response.CodeUrl, nil
|
||||
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (s *WxPayService) Query(outTradeNo string) (OrderInfo, error) {
|
||||
wxRsp, err := s.client.V3TransactionQueryOrder(context.Background(), wechat.OutTradeNo, outTradeNo)
|
||||
if err != nil {
|
||||
return OrderInfo{}, fmt.Errorf("error with client v3 transaction query: %v", err)
|
||||
}
|
||||
|
||||
if wxRsp.Code != wechat.Success {
|
||||
return OrderInfo{}, fmt.Errorf("error status with querying order: %v", wxRsp.Error)
|
||||
}
|
||||
|
||||
if wxRsp.Response.TradeState == "CLOSED" {
|
||||
return OrderInfo{Status: Closed}, nil
|
||||
}
|
||||
|
||||
orderInfo := OrderInfo{
|
||||
OutTradeNo: wxRsp.Response.OutTradeNo,
|
||||
TradeId: wxRsp.Response.TransactionId,
|
||||
Amount: fmt.Sprintf("%d", wxRsp.Response.Amount.Total/100),
|
||||
PayTime: wxRsp.Response.SuccessTime,
|
||||
}
|
||||
if wxRsp.Response.TradeState == "SUCCESS" {
|
||||
orderInfo.Status = Success
|
||||
} else {
|
||||
orderInfo.Status = Failure
|
||||
}
|
||||
return orderInfo, nil
|
||||
}
|
||||
|
||||
// TradeVerify 交易验证
|
||||
func (s *WxPayService) TradeVerify(request *http.Request) (OrderInfo, error) {
|
||||
notifyReq, err := wechat.V3ParseNotify(request)
|
||||
if err != nil {
|
||||
return OrderInfo{}, fmt.Errorf("error with client v3 parse notify: %v", err)
|
||||
}
|
||||
|
||||
// 解密支付密文,验证订单信息
|
||||
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
|
||||
if err != nil {
|
||||
return OrderInfo{}, fmt.Errorf("error with client v3 decrypt: %v", err)
|
||||
}
|
||||
|
||||
return OrderInfo{
|
||||
Status: Success,
|
||||
OutTradeNo: result.OutTradeNo,
|
||||
TradeId: result.TransactionId,
|
||||
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
|
||||
PayTime: result.SuccessTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
|
||||
// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||
// // 初始化 BodyMap
|
||||
// bm := make(gopay.BodyMap)
|
||||
// bm.Set("appid", s.config.AppId).
|
||||
// Set("mchid", s.config.MchId).
|
||||
// Set("description", params.Subject).
|
||||
// Set("out_trade_no", params.OutTradeNo).
|
||||
// Set("time_expire", expire).
|
||||
// Set("notify_url", params.NotifyURL).
|
||||
// SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||
// bm.Set("total", params.TotalFee).
|
||||
// Set("currency", "CNY")
|
||||
// })
|
||||
|
||||
// wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
|
||||
// if err != nil {
|
||||
// return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
|
||||
// }
|
||||
// if wxRsp.Code != wechat.Success {
|
||||
// return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
|
||||
// }
|
||||
// return wxRsp.Response.CodeUrl, nil
|
||||
// }
|
||||
|
||||
// func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
|
||||
// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||
// // 初始化 BodyMap
|
||||
// bm := make(gopay.BodyMap)
|
||||
// bm.Set("appid", s.config.AppId).
|
||||
// Set("mchid", s.config.MchId).
|
||||
// Set("description", params.Subject).
|
||||
// Set("out_trade_no", params.OutTradeNo).
|
||||
// Set("time_expire", expire).
|
||||
// Set("notify_url", params.NotifyURL).
|
||||
// SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||
// bm.Set("total", params.TotalFee).
|
||||
// Set("currency", "CNY")
|
||||
// }).
|
||||
// SetBodyMap("scene_info", func(bm gopay.BodyMap) {
|
||||
// bm.Set("payer_client_ip", params.ClientIP).
|
||||
// SetBodyMap("h5_info", func(bm gopay.BodyMap) {
|
||||
// bm.Set("type", "Wap")
|
||||
// })
|
||||
// })
|
||||
|
||||
// wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
|
||||
// if err != nil {
|
||||
// return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
|
||||
// }
|
||||
// if wxRsp.Code != wechat.Success {
|
||||
// return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
|
||||
// }
|
||||
// return wxRsp.Response.H5Url, nil
|
||||
// }
|
||||
|
||||
// type NotifyResponse struct {
|
||||
// Code string `json:"code"`
|
||||
// Message string `xml:"message"`
|
||||
// }
|
||||
|
||||
var _ PayService = (*WxPayService)(nil)
|
||||
@@ -10,36 +10,57 @@ package sms
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
|
||||
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
|
||||
)
|
||||
|
||||
type AliYunSmsService struct {
|
||||
config *types.SmsConfigAli
|
||||
config types.SmsConfigAli
|
||||
client *dysmsapi.Client
|
||||
domain string
|
||||
zoneId string
|
||||
}
|
||||
|
||||
func NewAliYunSmsService(appConfig *types.AppConfig) (*AliYunSmsService, error) {
|
||||
config := &appConfig.SMS.Ali
|
||||
// 创建阿里云短信客户端
|
||||
func NewAliYunSmsService(sysConfig *types.SystemConfig) (*AliYunSmsService, error) {
|
||||
config := sysConfig.SMS.Ali
|
||||
domain := "dysmsapi.aliyuncs.com"
|
||||
zoneId := "cn-hangzhou"
|
||||
|
||||
s := AliYunSmsService{
|
||||
config: config,
|
||||
domain: domain,
|
||||
zoneId: zoneId,
|
||||
}
|
||||
if sysConfig.SMS.Active == Ali {
|
||||
err := s.UpdateConfig(config)
|
||||
if err != nil {
|
||||
logger.Errorf("阿里云短信初始化失败: %v", err)
|
||||
}
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func (s *AliYunSmsService) UpdateConfig(config types.SmsConfigAli) error {
|
||||
client, err := dysmsapi.NewClientWithAccessKey(
|
||||
"cn-hangzhou",
|
||||
s.zoneId,
|
||||
config.AccessKey,
|
||||
config.AccessSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create client: %v", err)
|
||||
return fmt.Errorf("failed to create client: %v", err)
|
||||
}
|
||||
|
||||
return &AliYunSmsService{
|
||||
config: config,
|
||||
client: client,
|
||||
}, nil
|
||||
s.client = client
|
||||
s.config = config
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AliYunSmsService) SendVerifyCode(mobile string, code int) error {
|
||||
if s.client == nil {
|
||||
return fmt.Errorf("阿里云短信服务未初始化")
|
||||
}
|
||||
// 创建短信请求并设置参数
|
||||
request := dysmsapi.CreateSendSmsRequest()
|
||||
request.Scheme = "https"
|
||||
request.Domain = s.config.Domain
|
||||
request.Domain = s.domain
|
||||
request.PhoneNumbers = mobile
|
||||
request.SignName = s.config.Sign
|
||||
request.TemplateCode = s.config.CodeTempId
|
||||
|
||||
@@ -19,20 +19,21 @@ import (
|
||||
)
|
||||
|
||||
type BaoSmsService struct {
|
||||
config *types.SmsConfigBao
|
||||
config types.SmsConfigBao
|
||||
domain string
|
||||
}
|
||||
|
||||
func NewSmsBaoSmsService(appConfig *types.AppConfig) *BaoSmsService {
|
||||
config := appConfig.SMS.Bao
|
||||
if config.Domain == "" { // use default domain
|
||||
config.Domain = "api.smsbao.com"
|
||||
logger.Infof("Using default domain for SMS-BAO: %s", config.Domain)
|
||||
}
|
||||
func NewBaoSmsService(sysConfig *types.SystemConfig) *BaoSmsService {
|
||||
return &BaoSmsService{
|
||||
config: &config,
|
||||
config: sysConfig.SMS.Bao,
|
||||
domain: "api.smsbao.com",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BaoSmsService) UpdateConfig(config types.SmsConfigBao) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
var errMsg = map[string]string{
|
||||
"0": "短信发送成功",
|
||||
"-1": "参数不全",
|
||||
@@ -56,7 +57,7 @@ func (s *BaoSmsService) SendVerifyCode(mobile string, code int) error {
|
||||
params.Set("m", mobile)
|
||||
params.Set("c", content)
|
||||
|
||||
apiURL := fmt.Sprintf("https://%s/sms?%s", s.config.Domain, params.Encode())
|
||||
apiURL := fmt.Sprintf("https://%s/sms?%s", s.domain, params.Encode())
|
||||
response, err := http.Get(apiURL)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -7,8 +7,8 @@ package sms
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
const Ali = "ALI"
|
||||
const Bao = "BAO"
|
||||
const Ali = "aliyun"
|
||||
const Bao = "bao"
|
||||
|
||||
type Service interface {
|
||||
SendVerifyCode(mobile string, code int) error
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
package sms
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ServiceManager struct {
|
||||
handler Service
|
||||
}
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
func NewSendServiceManager(config *types.AppConfig) (*ServiceManager, error) {
|
||||
active := Ali
|
||||
if config.SMS.Active != "" {
|
||||
active = strings.ToUpper(config.SMS.Active)
|
||||
}
|
||||
var handler Service
|
||||
switch active {
|
||||
case Ali:
|
||||
client, err := NewAliYunSmsService(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handler = client
|
||||
break
|
||||
case Bao:
|
||||
handler = NewSmsBaoSmsService(config)
|
||||
break
|
||||
}
|
||||
|
||||
return &ServiceManager{handler: handler}, nil
|
||||
}
|
||||
|
||||
func (m *ServiceManager) GetService() Service {
|
||||
return m.handler
|
||||
}
|
||||
54
api/service/sms/sms_manager.go
Normal file
54
api/service/sms/sms_manager.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package sms
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
)
|
||||
|
||||
type SmsManager struct {
|
||||
aliyun *AliYunSmsService
|
||||
bao *BaoSmsService
|
||||
active string
|
||||
}
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
func NewSmsManager(sysConfig *types.SystemConfig, aliyun *AliYunSmsService, bao *BaoSmsService) (*SmsManager, error) {
|
||||
|
||||
return &SmsManager{
|
||||
active: sysConfig.SMS.Active,
|
||||
aliyun: aliyun,
|
||||
bao: bao,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *SmsManager) GetService() Service {
|
||||
switch m.active {
|
||||
case Ali:
|
||||
return m.aliyun
|
||||
case Bao:
|
||||
return m.bao
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SmsManager) SetActive(active string) {
|
||||
m.active = active
|
||||
}
|
||||
|
||||
func (m *SmsManager) UpdateConfig(config types.SMSConfig) {
|
||||
switch config.Active {
|
||||
case Ali:
|
||||
m.aliyun.UpdateConfig(config.Ali)
|
||||
case Bao:
|
||||
m.bao.UpdateConfig(config.Bao)
|
||||
}
|
||||
m.active = config.Active
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user