mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	Compare commits
	
		
			105 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					bc7d06d3e5 | ||
| 
						 | 
					8e81dfa12a | ||
| 
						 | 
					0ff76f0f21 | ||
| 
						 | 
					787caa84c8 | ||
| 
						 | 
					c2503e663a | ||
| 
						 | 
					405a88862b | ||
| 
						 | 
					296eabe09a | ||
| 
						 | 
					54b45ec2ff | ||
| 
						 | 
					c434f85045 | ||
| 
						 | 
					4d10279870 | ||
| 
						 | 
					9de9489673 | ||
| 
						 | 
					9814fec930 | ||
| 
						 | 
					53ba731159 | ||
| 
						 | 
					b2f57aa483 | ||
| 
						 | 
					4c2dba1004 | ||
| 
						 | 
					cdaf6fb9dc | ||
| 
						 | 
					78f443ed6d | ||
| 
						 | 
					54e8d72b10 | ||
| 
						 | 
					05161f48fd | ||
| 
						 | 
					e971bf6b88 | ||
| 
						 | 
					55b979784c | ||
| 
						 | 
					79adc871ef | ||
| 
						 | 
					97aa922b5f | ||
| 
						 | 
					11c760a4e8 | ||
| 
						 | 
					8144fada25 | ||
| 
						 | 
					87b03332d9 | ||
| 
						 | 
					8b14eeadf4 | ||
| 
						 | 
					e0ead127e0 | ||
| 
						 | 
					0887bcdee0 | ||
| 
						 | 
					67d83041d7 | ||
| 
						 | 
					1350f388f0 | ||
| 
						 | 
					65dde9e69d | ||
| 
						 | 
					2e5bd238b7 | ||
| 
						 | 
					8fc8fd6cba | ||
| 
						 | 
					dfc6c87250 | ||
| 
						 | 
					b63e01225e | ||
| 
						 | 
					561b82027a | ||
| 
						 | 
					f6d8fbf570 | ||
| 
						 | 
					568201ebbb | ||
| 
						 | 
					ab421f2185 | ||
| 
						 | 
					f71a2f5263 | ||
| 
						 | 
					d000cc5a67 | ||
| 
						 | 
					04d6ba0853 | ||
| 
						 | 
					8d7c028ca8 | ||
| 
						 | 
					3ae7ebfeaf | ||
| 
						 | 
					aa42d38387 | ||
| 
						 | 
					43843b92f2 | ||
| 
						 | 
					5da879600a | ||
| 
						 | 
					87ed2064e3 | ||
| 
						 | 
					34e96e91d4 | ||
| 
						 | 
					8c4c2b89ce | ||
| 
						 | 
					373021c191 | ||
| 
						 | 
					740c3c1b00 | ||
| 
						 | 
					67c7132e6b | ||
| 
						 | 
					c77843424b | ||
| 
						 | 
					2d4959aa7d | ||
| 
						 | 
					167c59a159 | ||
| 
						 | 
					1d0006ce59 | ||
| 
						 | 
					6a8b4ee2f1 | ||
| 
						 | 
					72b1515b68 | ||
| 
						 | 
					3f0252b498 | ||
| 
						 | 
					1d9d487f0e | ||
| 
						 | 
					96f1126d02 | ||
| 
						 | 
					7f9b8d8246 | ||
| 
						 | 
					5132d52a44 | ||
| 
						 | 
					1bcbf74883 | ||
| 
						 | 
					abdf5298fe | ||
| 
						 | 
					2129f7a8b7 | ||
| 
						 | 
					f6f8748521 | ||
| 
						 | 
					59301df073 | ||
| 
						 | 
					e17dcf4d5f | ||
| 
						 | 
					09f44e6d9b | ||
| 
						 | 
					59824bffc5 | ||
| 
						 | 
					cb0dacd5e0 | ||
| 
						 | 
					7463cfc66c | ||
| 
						 | 
					b248560ba2 | ||
| 
						 | 
					37368fe13f | ||
| 
						 | 
					246b023624 | ||
| 
						 | 
					a6b9f57a50 | ||
| 
						 | 
					42bc23cacf | ||
| 
						 | 
					282f55c7a3 | ||
| 
						 | 
					44798f89ba | ||
| 
						 | 
					596cb2b206 | ||
| 
						 | 
					d1965deff1 | ||
| 
						 | 
					a5ef4299ec | ||
| 
						 | 
					cdb1a8bde1 | ||
| 
						 | 
					64e5fc48ba | ||
| 
						 | 
					a692cf1338 | ||
| 
						 | 
					6998dd7af4 | ||
| 
						 | 
					9343c73e0f | ||
| 
						 | 
					739cd46539 | ||
| 
						 | 
					f8fed83507 | ||
| 
						 | 
					d63536d5ef | ||
| 
						 | 
					4905fb28d4 | ||
| 
						 | 
					a3a2a8abcb | ||
| 
						 | 
					839dd8dbf4 | ||
| 
						 | 
					0375164f40 | ||
| 
						 | 
					691294b444 | ||
| 
						 | 
					c24b4d7074 | ||
| 
						 | 
					ab24398748 | ||
| 
						 | 
					6110522b54 | ||
| 
						 | 
					bcdf5e3776 | ||
| 
						 | 
					2207830db9 | ||
| 
						 | 
					d52dfbfef4 | ||
| 
						 | 
					66ccb387e8 | 
							
								
								
									
										59
									
								
								CHANGELOG.md
									
									
									
									
									
								
							
							
						
						
									
										59
									
								
								CHANGELOG.md
									
									
									
									
									
								
							@@ -1,4 +1,63 @@
 | 
			
		||||
# 更新日志
 | 
			
		||||
 | 
			
		||||
## v4.1.3
 | 
			
		||||
* 功能优化:重构用户登录模块,给所有的登录组件增加行为验证码功能,支持用户绑定手机,邮箱和微信
 | 
			
		||||
* 功能优化:重构找回密码模块,支持通过手机或者邮箱找回密码
 | 
			
		||||
* 功能优化:管理后台给可以拖动排序的组件添加拖动图标
 | 
			
		||||
* 功能优化:Suno 支持合成完整歌曲,和上传自己的音乐作品进行二次创作
 | 
			
		||||
* Bug修复:手机端角色和模型选择不生效
 | 
			
		||||
* Bug修复:用户登录过期之后聊天页面出现大量报错,需要刷新页面才能正常
 | 
			
		||||
* 功能优化:优化聊天页面 Websocket 断线重连代码,提高用户体验
 | 
			
		||||
* 功能优化:给算力增减服务全部加上数据库事务和同步锁
 | 
			
		||||
* 功能优化:支持用户在前端对话界面选择插件
 | 
			
		||||
* 功能新增:支持 Luma 文生视频功能
 | 
			
		||||
 | 
			
		||||
## v4.1.2
 | 
			
		||||
* Bug修复:修复思维导图页面获取模型失败的问题
 | 
			
		||||
* 功能优化:优化MJ,SD,DALL-E 任务列表页面,显示失败任务的错误信息,删除失败任务可以恢复扣减算力
 | 
			
		||||
* Bug修复:修复后台拖动排序组件 Bug
 | 
			
		||||
* 功能优化:更新数据库失败时候显示具体的的报错信息
 | 
			
		||||
* Bug修复:修复管理后台对话详情页内容显示异常问题
 | 
			
		||||
* 功能优化:管理后台新增清空所有未支付订单的功能
 | 
			
		||||
* 功能优化:给会话信息和系统配置数据加上缓存功能,减少 http 请求
 | 
			
		||||
* 功能新增:移除微信机器人收款功能,增加卡密功能,支持用户使用卡密兑换算力
 | 
			
		||||
 | 
			
		||||
## v4.1.1
 | 
			
		||||
* Bug修复:修复 GPT 模型 function call 调用后没有输出的问题
 | 
			
		||||
* 功能新增:允许获取 License 授权用户可以自定义版权信息
 | 
			
		||||
* 功能新增:聊天对话框支持粘贴剪切板内容来上传截图和文件
 | 
			
		||||
* 功能优化:增加 session 和系统配置缓存,确保每个页面只进行一次 session 和 get system config 请求
 | 
			
		||||
* 功能优化:在应用列表页面,无需先添加模型到用户工作区,可以直接使用
 | 
			
		||||
* 功能新增:MJ 绘图失败的任务不会自动删除,而是会在列表页显示失败详细错误信息
 | 
			
		||||
* 功能新增:允许在设置首页纯色背景,背景图片,随机背景图片三种背景模式
 | 
			
		||||
* 功能新增:允许在管理后台设置首页显示的导航菜单
 | 
			
		||||
* Bug修复:修复注册页面先显示关闭注册组件,然后再显示注册组件
 | 
			
		||||
* 功能新增:增加 Suno 文生歌曲功能
 | 
			
		||||
* 功能优化:移除多平台模型支持,统一使用 one-api 接口形式,其他平台的模型需要通过 one-api 接口添加
 | 
			
		||||
* 功能优化:在所有列表页面增加返回顶部按钮
 | 
			
		||||
 | 
			
		||||
## v4.1.0
 | 
			
		||||
* bug修复:修复移动端修改聊天标题不生效的问题
 | 
			
		||||
* Bug修复:修复用户注册不显示用户名的问题
 | 
			
		||||
* Bug修复:修复管理后台拖动排序不生效的问题
 | 
			
		||||
* 功能优化:允许用户设置自定义首页背景图片
 | 
			
		||||
* 功能新增:**支持AI解读 PDF, Word, Excel等文件**
 | 
			
		||||
* 功能优化:优化聊天界面的用户上传文件的列表样式
 | 
			
		||||
* 功能优化:优化聊天页面对话样式,支持列表样式和对话样式切换
 | 
			
		||||
* 功能新增:支持微信扫码登录,未注册用户微信扫码后会自动注册并登录。移动使用微信浏览器打开可以实现无感登录。
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## v4.0.9
 | 
			
		||||
* 环境升级:升级 Golang 到 go1.22.4
 | 
			
		||||
* 功能增加:接入微信商户号支付渠道
 | 
			
		||||
* Bug修复:修复前端页面菜单把页面撑开,底部留白问题
 | 
			
		||||
* 功能优化:聊天页面自动根据内容调整输入框的高度
 | 
			
		||||
* Bug修复:修复Dalle绘图失败退回算力的问题
 | 
			
		||||
* 功能优化:邀请码注册时被邀请人也可以获得赠送的算力
 | 
			
		||||
* 功能优化:允许设置邮件验证码的抬头
 | 
			
		||||
* Bug修复:修复免费模型不会记录聊天记录的bug
 | 
			
		||||
* Bug修复:修复聊天输入公式显示异常的Bug
 | 
			
		||||
 | 
			
		||||
## v4.0.8
 | 
			
		||||
* 功能优化:升级 mathjax 公式解析插件,修复公式因为图片访问限制而无法显示的问题
 | 
			
		||||
* 功能优化:当数据库更新失败的时候记录错误日志
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										1
									
								
								api/.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								api/.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -18,3 +18,4 @@ data
 | 
			
		||||
config.toml
 | 
			
		||||
static/upload 
 | 
			
		||||
storage.json 
 | 
			
		||||
res/certs/wechat/apiclient_key.pem
 | 
			
		||||
 
 | 
			
		||||
@@ -3,8 +3,7 @@ ProxyURL = "" # 如 http://127.0.0.1:7777
 | 
			
		||||
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local"
 | 
			
		||||
StaticDir = "./static" # 静态资源的目录
 | 
			
		||||
StaticUrl = "/static" # 静态资源访问 URL
 | 
			
		||||
AesEncryptKey = ""
 | 
			
		||||
WeChatBot = false
 | 
			
		||||
TikaHost = "http://tika:9998"
 | 
			
		||||
 | 
			
		||||
[Session]
 | 
			
		||||
  SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换
 | 
			
		||||
@@ -17,7 +16,7 @@ WeChatBot = false
 | 
			
		||||
  DB = 0
 | 
			
		||||
 | 
			
		||||
[ApiConfig] # 微博热搜,今日头条等函数服务 API 配置,此为第三方插件服务,如需使用请联系作者开通
 | 
			
		||||
  ApiURL = ""
 | 
			
		||||
  ApiURL = "https://sapi.geekai.me"
 | 
			
		||||
  AppId = ""
 | 
			
		||||
  Token = ""
 | 
			
		||||
 | 
			
		||||
@@ -64,23 +63,6 @@ WeChatBot = false
 | 
			
		||||
       SubDir = ""
 | 
			
		||||
       Domain = ""
 | 
			
		||||
 | 
			
		||||
[[MjProxyConfigs]]
 | 
			
		||||
  Enabled = true
 | 
			
		||||
  ApiURL = "http://midjourney-proxy:8082"
 | 
			
		||||
  ApiKey = "sk-geekmaster"
 | 
			
		||||
 | 
			
		||||
[[MjPlusConfigs]]
 | 
			
		||||
  Enabled = false
 | 
			
		||||
  ApiURL = "https://api.chat-plus.net"
 | 
			
		||||
  Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo
 | 
			
		||||
  ApiKey = "sk-xxx"
 | 
			
		||||
 | 
			
		||||
[[SdConfigs]]
 | 
			
		||||
  Enabled = false
 | 
			
		||||
  ApiURL = ""
 | 
			
		||||
  ApiKey = ""
 | 
			
		||||
  Txt2ImgJsonPath = "res/sd/text2img.json"
 | 
			
		||||
 | 
			
		||||
[XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP,如果你没有启用支付服务,则该服务也无需启动
 | 
			
		||||
  Enabled = false # 是否启用 XXL JOB 服务
 | 
			
		||||
  ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址
 | 
			
		||||
@@ -123,3 +105,15 @@ WeChatBot = false
 | 
			
		||||
  PrivateKey = "" # 秘钥
 | 
			
		||||
  ApiURL = "https://payjs.cn"
 | 
			
		||||
  NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己的
 | 
			
		||||
 | 
			
		||||
# 微信商户支付
 | 
			
		||||
[WechatPayConfig]
 | 
			
		||||
  Enabled = false
 | 
			
		||||
  AppId = "" # 商户应用ID
 | 
			
		||||
  MchId = "" # 商户号
 | 
			
		||||
  SerialNo = "" # API 证书序列号
 | 
			
		||||
  PrivateKey = "certs/alipay/privateKey.txt" # API 证书私钥文件路径,跟支付宝一样,把私钥文件拷贝到对应的路径,证书路径要映射到容器内
 | 
			
		||||
  ApiV3Key = "" # APIV3 私钥,这个是你自己在微信支付平台设置的
 | 
			
		||||
  NotifyURL = "https://ai.r9it.com/api/payment/wechat/notify" # 支付成功异步回调地址,域名改成自己的
 | 
			
		||||
  ReturnURL = "" # 支付成功同步回调地址
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -32,31 +32,19 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type AppServer struct {
 | 
			
		||||
	Debug        bool
 | 
			
		||||
	Config       *types.AppConfig
 | 
			
		||||
	Engine       *gin.Engine
 | 
			
		||||
	ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
 | 
			
		||||
 | 
			
		||||
	Debug     bool
 | 
			
		||||
	Config    *types.AppConfig
 | 
			
		||||
	Engine    *gin.Engine
 | 
			
		||||
	SysConfig *types.SystemConfig // system config cache
 | 
			
		||||
 | 
			
		||||
	// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
 | 
			
		||||
	// 防止第三方直接连接 socket 调用 OpenAI API
 | 
			
		||||
	ChatSession   *types.LMap[string, *types.ChatSession] //map[sessionId]UserId
 | 
			
		||||
	ChatClients   *types.LMap[string, *types.WsClient]    // map[sessionId]Websocket 连接集合
 | 
			
		||||
	ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewServer(appConfig *types.AppConfig) *AppServer {
 | 
			
		||||
	gin.SetMode(gin.ReleaseMode)
 | 
			
		||||
	gin.DefaultWriter = io.Discard
 | 
			
		||||
	return &AppServer{
 | 
			
		||||
		Debug:         false,
 | 
			
		||||
		Config:        appConfig,
 | 
			
		||||
		Engine:        gin.Default(),
 | 
			
		||||
		ChatContexts:  types.NewLMap[string, []types.Message](),
 | 
			
		||||
		ChatSession:   types.NewLMap[string, *types.ChatSession](),
 | 
			
		||||
		ChatClients:   types.NewLMap[string, *types.WsClient](),
 | 
			
		||||
		ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
 | 
			
		||||
		Debug:  false,
 | 
			
		||||
		Config: appConfig,
 | 
			
		||||
		Engine: gin.Default(),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -95,7 +83,7 @@ func errorHandler(c *gin.Context) {
 | 
			
		||||
		if r := recover(); r != nil {
 | 
			
		||||
			logger.Errorf("Handler Panic: %v", r)
 | 
			
		||||
			debug.PrintStack()
 | 
			
		||||
			c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
 | 
			
		||||
			c.JSON(http.StatusBadRequest, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
 | 
			
		||||
			c.Abort()
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
@@ -151,7 +139,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
 | 
			
		||||
 | 
			
		||||
		if tokenString == "" {
 | 
			
		||||
			if needLogin(c) {
 | 
			
		||||
				resp.ERROR(c, "You should put Authorization in request headers")
 | 
			
		||||
				resp.NotAuth(c, "You should put Authorization in request headers")
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				return
 | 
			
		||||
			} else { // 直接放行
 | 
			
		||||
@@ -213,7 +201,6 @@ func needLogin(c *gin.Context) bool {
 | 
			
		||||
		c.Request.URL.Path == "/api/admin/logout" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/admin/login/captcha" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/user/register" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/user/session" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/chat/history" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/chat/detail" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/chat/list" ||
 | 
			
		||||
@@ -233,9 +220,16 @@ func needLogin(c *gin.Context) bool {
 | 
			
		||||
		c.Request.URL.Path == "/api/payment/alipay/notify" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/payment/hupipay/notify" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/payment/payjs/notify" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/payment/wechat/notify" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/payment/doPay" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/payment/payWays" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/suno/client" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/suno/detail" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/suno/play" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/download" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/video/client" ||
 | 
			
		||||
		strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
 | 
			
		||||
		strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
 | 
			
		||||
		strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
 | 
			
		||||
		strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
 | 
			
		||||
		strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
 | 
			
		||||
@@ -374,6 +368,7 @@ func staticResourceMiddleware() gin.HandlerFunc {
 | 
			
		||||
			// 直接输出图像数据流
 | 
			
		||||
			c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
 | 
			
		||||
			c.Abort() // 中断请求
 | 
			
		||||
 | 
			
		||||
		}
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -38,7 +38,6 @@ func NewDefaultConfig() *types.AppConfig {
 | 
			
		||||
				BasePath: "./static/upload",
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		WeChatBot:    false,
 | 
			
		||||
		AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -53,16 +53,15 @@ type Delta struct {
 | 
			
		||||
// ChatSession 聊天会话对象
 | 
			
		||||
type ChatSession struct {
 | 
			
		||||
	SessionId string    `json:"session_id"`
 | 
			
		||||
	UserId    uint      `json:"user_id"`
 | 
			
		||||
	ClientIP  string    `json:"client_ip"` // 客户端 IP
 | 
			
		||||
	Username  string    `json:"username"`  // 当前登录的 username
 | 
			
		||||
	UserId    uint      `json:"user_id"`   // 当前登录的 user ID
 | 
			
		||||
	ChatId    string    `json:"chat_id"`   // 客户端聊天会话 ID, 多会话模式专用字段
 | 
			
		||||
	Model     ChatModel `json:"model"`     // GPT 模型
 | 
			
		||||
	Tools     string    `json:"tools"`     // 函数
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatModel struct {
 | 
			
		||||
	Id          uint    `json:"id"`
 | 
			
		||||
	Platform    string  `json:"platform"`
 | 
			
		||||
	Name        string  `json:"name"`
 | 
			
		||||
	Value       string  `json:"value"`
 | 
			
		||||
	Power       int     `json:"power"`
 | 
			
		||||
@@ -92,7 +91,7 @@ const (
 | 
			
		||||
	PowerConsume  = PowerType(2) // 消费
 | 
			
		||||
	PowerRefund   = PowerType(3) // 任务(SD,MJ)执行失败,退款
 | 
			
		||||
	PowerInvite   = PowerType(4) // 邀请奖励
 | 
			
		||||
	PowerReward   = PowerType(5) // 众筹
 | 
			
		||||
	PowerRedeem   = PowerType(5) // 众筹
 | 
			
		||||
	PowerGift     = PowerType(6) // 系统赠送
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -104,8 +103,8 @@ func (t PowerType) String() string {
 | 
			
		||||
		return "消费"
 | 
			
		||||
	case PowerRefund:
 | 
			
		||||
		return "退款"
 | 
			
		||||
	case PowerReward:
 | 
			
		||||
		return "众筹"
 | 
			
		||||
	case PowerRedeem:
 | 
			
		||||
		return "兑换"
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
	return "其他"
 | 
			
		||||
 
 | 
			
		||||
@@ -12,28 +12,26 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type AppConfig struct {
 | 
			
		||||
	Path           string `toml:"-"`
 | 
			
		||||
	Listen         string
 | 
			
		||||
	Session        Session
 | 
			
		||||
	AdminSession   Session
 | 
			
		||||
	ProxyURL       string
 | 
			
		||||
	MysqlDns       string                  // mysql 连接地址
 | 
			
		||||
	StaticDir      string                  // 静态资源目录
 | 
			
		||||
	StaticUrl      string                  // 静态资源 URL
 | 
			
		||||
	Redis          RedisConfig             // redis 连接信息
 | 
			
		||||
	ApiConfig      ApiConfig               // ChatPlus API authorization configs
 | 
			
		||||
	SMS            SMSConfig               // send mobile message config
 | 
			
		||||
	OSS            OSSConfig               // OSS config
 | 
			
		||||
	MjProxyConfigs []MjProxyConfig         // MJ proxy config
 | 
			
		||||
	MjPlusConfigs  []MjPlusConfig          // MJ plus config
 | 
			
		||||
	WeChatBot      bool                    // 是否启用微信机器人
 | 
			
		||||
	SdConfigs      []StableDiffusionConfig // sd AI draw service pool
 | 
			
		||||
	Path         string `toml:"-"`
 | 
			
		||||
	Listen       string
 | 
			
		||||
	Session      Session
 | 
			
		||||
	AdminSession Session
 | 
			
		||||
	ProxyURL     string
 | 
			
		||||
	MysqlDns     string      // mysql 连接地址
 | 
			
		||||
	StaticDir    string      // 静态资源目录
 | 
			
		||||
	StaticUrl    string      // 静态资源 URL
 | 
			
		||||
	Redis        RedisConfig // redis 连接信息
 | 
			
		||||
	ApiConfig    ApiConfig   // ChatPlus API authorization configs
 | 
			
		||||
	SMS          SMSConfig   // send mobile message config
 | 
			
		||||
	OSS          OSSConfig   // OSS config
 | 
			
		||||
 | 
			
		||||
	XXLConfig     XXLConfig
 | 
			
		||||
	AlipayConfig  AlipayConfig
 | 
			
		||||
	HuPiPayConfig HuPiPayConfig
 | 
			
		||||
	SmtpConfig    SmtpConfig // 邮件发送配置
 | 
			
		||||
	JPayConfig    JPayConfig // payjs 支付配置
 | 
			
		||||
	XXLConfig       XXLConfig
 | 
			
		||||
	AlipayConfig    AlipayConfig    // 支付宝支付渠道配置
 | 
			
		||||
	HuPiPayConfig   HuPiPayConfig   // 虎皮椒支付配置
 | 
			
		||||
	SmtpConfig      SmtpConfig      // 邮件发送配置
 | 
			
		||||
	JPayConfig      JPayConfig      // payjs 支付配置
 | 
			
		||||
	WechatPayConfig WechatPayConfig // 微信支付渠道配置
 | 
			
		||||
	TikaHost        string          // TiKa 服务器地址
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SmtpConfig struct {
 | 
			
		||||
@@ -51,27 +49,6 @@ type ApiConfig struct {
 | 
			
		||||
	Token  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MjProxyConfig struct {
 | 
			
		||||
	Enabled bool
 | 
			
		||||
	ApiURL  string // api 地址
 | 
			
		||||
	Mode    string // 绘画模式,可选值:fast/turbo/relax
 | 
			
		||||
	ApiKey  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type StableDiffusionConfig struct {
 | 
			
		||||
	Enabled bool
 | 
			
		||||
	Model   string // 模型名称
 | 
			
		||||
	ApiURL  string
 | 
			
		||||
	ApiKey  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MjPlusConfig struct {
 | 
			
		||||
	Enabled bool   // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务
 | 
			
		||||
	ApiURL  string // api 地址
 | 
			
		||||
	Mode    string // 绘画模式,可选值:fast/turbo/relax
 | 
			
		||||
	ApiKey  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AlipayConfig struct {
 | 
			
		||||
	Enabled         bool   // 是否启用该支付通道
 | 
			
		||||
	SandBox         bool   // 是否沙盒环境
 | 
			
		||||
@@ -85,6 +62,17 @@ type AlipayConfig struct {
 | 
			
		||||
	ReturnURL       string // 支付成功返回地址
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type WechatPayConfig struct {
 | 
			
		||||
	Enabled    bool   // 是否启用该支付通道
 | 
			
		||||
	AppId      string // 公众号的APPID,如:wxd678efh567hg6787
 | 
			
		||||
	MchId      string // 直连商户的商户号,由微信支付生成并下发
 | 
			
		||||
	SerialNo   string // 商户证书的证书序列号
 | 
			
		||||
	PrivateKey string // 用户私钥文件路径
 | 
			
		||||
	ApiV3Key   string // API V3 秘钥
 | 
			
		||||
	NotifyURL  string // 异步通知回调
 | 
			
		||||
	ReturnURL  string // 支付成功返回地址
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type HuPiPayConfig struct { //虎皮椒第四方支付配置
 | 
			
		||||
	Enabled   bool   // 是否启用该支付通道
 | 
			
		||||
	Name      string // 支付名称,如:wechat/alipay
 | 
			
		||||
@@ -142,49 +130,11 @@ func (c RedisConfig) Url() string {
 | 
			
		||||
	return fmt.Sprintf("%s:%d", c.Host, c.Port)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Platform struct {
 | 
			
		||||
	Name    string `json:"name"`
 | 
			
		||||
	Value   string `json:"value"`
 | 
			
		||||
	ChatURL string `json:"chat_url"`
 | 
			
		||||
	ImgURL  string `json:"img_url"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var OpenAI = Platform{
 | 
			
		||||
	Name:    "OpenAI - GPT",
 | 
			
		||||
	Value:   "OpenAI",
 | 
			
		||||
	ChatURL: "https://api.chat-plus.net/v1/chat/completions",
 | 
			
		||||
	ImgURL:  "https://api.chat-plus.net/v1/images/generations",
 | 
			
		||||
}
 | 
			
		||||
var Azure = Platform{
 | 
			
		||||
	Name:    "微软 - Azure",
 | 
			
		||||
	Value:   "Azure",
 | 
			
		||||
	ChatURL: "https://chat-bot-api.openai.azure.com/openai/deployments/{model}/chat/completions?api-version=2023-05-15",
 | 
			
		||||
}
 | 
			
		||||
var ChatGLM = Platform{
 | 
			
		||||
	Name:    "智谱 - ChatGLM",
 | 
			
		||||
	Value:   "ChatGLM",
 | 
			
		||||
	ChatURL: "https://open.bigmodel.cn/api/paas/v3/model-api/{model}/sse-invoke",
 | 
			
		||||
}
 | 
			
		||||
var Baidu = Platform{
 | 
			
		||||
	Name:    "百度 - 文心大模型",
 | 
			
		||||
	Value:   "Baidu",
 | 
			
		||||
	ChatURL: "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}",
 | 
			
		||||
}
 | 
			
		||||
var XunFei = Platform{
 | 
			
		||||
	Name:    "讯飞 - 星火大模型",
 | 
			
		||||
	Value:   "XunFei",
 | 
			
		||||
	ChatURL: "wss://spark-api.xf-yun.com/{version}/chat",
 | 
			
		||||
}
 | 
			
		||||
var QWen = Platform{
 | 
			
		||||
	Name:    "阿里 - 通义千问",
 | 
			
		||||
	Value:   "QWen",
 | 
			
		||||
	ChatURL: "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SystemConfig struct {
 | 
			
		||||
	Title         string `json:"title,omitempty"`
 | 
			
		||||
	AdminTitle    string `json:"admin_title,omitempty"`
 | 
			
		||||
	Logo          string `json:"logo,omitempty"`
 | 
			
		||||
	Title         string `json:"title,omitempty"`           // 网站标题
 | 
			
		||||
	Slogan        string `json:"slogan,omitempty"`          // 网站 slogan
 | 
			
		||||
	AdminTitle    string `json:"admin_title,omitempty"`     // 管理后台标题
 | 
			
		||||
	Logo          string `json:"logo,omitempty"`            // 方形 Logo
 | 
			
		||||
	InitPower     int    `json:"init_power,omitempty"`      // 新用户注册赠送算力值
 | 
			
		||||
	DailyPower    int    `json:"daily_power,omitempty"`     // 每日赠送算力
 | 
			
		||||
	InvitePower   int    `json:"invite_power,omitempty"`    // 邀请新用户赠送算力值
 | 
			
		||||
@@ -193,10 +143,6 @@ type SystemConfig struct {
 | 
			
		||||
	RegisterWays    []string `json:"register_ways,omitempty"`    // 注册方式:支持手机(mobile),邮箱注册(email),账号密码注册
 | 
			
		||||
	EnabledRegister bool     `json:"enabled_register,omitempty"` // 是否开放注册
 | 
			
		||||
 | 
			
		||||
	RewardImg     string  `json:"reward_img,omitempty"`     // 众筹收款二维码地址
 | 
			
		||||
	EnabledReward bool    `json:"enabled_reward,omitempty"` // 启用众筹功能
 | 
			
		||||
	PowerPrice    float64 `json:"power_price,omitempty"`    // 算力单价
 | 
			
		||||
 | 
			
		||||
	OrderPayTimeout int    `json:"order_pay_timeout,omitempty"` //订单支付超时时间
 | 
			
		||||
	VipInfoText     string `json:"vip_info_text,omitempty"`     // 会员页面充值说明
 | 
			
		||||
	DefaultModels   []int  `json:"default_models,omitempty"`    // 默认开通的 AI 模型
 | 
			
		||||
@@ -204,7 +150,9 @@ type SystemConfig struct {
 | 
			
		||||
	MjPower       int `json:"mj_power,omitempty"`        // MJ 绘画消耗算力
 | 
			
		||||
	MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
 | 
			
		||||
	SdPower       int `json:"sd_power,omitempty"`        // SD 绘画消耗算力
 | 
			
		||||
	DallPower     int `json:"dall_power,omitempty"`      // DALLE3 绘图消耗算力
 | 
			
		||||
	DallPower     int `json:"dall_power,omitempty"`      // DALL-E-3 绘图消耗算力
 | 
			
		||||
	SunoPower     int `json:"suno_power,omitempty"`      // Suno 生成歌曲消耗算力
 | 
			
		||||
	LumaPower     int `json:"luma_power,omitempty"`      // Luma 生成视频消耗算力
 | 
			
		||||
 | 
			
		||||
	WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
 | 
			
		||||
 | 
			
		||||
@@ -212,6 +160,12 @@ type SystemConfig struct {
 | 
			
		||||
	ContextDeep   int  `json:"context_deep,omitempty"`
 | 
			
		||||
 | 
			
		||||
	SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
 | 
			
		||||
	MjMode      string `json:"mj_mode"`       // midjourney 默认的API模式,relax, fast, turbo
 | 
			
		||||
 | 
			
		||||
	RandBg bool `json:"rand_bg"` // 前端首页是否启用随机背景
 | 
			
		||||
	IndexBgURL  string `json:"index_bg_url"`  // 前端首页背景图片
 | 
			
		||||
	IndexNavs   []int  `json:"index_navs"`    // 首页显示的导航菜单
 | 
			
		||||
	Copyright   string `json:"copyright"`     // 版权信息
 | 
			
		||||
	MarkMapText string `json:"mark_map_text"` // 思维导入的默认文本
 | 
			
		||||
 | 
			
		||||
	EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -27,8 +27,6 @@ type MjTask struct {
 | 
			
		||||
	Id          uint     `json:"id"`
 | 
			
		||||
	TaskId      string   `json:"task_id"`
 | 
			
		||||
	ImgArr      []string `json:"img_arr"`
 | 
			
		||||
	ChannelId   string   `json:"channel_id"`
 | 
			
		||||
	SessionId   string   `json:"session_id"`
 | 
			
		||||
	Type        TaskType `json:"type"`
 | 
			
		||||
	UserId      int      `json:"user_id"`
 | 
			
		||||
	Prompt      string   `json:"prompt,omitempty"`
 | 
			
		||||
@@ -38,11 +36,12 @@ type MjTask struct {
 | 
			
		||||
	MessageId   string   `json:"message_id,omitempty"`
 | 
			
		||||
	MessageHash string   `json:"message_hash,omitempty"`
 | 
			
		||||
	RetryCount  int      `json:"retry_count"`
 | 
			
		||||
	ChannelId   string   `json:"channel_id"` // 渠道ID,用来区分是哪个渠道创建的任务,一个任务的 create 和 action 操作必须要再同一个渠道
 | 
			
		||||
	Mode        string   `json:"mode"`       // 绘画模式,relax, fast, turbo
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SdTask struct {
 | 
			
		||||
	Id         int          `json:"id"` // job 数据库ID
 | 
			
		||||
	SessionId  string       `json:"session_id"`
 | 
			
		||||
	Type       TaskType     `json:"type"`
 | 
			
		||||
	UserId     int          `json:"user_id"`
 | 
			
		||||
	Params     SdTaskParams `json:"params"`
 | 
			
		||||
@@ -55,10 +54,10 @@ type SdTaskParams struct {
 | 
			
		||||
	NegPrompt    string  `json:"neg_prompt"` // 反向提示词
 | 
			
		||||
	Steps        int     `json:"steps"`      // 迭代步数,默认20
 | 
			
		||||
	Sampler      string  `json:"sampler"`    // 采样器
 | 
			
		||||
	Scheduler    string  `json:"scheduler"`
 | 
			
		||||
	FaceFix      bool    `json:"face_fix"`  // 面部修复
 | 
			
		||||
	CfgScale     float32 `json:"cfg_scale"` //引导系数,默认 7
 | 
			
		||||
	Seed         int64   `json:"seed"`      // 随机数种子
 | 
			
		||||
	Scheduler    string  `json:"scheduler"`  // 采样调度
 | 
			
		||||
	FaceFix      bool    `json:"face_fix"`   // 面部修复
 | 
			
		||||
	CfgScale     float32 `json:"cfg_scale"`  //引导系数,默认 7
 | 
			
		||||
	Seed         int64   `json:"seed"`       // 随机数种子
 | 
			
		||||
	Height       int     `json:"height"`
 | 
			
		||||
	Width        int     `json:"width"`
 | 
			
		||||
	HdFix        bool    `json:"hd_fix"`         // 启用高清修复
 | 
			
		||||
@@ -80,3 +79,47 @@ type DallTask struct {
 | 
			
		||||
 | 
			
		||||
	Power int `json:"power"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SunoTask struct {
 | 
			
		||||
	Id           uint   `json:"id"`
 | 
			
		||||
	Channel      string `json:"channel"`
 | 
			
		||||
	UserId       int    `json:"user_id"`
 | 
			
		||||
	Type         int    `json:"type"`
 | 
			
		||||
	Title        string `json:"title"`
 | 
			
		||||
	RefTaskId    string `json:"ref_task_id,omitempty"`
 | 
			
		||||
	RefSongId    string `json:"ref_song_id,omitempty"`
 | 
			
		||||
	Prompt       string `json:"prompt"` // 提示词/歌词
 | 
			
		||||
	Tags         string `json:"tags"`
 | 
			
		||||
	Model        string `json:"model"`
 | 
			
		||||
	Instrumental bool   `json:"instrumental"`          // 是否纯音乐
 | 
			
		||||
	ExtendSecs   int    `json:"extend_secs,omitempty"` // 延长秒杀
 | 
			
		||||
	SongId       string `json:"song_id,omitempty"`     // 合并歌曲ID
 | 
			
		||||
	AudioURL     string `json:"audio_url"`             // 用户上传音频地址
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	VideoLuma   = "luma"
 | 
			
		||||
	VideoRunway = "runway"
 | 
			
		||||
	VideoCog    = "cog"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type VideoTask struct {
 | 
			
		||||
	Id      uint        `json:"id"`
 | 
			
		||||
	Channel string      `json:"channel"`
 | 
			
		||||
	UserId  int         `json:"user_id"`
 | 
			
		||||
	Type    string      `json:"type"`
 | 
			
		||||
	TaskId  string      `json:"task_id"`
 | 
			
		||||
	Prompt  string      `json:"prompt"` // 提示词
 | 
			
		||||
	Params  VideoParams `json:"params"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type VideoParams struct {
 | 
			
		||||
	PromptOptimize bool   `json:"prompt_optimize"` // 是否优化提示词
 | 
			
		||||
	Loop           bool   `json:"loop"`            // 是否循环参考图
 | 
			
		||||
	StartImgURL    string `json:"start_img_url"`   // 第一帧参考图地址
 | 
			
		||||
	EndImgURL      string `json:"end_img_url"`     // 最后一帧参考图地址
 | 
			
		||||
	Model          string `json:"model"`           // 使用哪个模型生成视频
 | 
			
		||||
	Radio          string `json:"radio"`           // 视频尺寸
 | 
			
		||||
	Style          string `json:"style"`           // 风格
 | 
			
		||||
	Duration       int    `json:"duration"`        // 视频时长(秒)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -22,6 +22,7 @@ type WsMessage struct {
 | 
			
		||||
	Type    WsMsgType   `json:"type"` // 消息类别,start, end, img
 | 
			
		||||
	Content interface{} `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type WsMsgType string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@@ -36,11 +37,9 @@ type BizCode int
 | 
			
		||||
const (
 | 
			
		||||
	Success       = BizCode(0)
 | 
			
		||||
	Failed        = BizCode(1)
 | 
			
		||||
	NotAuthorized = BizCode(400) // 未授权
 | 
			
		||||
	NotPermission = BizCode(403) // 没有权限
 | 
			
		||||
	NotAuthorized = BizCode(401) // 未授权
 | 
			
		||||
 | 
			
		||||
	OkMsg       = "Success"
 | 
			
		||||
	ErrorMsg    = "系统开小差了"
 | 
			
		||||
	InvalidArgs = "非法参数或参数解析失败"
 | 
			
		||||
	NoData      = "No Data"
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										23
									
								
								api/go.mod
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								api/go.mod
									
									
									
									
									
								
							@@ -8,7 +8,6 @@ require (
 | 
			
		||||
	github.com/BurntSushi/toml v1.1.0
 | 
			
		||||
	github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
 | 
			
		||||
	github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible
 | 
			
		||||
	github.com/eatmoreapple/openwechat v1.2.1
 | 
			
		||||
	github.com/gin-gonic/gin v1.9.1
 | 
			
		||||
	github.com/go-redis/redis/v8 v8.11.5
 | 
			
		||||
	github.com/golang-jwt/jwt/v5 v5.0.0
 | 
			
		||||
@@ -19,7 +18,6 @@ require (
 | 
			
		||||
	github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480
 | 
			
		||||
	github.com/qiniu/go-sdk/v7 v7.17.1
 | 
			
		||||
	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 | 
			
		||||
	github.com/smartwalle/alipay/v3 v3.2.15
 | 
			
		||||
	go.uber.org/zap v1.23.0
 | 
			
		||||
	gopkg.in/natefinch/lumberjack.v2 v2.2.1
 | 
			
		||||
	gorm.io/driver/mysql v1.4.7
 | 
			
		||||
@@ -28,19 +26,27 @@ require (
 | 
			
		||||
require github.com/xxl-job/xxl-job-executor-go v1.2.0
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/mojocn/base64Captcha v1.3.1
 | 
			
		||||
	github.com/go-pay/gopay v1.5.101
 | 
			
		||||
	github.com/google/go-tika v0.3.1
 | 
			
		||||
	github.com/microcosm-cc/bluemonday v1.0.26
 | 
			
		||||
	github.com/shirou/gopsutil v3.21.11+incompatible
 | 
			
		||||
	github.com/shopspring/decimal v1.3.1
 | 
			
		||||
	github.com/syndtr/goleveldb v1.0.0
 | 
			
		||||
	golang.org/x/image v0.0.0-20211028202545-6944b10bf410
 | 
			
		||||
	golang.org/x/image v0.15.0
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/aymerick/douceur v0.2.0 // indirect
 | 
			
		||||
	github.com/go-ole/go-ole v1.2.6 // indirect
 | 
			
		||||
	github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
 | 
			
		||||
	github.com/go-pay/crypto v0.0.1 // indirect
 | 
			
		||||
	github.com/go-pay/errgroup v0.0.2 // indirect
 | 
			
		||||
	github.com/go-pay/util v0.0.2 // indirect
 | 
			
		||||
	github.com/go-pay/xlog v0.0.2 // indirect
 | 
			
		||||
	github.com/go-pay/xtime v0.0.2 // indirect
 | 
			
		||||
	github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
 | 
			
		||||
	github.com/tklauser/go-sysconf v0.3.14 // indirect
 | 
			
		||||
	github.com/tklauser/numcpus v0.8.0 // indirect
 | 
			
		||||
	github.com/gorilla/css v1.0.0 // indirect
 | 
			
		||||
	github.com/tklauser/go-sysconf v0.3.13 // indirect
 | 
			
		||||
	github.com/tklauser/numcpus v0.7.0 // indirect
 | 
			
		||||
	github.com/yusufpapurcu/wmi v1.2.4 // indirect
 | 
			
		||||
	go.uber.org/mock v0.4.0 // indirect
 | 
			
		||||
)
 | 
			
		||||
@@ -79,9 +85,6 @@ require (
 | 
			
		||||
	github.com/refraction-networking/utls v1.3.2 // indirect
 | 
			
		||||
	github.com/rs/xid v1.5.0 // indirect
 | 
			
		||||
	github.com/sirupsen/logrus v1.9.3 // indirect
 | 
			
		||||
	github.com/smartwalle/ncrypto v1.0.2 // indirect
 | 
			
		||||
	github.com/smartwalle/ngx v1.0.6 // indirect
 | 
			
		||||
	github.com/smartwalle/nsign v1.0.8 // indirect
 | 
			
		||||
	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
 | 
			
		||||
	go.uber.org/dig v1.16.1 // indirect
 | 
			
		||||
	golang.org/x/arch v0.3.0 // indirect
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										66
									
								
								api/go.sum
									
									
									
									
									
								
							
							
						
						
									
										66
									
								
								api/go.sum
									
									
									
									
									
								
							@@ -6,6 +6,8 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible h1:Sg/2xHwDrioHpxTN6WMiw
 | 
			
		||||
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8=
 | 
			
		||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
 | 
			
		||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
 | 
			
		||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
 | 
			
		||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
 | 
			
		||||
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
 | 
			
		||||
github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
 | 
			
		||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
 | 
			
		||||
@@ -26,8 +28,6 @@ github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0
 | 
			
		||||
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
 | 
			
		||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
 | 
			
		||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
 | 
			
		||||
github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU=
 | 
			
		||||
github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
 | 
			
		||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
 | 
			
		||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
 | 
			
		||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
 | 
			
		||||
@@ -45,6 +45,18 @@ github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
 | 
			
		||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
 | 
			
		||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
 | 
			
		||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
 | 
			
		||||
github.com/go-pay/crypto v0.0.1 h1:B6InT8CLfSLc6nGRVx9VMJRBBazFMjr293+jl0lLXUY=
 | 
			
		||||
github.com/go-pay/crypto v0.0.1/go.mod h1:41oEIvHMKbNcYlWUlRWtsnC6+ASgh7u29z0gJXe5bes=
 | 
			
		||||
github.com/go-pay/errgroup v0.0.2 h1:5mZMdm0TDClDm2S3G0/sm0f8AuQRtz0dOrTHDR9R8Cc=
 | 
			
		||||
github.com/go-pay/errgroup v0.0.2/go.mod h1:0+4b8mvFMS71MIzsaC+gVvB4x37I93lRb2dqrwuU8x8=
 | 
			
		||||
github.com/go-pay/gopay v1.5.101 h1:rVb+sfv6hiQtknAlZnTTLvU27NvFJ4p0yglN/vPpGXI=
 | 
			
		||||
github.com/go-pay/gopay v1.5.101/go.mod h1:AW4Yj8jDZX9BM1/GTLTY1Gy5SHjiq8kQvG5sBTN2sxI=
 | 
			
		||||
github.com/go-pay/util v0.0.2 h1:goJ4f6kNY5zzdtg1Cj8oWC+Cw7bfg/qq2rJangMAb9U=
 | 
			
		||||
github.com/go-pay/util v0.0.2/go.mod h1:qM8VbyF1n7YAPZBSJONSPMPsPedhUTktewUAdf1AjPg=
 | 
			
		||||
github.com/go-pay/xlog v0.0.2 h1:kUg5X8/5VZAPDg1J5eGjA3MG0/H5kK6Ew0dW/Bycsws=
 | 
			
		||||
github.com/go-pay/xlog v0.0.2/go.mod h1:DbjMADPK4+Sjxj28ekK9goqn4zmyY4hql/zRiab+S9E=
 | 
			
		||||
github.com/go-pay/xtime v0.0.2 h1:7YR4/iuELsEHpJ6LUO0SVK80hQxDO9MLCfuVYIiTCRM=
 | 
			
		||||
github.com/go-pay/xtime v0.0.2/go.mod h1:W1yRbJaSt4CSBcdAtLBQ8xajiN/Pl5hquGczUcUE9xE=
 | 
			
		||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
 | 
			
		||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
 | 
			
		||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
 | 
			
		||||
@@ -70,8 +82,6 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
 | 
			
		||||
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
 | 
			
		||||
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
 | 
			
		||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
 | 
			
		||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
 | 
			
		||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
 | 
			
		||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
 | 
			
		||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
 | 
			
		||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
 | 
			
		||||
@@ -79,11 +89,15 @@ github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pO
 | 
			
		||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
 | 
			
		||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
 | 
			
		||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
 | 
			
		||||
github.com/google/go-tika v0.3.1 h1:l+jr10hDhZjcgxFRfcQChRLo1bPXQeLFluMyvDhXTTA=
 | 
			
		||||
github.com/google/go-tika v0.3.1/go.mod h1:DJh5N8qxXIl85QkqmXknd+PeeRkUOTbvwyYf7ieDz6c=
 | 
			
		||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
 | 
			
		||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
 | 
			
		||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
 | 
			
		||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
 | 
			
		||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 | 
			
		||||
github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
 | 
			
		||||
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
 | 
			
		||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
 | 
			
		||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 | 
			
		||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
 | 
			
		||||
@@ -125,6 +139,8 @@ github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259
 | 
			
		||||
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0/go.mod h1:C5LA5UO2ZXJrLaPLYtE1wUJMiyd/nwWaCO5cw/2pSHs=
 | 
			
		||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
 | 
			
		||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
 | 
			
		||||
github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58=
 | 
			
		||||
github.com/microcosm-cc/bluemonday v1.0.26/go.mod h1:JyzOCs9gkyQyjs+6h10UEVSe02CGwkhd72Xdqh78TWs=
 | 
			
		||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
 | 
			
		||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
 | 
			
		||||
github.com/minio/minio-go/v7 v7.0.62 h1:qNYsFZHEzl+NfH8UxW4jpmlKav1qUAgfY30YNRneVhc=
 | 
			
		||||
@@ -137,8 +153,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
 | 
			
		||||
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
 | 
			
		||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
 | 
			
		||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
 | 
			
		||||
github.com/mojocn/base64Captcha v1.3.1 h1:2Wbkt8Oc8qjmNJ5GyOfSo4tgVQPsbKMftqASnq8GlT0=
 | 
			
		||||
github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY=
 | 
			
		||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
 | 
			
		||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
 | 
			
		||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
 | 
			
		||||
@@ -186,14 +200,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
 | 
			
		||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
 | 
			
		||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
 | 
			
		||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
 | 
			
		||||
github.com/smartwalle/alipay/v3 v3.2.15 h1:3fvFJnINKKAOXHR/Iv20k1Z7KJ+nOh3oK214lELPqG8=
 | 
			
		||||
github.com/smartwalle/alipay/v3 v3.2.15/go.mod h1:niTNB609KyUYuAx9Bex/MawEjv2yPx4XOjxSAkqmGjE=
 | 
			
		||||
github.com/smartwalle/ncrypto v1.0.2 h1:pTAhCqtPCMhpOwFXX+EcMdR6PNzruBNoGQrN2S1GbGI=
 | 
			
		||||
github.com/smartwalle/ncrypto v1.0.2/go.mod h1:Dwlp6sfeNaPMnOxMNayMTacvC5JGEVln3CVdiVDgbBk=
 | 
			
		||||
github.com/smartwalle/ngx v1.0.6 h1:JPNqNOIj+2nxxFtrSkJO+vKJfeNUSEQueck/Wworjps=
 | 
			
		||||
github.com/smartwalle/ngx v1.0.6/go.mod h1:mx/nz2Pk5j+RBs7t6u6k22MPiBG/8CtOMpCnALIG8Y0=
 | 
			
		||||
github.com/smartwalle/nsign v1.0.8 h1:78KWtwKPrdt4Xsn+tNEBVxaTLIJBX9YRX0ZSrMUeuHo=
 | 
			
		||||
github.com/smartwalle/nsign v1.0.8/go.mod h1:eY6I4CJlyNdVMP+t6z1H6Jpd4m5/V+8xi44ufSTxXgc=
 | 
			
		||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 | 
			
		||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
 | 
			
		||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
 | 
			
		||||
@@ -208,10 +214,10 @@ github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gt
 | 
			
		||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
 | 
			
		||||
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
 | 
			
		||||
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
 | 
			
		||||
github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU=
 | 
			
		||||
github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY=
 | 
			
		||||
github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYgY=
 | 
			
		||||
github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE=
 | 
			
		||||
github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4=
 | 
			
		||||
github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0=
 | 
			
		||||
github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4=
 | 
			
		||||
github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY=
 | 
			
		||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
 | 
			
		||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
 | 
			
		||||
github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o=
 | 
			
		||||
@@ -247,14 +253,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
 | 
			
		||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
 | 
			
		||||
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
 | 
			
		||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
 | 
			
		||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
 | 
			
		||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
 | 
			
		||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
 | 
			
		||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
 | 
			
		||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
 | 
			
		||||
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
 | 
			
		||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ=
 | 
			
		||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM=
 | 
			
		||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
 | 
			
		||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
 | 
			
		||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
 | 
			
		||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
 | 
			
		||||
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
 | 
			
		||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
 | 
			
		||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 | 
			
		||||
@@ -262,11 +270,16 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
 | 
			
		||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
 | 
			
		||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
 | 
			
		||||
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
 | 
			
		||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
 | 
			
		||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
 | 
			
		||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
 | 
			
		||||
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
 | 
			
		||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
 | 
			
		||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
 | 
			
		||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 | 
			
		||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 | 
			
		||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 | 
			
		||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 | 
			
		||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
 | 
			
		||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
 | 
			
		||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 | 
			
		||||
@@ -281,17 +294,27 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
 | 
			
		||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 | 
			
		||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 | 
			
		||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
 | 
			
		||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 | 
			
		||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 | 
			
		||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
 | 
			
		||||
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
 | 
			
		||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
 | 
			
		||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
 | 
			
		||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
 | 
			
		||||
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
 | 
			
		||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 | 
			
		||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 | 
			
		||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 | 
			
		||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 | 
			
		||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
 | 
			
		||||
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
 | 
			
		||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
 | 
			
		||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
 | 
			
		||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
 | 
			
		||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
 | 
			
		||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
 | 
			
		||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
 | 
			
		||||
@@ -299,6 +322,7 @@ golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
 | 
			
		||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 | 
			
		||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
 | 
			
		||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
 | 
			
		||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
 | 
			
		||||
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
 | 
			
		||||
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
 | 
			
		||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 | 
			
		||||
 
 | 
			
		||||
@@ -8,19 +8,19 @@ package admin
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/golang-jwt/jwt/v5"
 | 
			
		||||
	"github.com/mojocn/base64Captcha"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
@@ -29,37 +29,47 @@ import (
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
// Manager 管理员
 | 
			
		||||
type Manager struct {
 | 
			
		||||
	Username  string `json:"username"`
 | 
			
		||||
	Password  string `json:"password"`
 | 
			
		||||
	Captcha   string `json:"captcha"`    // 验证码
 | 
			
		||||
	CaptchaId string `json:"captcha_id"` // 验证码id
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const SuperManagerID = 1
 | 
			
		||||
 | 
			
		||||
type ManagerHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	redis *redis.Client
 | 
			
		||||
	redis   *redis.Client
 | 
			
		||||
	captcha *service.CaptchaService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler {
 | 
			
		||||
	return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client}
 | 
			
		||||
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client, captcha *service.CaptchaService) *ManagerHandler {
 | 
			
		||||
	return &ManagerHandler{
 | 
			
		||||
		BaseHandler: handler.BaseHandler{DB: db, App: app},
 | 
			
		||||
		redis:       client,
 | 
			
		||||
		captcha:     captcha,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Login 登录
 | 
			
		||||
func (h *ManagerHandler) Login(c *gin.Context) {
 | 
			
		||||
	var data Manager
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Username string `json:"username"`
 | 
			
		||||
		Password string `json:"password"`
 | 
			
		||||
		Key      string `json:"key,omitempty"`
 | 
			
		||||
		Dots     string `json:"dots,omitempty"`
 | 
			
		||||
		X        int    `json:"x,omitempty"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// add captcha
 | 
			
		||||
	if !base64Captcha.DefaultMemStore.Verify(data.CaptchaId, data.Captcha, true) {
 | 
			
		||||
		resp.ERROR(c, "验证码错误!")
 | 
			
		||||
		return
 | 
			
		||||
	if h.App.SysConfig.EnabledVerify {
 | 
			
		||||
		var check bool
 | 
			
		||||
		if data.X != 0 {
 | 
			
		||||
			check = h.captcha.SlideCheck(data)
 | 
			
		||||
		} else {
 | 
			
		||||
			check = h.captcha.Check(data)
 | 
			
		||||
		}
 | 
			
		||||
		if !check {
 | 
			
		||||
			resp.ERROR(c, "请先完人机验证")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var manager model.AdminUser
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@ package admin
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
@@ -31,7 +32,6 @@ func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
 | 
			
		||||
func (h *ApiKeyHandler) Save(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id       uint   `json:"id"`
 | 
			
		||||
		Platform string `json:"platform"`
 | 
			
		||||
		Name     string `json:"name"`
 | 
			
		||||
		Type     string `json:"type"`
 | 
			
		||||
		Value    string `json:"value"`
 | 
			
		||||
@@ -48,24 +48,22 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
 | 
			
		||||
	if data.Id > 0 {
 | 
			
		||||
		h.DB.Find(&apiKey, data.Id)
 | 
			
		||||
	}
 | 
			
		||||
	apiKey.Platform = data.Platform
 | 
			
		||||
	apiKey.Value = data.Value
 | 
			
		||||
	apiKey.Type = data.Type
 | 
			
		||||
	apiKey.ApiURL = data.ApiURL
 | 
			
		||||
	apiKey.Enabled = data.Enabled
 | 
			
		||||
	apiKey.ProxyURL = data.ProxyURL
 | 
			
		||||
	apiKey.Name = data.Name
 | 
			
		||||
	res := h.DB.Save(&apiKey)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
	err := h.DB.Save(&apiKey).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var keyVo vo.ApiKey
 | 
			
		||||
	err := utils.CopyObject(apiKey, &keyVo)
 | 
			
		||||
	err = utils.CopyObject(apiKey, &keyVo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "数据拷贝失败!")
 | 
			
		||||
		resp.ERROR(c, fmt.Sprintf("拷贝数据失败:%v", err))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	keyVo.Id = apiKey.Id
 | 
			
		||||
@@ -121,10 +119,9 @@ func (h *ApiKeyHandler) Set(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
	err := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
@@ -137,10 +134,9 @@ func (h *ApiKeyHandler) Remove(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Where("id", id).Delete(&model.ApiKey{})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
	err := h.DB.Where("id", id).Delete(&model.ApiKey{}).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,46 +0,0 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/mojocn/base64Captcha"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type CaptchaHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewCaptchaHandler(app *core.AppServer) *CaptchaHandler {
 | 
			
		||||
	return &CaptchaHandler{BaseHandler: handler.BaseHandler{App: app}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CaptchaVo struct {
 | 
			
		||||
	CaptchaId string `json:"captcha_id"`
 | 
			
		||||
	PicPath   string `json:"pic_path"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetCaptcha 获取验证码
 | 
			
		||||
func (h *CaptchaHandler) GetCaptcha(c *gin.Context) {
 | 
			
		||||
	var captchaVo CaptchaVo
 | 
			
		||||
	driver := base64Captcha.NewDriverDigit(48, 130, 4, 0.4, 10)
 | 
			
		||||
	cp := base64Captcha.NewCaptcha(driver, base64Captcha.DefaultMemStore)
 | 
			
		||||
	// b64s是图片的base64编码
 | 
			
		||||
	id, b64s, err := cp.Generate()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "生成验证码错误!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	captchaVo.CaptchaId = id
 | 
			
		||||
	captchaVo.PicPath = b64s
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, captchaVo)
 | 
			
		||||
}
 | 
			
		||||
@@ -259,10 +259,9 @@ func (h *ChatHandler) RemoveChat(c *gin.Context) {
 | 
			
		||||
// RemoveMessage 删除聊天记录
 | 
			
		||||
func (h *ChatHandler) RemoveMessage(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{})
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", tx.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
	err := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{}).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
 
 | 
			
		||||
@@ -49,28 +49,32 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	item := model.ChatModel{
 | 
			
		||||
		Platform:    data.Platform,
 | 
			
		||||
		Name:        data.Name,
 | 
			
		||||
		Value:       data.Value,
 | 
			
		||||
		Enabled:     data.Enabled,
 | 
			
		||||
		SortNum:     data.SortNum,
 | 
			
		||||
		Open:        data.Open,
 | 
			
		||||
		MaxTokens:   data.MaxTokens,
 | 
			
		||||
		MaxContext:  data.MaxContext,
 | 
			
		||||
		Temperature: data.Temperature,
 | 
			
		||||
		KeyId:       data.KeyId,
 | 
			
		||||
		Power:       data.Power}
 | 
			
		||||
	item := model.ChatModel{}
 | 
			
		||||
	// 更新
 | 
			
		||||
	if data.Id > 0 {
 | 
			
		||||
		h.DB.Where("id", data.Id).First(&item)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	item.Name = data.Name
 | 
			
		||||
	item.Value = data.Value
 | 
			
		||||
	item.Enabled = data.Enabled
 | 
			
		||||
	item.SortNum = data.SortNum
 | 
			
		||||
	item.Open = data.Open
 | 
			
		||||
	item.Power = data.Power
 | 
			
		||||
	item.MaxTokens = data.MaxTokens
 | 
			
		||||
	item.MaxContext = data.MaxContext
 | 
			
		||||
	item.Temperature = data.Temperature
 | 
			
		||||
	item.KeyId = data.KeyId
 | 
			
		||||
 | 
			
		||||
	var res *gorm.DB
 | 
			
		||||
	if data.Id > 0 {
 | 
			
		||||
		item.Id = data.Id
 | 
			
		||||
		res = h.DB.Select("*").Omit("created_at").Updates(&item)
 | 
			
		||||
		res = h.DB.Save(&item)
 | 
			
		||||
	} else {
 | 
			
		||||
		res = h.DB.Create(&item)
 | 
			
		||||
	}
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -89,12 +93,12 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
 | 
			
		||||
func (h *ChatModelHandler) List(c *gin.Context) {
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	enable := h.GetBool(c, "enable")
 | 
			
		||||
	platform := h.GetTrim(c, "platform")
 | 
			
		||||
	name := h.GetTrim(c, "name")
 | 
			
		||||
	if enable {
 | 
			
		||||
		session = session.Where("enabled", enable)
 | 
			
		||||
	}
 | 
			
		||||
	if platform != "" {
 | 
			
		||||
		session = session.Where("platform", platform)
 | 
			
		||||
	if name != "" {
 | 
			
		||||
		session = session.Where("name LIKE ?", name+"%")
 | 
			
		||||
	}
 | 
			
		||||
	var items []model.ChatModel
 | 
			
		||||
	var cms = make([]vo.ChatModel, 0)
 | 
			
		||||
@@ -143,10 +147,9 @@ func (h *ChatModelHandler) Set(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
	err := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
@@ -164,10 +167,9 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for index, id := range data.Ids {
 | 
			
		||||
		res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("error with update database:", res.Error)
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		err := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
@@ -182,10 +184,9 @@ func (h *ChatModelHandler) Remove(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
	err := h.DB.Where("id = ?", id).Delete(&model.ChatModel{}).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@ package admin
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
@@ -45,11 +46,16 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
 | 
			
		||||
	role.Id = data.Id
 | 
			
		||||
	if data.CreatedAt > 0 {
 | 
			
		||||
		role.CreatedAt = time.Unix(data.CreatedAt, 0)
 | 
			
		||||
	} else {
 | 
			
		||||
		err = h.DB.Where("marker", data.Key).First(&role).Error
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			resp.ERROR(c, fmt.Sprintf("角色 %s 已存在", data.Key))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	res := h.DB.Save(&role)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
	err = h.DB.Save(&role).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 填充 ID 数据
 | 
			
		||||
@@ -114,10 +120,9 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for index, id := range data.Ids {
 | 
			
		||||
		res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("error with update database:", res.Error)
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		err := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
@@ -137,10 +142,9 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
	err := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
 
 | 
			
		||||
@@ -12,8 +12,6 @@ import (
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/mj"
 | 
			
		||||
	"geekai/service/sd"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
@@ -28,16 +26,12 @@ type ConfigHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	levelDB        *store.LevelDB
 | 
			
		||||
	licenseService *service.LicenseService
 | 
			
		||||
	mjServicePool  *mj.ServicePool
 | 
			
		||||
	sdServicePool  *sd.ServicePool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, mjPool *mj.ServicePool, sdPool *sd.ServicePool) *ConfigHandler {
 | 
			
		||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService) *ConfigHandler {
 | 
			
		||||
	return &ConfigHandler{
 | 
			
		||||
		BaseHandler:    handler.BaseHandler{App: app, DB: db},
 | 
			
		||||
		levelDB:        levelDB,
 | 
			
		||||
		mjServicePool:  mjPool,
 | 
			
		||||
		sdServicePool:  sdPool,
 | 
			
		||||
		licenseService: licenseService,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -50,6 +44,7 @@ func (h *ConfigHandler) Update(c *gin.Context) {
 | 
			
		||||
			Content string `json:"content,omitempty"`
 | 
			
		||||
			Updated bool   `json:"updated,omitempty"`
 | 
			
		||||
		} `json:"config"`
 | 
			
		||||
		ConfigBak types.SystemConfig `json:"config_bak,omitempty"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
@@ -57,6 +52,12 @@ func (h *ConfigHandler) Update(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ONLY authorized user can change the copyright
 | 
			
		||||
	if (data.Key == "system" && data.Config.Copyright != data.ConfigBak.Copyright) && !h.licenseService.GetLicense().Configs.DeCopy {
 | 
			
		||||
		resp.ERROR(c, "您无权修改版权信息,请先联系作者获取授权")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	value := utils.JsonEncode(&data.Config)
 | 
			
		||||
	config := model.Config{Key: data.Key, Config: value}
 | 
			
		||||
	res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
 | 
			
		||||
@@ -139,59 +140,3 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
 | 
			
		||||
	license := h.licenseService.GetLicense()
 | 
			
		||||
	resp.SUCCESS(c, license)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAppConfig 获取内置配置
 | 
			
		||||
func (h *ConfigHandler) GetAppConfig(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c, gin.H{
 | 
			
		||||
		"mj_plus":   h.App.Config.MjPlusConfigs,
 | 
			
		||||
		"mj_proxy":  h.App.Config.MjProxyConfigs,
 | 
			
		||||
		"sd":        h.App.Config.SdConfigs,
 | 
			
		||||
		"platforms": Platforms,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SaveDrawingConfig 保存AI绘画配置
 | 
			
		||||
func (h *ConfigHandler) SaveDrawingConfig(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Sd      []types.StableDiffusionConfig `json:"sd"`
 | 
			
		||||
		MjPlus  []types.MjPlusConfig          `json:"mj_plus"`
 | 
			
		||||
		MjProxy []types.MjProxyConfig         `json:"mj_proxy"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	changed := false
 | 
			
		||||
	if configChanged(data.Sd, h.App.Config.SdConfigs) {
 | 
			
		||||
		logger.Debugf("SD 配置变动了")
 | 
			
		||||
		h.App.Config.SdConfigs = data.Sd
 | 
			
		||||
		h.sdServicePool.InitServices(data.Sd)
 | 
			
		||||
		changed = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if configChanged(data.MjPlus, h.App.Config.MjPlusConfigs) || configChanged(data.MjProxy, h.App.Config.MjProxyConfigs) {
 | 
			
		||||
		logger.Debugf("MidJourney 配置变动了")
 | 
			
		||||
		h.App.Config.MjPlusConfigs = data.MjPlus
 | 
			
		||||
		h.App.Config.MjProxyConfigs = data.MjProxy
 | 
			
		||||
		h.mjServicePool.InitServices(data.MjPlus, data.MjProxy)
 | 
			
		||||
		changed = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if changed {
 | 
			
		||||
		err := core.SaveConfig(h.App.Config)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, "更新配置文档失败!")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func configChanged(c1 interface{}, c2 interface{}) bool {
 | 
			
		||||
	encode1 := utils.JsonEncode(c1)
 | 
			
		||||
	encode2 := utils.JsonEncode(c2)
 | 
			
		||||
	return utils.Md5(encode1) != utils.Md5(encode2)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -60,13 +60,6 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
 | 
			
		||||
		stats.Tokens += item.Tokens
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 众筹收入
 | 
			
		||||
	var rewards []model.Reward
 | 
			
		||||
	res = h.DB.Where("created_at > ?", zeroTime).Find(&rewards)
 | 
			
		||||
	for _, item := range rewards {
 | 
			
		||||
		stats.Income += item.Amount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 订单收入
 | 
			
		||||
	var orders []model.Order
 | 
			
		||||
	res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
 | 
			
		||||
@@ -101,13 +94,6 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
 | 
			
		||||
		historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 浮点数相加?
 | 
			
		||||
	// 统计最近7天的众筹
 | 
			
		||||
	res = h.DB.Where("created_at > ?", startDate).Find(&rewards)
 | 
			
		||||
	for _, item := range rewards {
 | 
			
		||||
		incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 统计最近7天的订单
 | 
			
		||||
	res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
 | 
			
		||||
	for _, item := range orders {
 | 
			
		||||
 
 | 
			
		||||
@@ -69,10 +69,9 @@ func (h *FunctionHandler) Set(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
	err := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
@@ -102,10 +101,9 @@ func (h *FunctionHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
 | 
			
		||||
	if id > 0 {
 | 
			
		||||
		res := h.DB.Delete(&model.Function{Id: uint(id)})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("error with update database:", res.Error)
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		err := h.DB.Delete(&model.Function{Id: uint(id)}).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -41,17 +41,16 @@ func (h *MenuHandler) Save(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Save(&model.Menu{
 | 
			
		||||
	err := h.DB.Save(&model.Menu{
 | 
			
		||||
		Id:      data.Id,
 | 
			
		||||
		Name:    data.Name,
 | 
			
		||||
		Icon:    data.Icon,
 | 
			
		||||
		URL:     data.URL,
 | 
			
		||||
		SortNum: data.SortNum,
 | 
			
		||||
		Enabled: data.Enabled,
 | 
			
		||||
	})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
	}).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
@@ -85,10 +84,9 @@ func (h *MenuHandler) Enable(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
	err := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
@@ -106,10 +104,9 @@ func (h *MenuHandler) Sort(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for index, id := range data.Ids {
 | 
			
		||||
		res := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index])
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("error with update database:", res.Error)
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		err := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index]).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
@@ -121,10 +118,9 @@ func (h *MenuHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
 | 
			
		||||
	if id > 0 {
 | 
			
		||||
		res := h.DB.Where("id", id).Delete(&model.Menu{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("error with update database:", res.Error)
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		err := h.DB.Where("id", id).Delete(&model.Menu{}).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -92,12 +92,21 @@ func (h *OrderHandler) Remove(c *gin.Context) {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("error with update database:", res.Error)
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		err := h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{}).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *OrderHandler) Clear(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	err := h.DB.Unscoped().Where("status <> ?", 2).Where("pay_time", 0).Delete(&model.Order{}).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -55,17 +55,16 @@ func (h *ProductHandler) Save(c *gin.Context) {
 | 
			
		||||
	if item.Id > 0 {
 | 
			
		||||
		item.CreatedAt = time.Unix(data.CreatedAt, 0)
 | 
			
		||||
	}
 | 
			
		||||
	res := h.DB.Save(&item)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
	err := h.DB.Save(&item).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var itemVo vo.Product
 | 
			
		||||
	err := utils.CopyObject(item, &itemVo)
 | 
			
		||||
	err = utils.CopyObject(item, &itemVo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "数据拷贝失败!")
 | 
			
		||||
		resp.ERROR(c, "数据拷贝失败: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	itemVo.Id = item.Id
 | 
			
		||||
@@ -106,10 +105,9 @@ func (h *ProductHandler) Enable(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
	err := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
@@ -127,10 +125,9 @@ func (h *ProductHandler) Sort(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for index, id := range data.Ids {
 | 
			
		||||
		res := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index])
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("error with update database:", res.Error)
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		err := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index]).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
@@ -142,10 +139,9 @@ func (h *ProductHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
 | 
			
		||||
	if id > 0 {
 | 
			
		||||
		res := h.DB.Where("id", id).Delete(&model.Product{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("error with update database:", res.Error)
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		err := h.DB.Where("id", id).Delete(&model.Product{}).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										164
									
								
								api/handler/admin/redeem_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										164
									
								
								api/handler/admin/redeem_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,164 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type RedeemHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewRedeemHandler(app *core.AppServer, db *gorm.DB) *RedeemHandler {
 | 
			
		||||
	return &RedeemHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *RedeemHandler) List(c *gin.Context) {
 | 
			
		||||
	page := h.GetInt(c, "page", 1)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 20)
 | 
			
		||||
	code := c.Query("code")
 | 
			
		||||
	status := h.GetInt(c, "status", -1)
 | 
			
		||||
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	if code != "" {
 | 
			
		||||
		session.Where("code LIKE ?", "%"+code+"%")
 | 
			
		||||
	}
 | 
			
		||||
	if status == 0 {
 | 
			
		||||
		session.Where("redeem_at = ?", 0)
 | 
			
		||||
	} else if status == 1 {
 | 
			
		||||
		session.Where("redeem_at > ?", 0)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.Redeem{}).Count(&total)
 | 
			
		||||
	var redeems []model.Redeem
 | 
			
		||||
	offset := (page - 1) * pageSize
 | 
			
		||||
	err := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&redeems).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var items = make([]vo.Redeem, 0)
 | 
			
		||||
	userIds := make([]uint, 0)
 | 
			
		||||
	for _, v := range redeems {
 | 
			
		||||
		userIds = append(userIds, v.UserId)
 | 
			
		||||
	}
 | 
			
		||||
	var users []model.User
 | 
			
		||||
	h.DB.Where("id IN ?", userIds).Find(&users)
 | 
			
		||||
	var userMap = make(map[uint]model.User)
 | 
			
		||||
	for _, u := range users {
 | 
			
		||||
		userMap[u.Id] = u
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, v := range redeems {
 | 
			
		||||
		var r vo.Redeem
 | 
			
		||||
		err = utils.CopyObject(v, &r)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		r.Id = v.Id
 | 
			
		||||
		r.Username = userMap[v.UserId].Username
 | 
			
		||||
		r.CreatedAt = v.CreatedAt.Unix()
 | 
			
		||||
		items = append(items, r)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *RedeemHandler) Create(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Name  string `json:"name"`
 | 
			
		||||
		Power int    `json:"power"`
 | 
			
		||||
		Num   int    `json:"num"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	counter := 0
 | 
			
		||||
	codes := make([]string, 0)
 | 
			
		||||
	var errMsg = ""
 | 
			
		||||
	if data.Num > 0 {
 | 
			
		||||
		for i := 0; i < data.Num; i++ {
 | 
			
		||||
			code, err := utils.GenRedeemCode(32)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				errMsg = err.Error()
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			err = h.DB.Create(&model.Redeem{
 | 
			
		||||
				Code:    code,
 | 
			
		||||
				Name:    data.Name,
 | 
			
		||||
				Power:   data.Power,
 | 
			
		||||
				Enabled: true,
 | 
			
		||||
			}).Error
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				errMsg = err.Error()
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			codes = append(codes, code)
 | 
			
		||||
			counter++
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if counter == 0 {
 | 
			
		||||
		resp.ERROR(c, errMsg)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, gin.H{
 | 
			
		||||
		"counter": counter,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *RedeemHandler) Set(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id    uint        `json:"id"`
 | 
			
		||||
		Filed string      `json:"filed"`
 | 
			
		||||
		Value interface{} `json:"value"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := h.DB.Model(&model.Redeem{}).Where("id = ?", data.Id).Update(data.Filed, data.Value).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *RedeemHandler) Remove(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id uint
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if data.Id > 0 {
 | 
			
		||||
		err := h.DB.Where("id", data.Id).Delete(&model.Redeem{}).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
@@ -1,81 +0,0 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type RewardHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
 | 
			
		||||
	return &RewardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *RewardHandler) List(c *gin.Context) {
 | 
			
		||||
	var items []model.Reward
 | 
			
		||||
	res := h.DB.Order("id DESC").Find(&items)
 | 
			
		||||
	var rewards = make([]vo.Reward, 0)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		userIds := make([]uint, 0)
 | 
			
		||||
		for _, v := range items {
 | 
			
		||||
			userIds = append(userIds, v.UserId)
 | 
			
		||||
		}
 | 
			
		||||
		var users []model.User
 | 
			
		||||
		h.DB.Where("id IN ?", userIds).Find(&users)
 | 
			
		||||
		var userMap = make(map[uint]model.User)
 | 
			
		||||
		for _, u := range users {
 | 
			
		||||
			userMap[u.Id] = u
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, v := range items {
 | 
			
		||||
			var r vo.Reward
 | 
			
		||||
			err := utils.CopyObject(v, &r)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			r.Id = v.Id
 | 
			
		||||
			r.Username = userMap[v.UserId].Username
 | 
			
		||||
			r.CreatedAt = v.CreatedAt.Unix()
 | 
			
		||||
			r.UpdatedAt = v.UpdatedAt.Unix()
 | 
			
		||||
			rewards = append(rewards, r)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, rewards)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *RewardHandler) Remove(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id uint
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if data.Id > 0 {
 | 
			
		||||
		res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("error with update database:", res.Error)
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
@@ -1,12 +0,0 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
import "geekai/core/types"
 | 
			
		||||
 | 
			
		||||
var Platforms = []types.Platform{
 | 
			
		||||
	types.OpenAI,
 | 
			
		||||
	types.QWen,
 | 
			
		||||
	types.XunFei,
 | 
			
		||||
	types.ChatGLM,
 | 
			
		||||
	types.Baidu,
 | 
			
		||||
	types.Azure,
 | 
			
		||||
}
 | 
			
		||||
@@ -49,7 +49,7 @@ func (h *UserHandler) List(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	session.Model(&model.User{}).Count(&total)
 | 
			
		||||
	res := session.Offset(offset).Limit(pageSize).Find(&items)
 | 
			
		||||
	res := session.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			var user vo.User
 | 
			
		||||
@@ -112,7 +112,7 @@ func (h *UserHandler) Save(c *gin.Context) {
 | 
			
		||||
		res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("error with update database:", res.Error)
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
			resp.ERROR(c, res.Error.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		// 记录算力日志
 | 
			
		||||
@@ -136,10 +136,16 @@ func (h *UserHandler) Save(c *gin.Context) {
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		// 检查用户是否已经存在
 | 
			
		||||
		h.DB.Where("username", data.Username).First(&user)
 | 
			
		||||
		if user.Id > 0 {
 | 
			
		||||
			resp.ERROR(c, "用户名已存在")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		salt := utils.RandString(8)
 | 
			
		||||
		u := model.User{
 | 
			
		||||
			Username:    data.Username,
 | 
			
		||||
			Nickname:    fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
 | 
			
		||||
			Password:    utils.GenPassword(data.Password, salt),
 | 
			
		||||
			Avatar:      "/images/avatar/user.png",
 | 
			
		||||
			Salt:        salt,
 | 
			
		||||
@@ -149,6 +155,11 @@ func (h *UserHandler) Save(c *gin.Context) {
 | 
			
		||||
			ChatModels:  utils.JsonEncode(data.ChatModels),
 | 
			
		||||
			ExpiredTime: utils.Str2stamp(data.ExpiredTime),
 | 
			
		||||
		}
 | 
			
		||||
		if h.licenseService.GetLicense().Configs.DeCopy {
 | 
			
		||||
			u.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
 | 
			
		||||
		} else {
 | 
			
		||||
			u.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
 | 
			
		||||
		}
 | 
			
		||||
		res = h.DB.Create(&u)
 | 
			
		||||
		_ = utils.CopyObject(u, &userVo)
 | 
			
		||||
		userVo.Id = u.Id
 | 
			
		||||
@@ -157,8 +168,7 @@ func (h *UserHandler) Save(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败")
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -194,33 +204,69 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *UserHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	if id <= 0 {
 | 
			
		||||
	id := c.Query("id")
 | 
			
		||||
	ids := c.QueryArray("ids[]")
 | 
			
		||||
	if id != "" {
 | 
			
		||||
		ids = append(ids, id)
 | 
			
		||||
	}
 | 
			
		||||
	if len(ids) == 0 {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 删除用户
 | 
			
		||||
	res := h.DB.Where("id = ?", id).Delete(&model.User{})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
 | 
			
		||||
	tx := h.DB.Begin()
 | 
			
		||||
	var err error
 | 
			
		||||
	for _, id = range ids {
 | 
			
		||||
		// 删除用户
 | 
			
		||||
		if err = tx.Where("id", id).Delete(&model.User{}).Error; err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		// 删除聊天记录
 | 
			
		||||
		if err = tx.Unscoped().Where("user_id = ?", id).Delete(&model.ChatItem{}).Error; err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		// 删除聊天历史记录
 | 
			
		||||
		if err = tx.Unscoped().Where("user_id = ?", id).Delete(&model.ChatMessage{}).Error; err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		// 删除登录日志
 | 
			
		||||
		if err = tx.Where("user_id = ?", id).Delete(&model.UserLoginLog{}).Error; err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		// 删除算力日志
 | 
			
		||||
		if err = tx.Where("user_id = ?", id).Delete(&model.PowerLog{}).Error; err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		if err = tx.Where("user_id = ?", id).Delete(&model.InviteLog{}).Error; err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		// 删除众筹日志
 | 
			
		||||
		if err = tx.Where("user_id = ?", id).Delete(&model.Redeem{}).Error; err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		// 删除绘图任务
 | 
			
		||||
		if err = tx.Where("user_id = ?", id).Delete(&model.MidJourneyJob{}).Error; err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		if err = tx.Where("user_id = ?", id).Delete(&model.SdJob{}).Error; err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		if err = tx.Where("user_id = ?", id).Delete(&model.DallJob{}).Error; err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		if err = tx.Where("user_id = ?", id).Delete(&model.SunoJob{}).Error; err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		if err = tx.Where("user_id = ?", id).Delete(&model.VideoJob{}).Error; err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "删除失败")
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 删除聊天记录
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{})
 | 
			
		||||
	// 删除聊天历史记录
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{})
 | 
			
		||||
	// 删除登录日志
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
 | 
			
		||||
	// 删除算力日志
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{})
 | 
			
		||||
	// 删除众筹日志
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.Reward{})
 | 
			
		||||
	// 删除绘图任务
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{})
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.SdJob{})
 | 
			
		||||
	//  删除订单
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.Order{})
 | 
			
		||||
	tx.Commit()
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -31,9 +31,14 @@ func (h *ChatModelHandler) List(c *gin.Context) {
 | 
			
		||||
	var items []model.ChatModel
 | 
			
		||||
	var chatModels = make([]vo.ChatModel, 0)
 | 
			
		||||
	var res *gorm.DB
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{}).Where("enabled", true)
 | 
			
		||||
	t := c.Query("type")
 | 
			
		||||
	if t != "" {
 | 
			
		||||
		session = session.Where("type", t)
 | 
			
		||||
	}
 | 
			
		||||
	// 如果用户没有登录,则加载所有开放模型
 | 
			
		||||
	if !h.IsLogin(c) {
 | 
			
		||||
		res = h.DB.Where("enabled", true).Where("open", true).Order("sort_num ASC").Find(&items)
 | 
			
		||||
		res = session.Where("open", true).Order("sort_num ASC").Find(&items)
 | 
			
		||||
	} else {
 | 
			
		||||
		user, _ := h.GetLoginUser(c)
 | 
			
		||||
		var models []int
 | 
			
		||||
 
 | 
			
		||||
@@ -29,45 +29,32 @@ func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
 | 
			
		||||
 | 
			
		||||
// List 获取用户聊天应用列表
 | 
			
		||||
func (h *ChatRoleHandler) List(c *gin.Context) {
 | 
			
		||||
	all := h.GetBool(c, "all")
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	var roles []model.ChatRole
 | 
			
		||||
	var roleVos = make([]vo.ChatRole, 0)
 | 
			
		||||
	query := h.DB.Where("enable", true)
 | 
			
		||||
	if userId > 0 {
 | 
			
		||||
		var user model.User
 | 
			
		||||
		h.DB.First(&user, userId)
 | 
			
		||||
		var roleKeys []string
 | 
			
		||||
		err := utils.JsonDecode(user.ChatRoles, &roleKeys)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, "角色解析失败!")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		query = query.Where("marker IN ?", roleKeys)
 | 
			
		||||
	}
 | 
			
		||||
	if id > 0 {
 | 
			
		||||
		query = query.Or("id", id)
 | 
			
		||||
	}
 | 
			
		||||
	res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.SUCCESS(c, roleVos)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 获取所有角色
 | 
			
		||||
	if userId == 0 || all {
 | 
			
		||||
		// 转成 vo
 | 
			
		||||
		var roleVos = make([]vo.ChatRole, 0)
 | 
			
		||||
		for _, r := range roles {
 | 
			
		||||
			var v vo.ChatRole
 | 
			
		||||
			err := utils.CopyObject(r, &v)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				v.Id = r.Id
 | 
			
		||||
				roleVos = append(roleVos, v)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		resp.SUCCESS(c, roleVos)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user model.User
 | 
			
		||||
	h.DB.First(&user, userId)
 | 
			
		||||
	var roleKeys []string
 | 
			
		||||
	err := utils.JsonDecode(user.ChatRoles, &roleKeys)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "角色解析失败!")
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var roleVos = make([]vo.ChatRole, 0)
 | 
			
		||||
	for _, r := range roles {
 | 
			
		||||
		if !utils.ContainsStr(roleKeys, r.Key) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		var v vo.ChatRole
 | 
			
		||||
		err := utils.CopyObject(r, &v)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
@@ -94,10 +81,9 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
	err = h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys)).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,111 +0,0 @@
 | 
			
		||||
package chatimpl
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 微软 Azure 模型消息发送实现
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	session *types.ChatSession,
 | 
			
		||||
	role model.ChatRole,
 | 
			
		||||
	prompt string,
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	var apiKey = model.ApiKey{}
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session, &apiKey)
 | 
			
		||||
	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
			return fmt.Errorf("用户取消了请求:%s", prompt)
 | 
			
		||||
		} else if strings.Contains(err.Error(), "no available key") {
 | 
			
		||||
			return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
		}
 | 
			
		||||
		return err
 | 
			
		||||
	} else {
 | 
			
		||||
		defer response.Body.Close()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	contentType := response.Header.Get("Content-Type")
 | 
			
		||||
	if strings.Contains(contentType, "text/event-stream") {
 | 
			
		||||
		replyCreatedAt := time.Now() // 记录回复时间
 | 
			
		||||
		// 循环读取 Chunk 消息
 | 
			
		||||
		var message = types.Message{}
 | 
			
		||||
		var contents = make([]string, 0)
 | 
			
		||||
		scanner := bufio.NewScanner(response.Body)
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			line := scanner.Text()
 | 
			
		||||
			if !strings.Contains(line, "data:") || len(line) < 30 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var responseBody = types.ApiResponse{}
 | 
			
		||||
			err = json.Unmarshal([]byte(line[6:]), &responseBody)
 | 
			
		||||
			if err != nil { // 数据解析出错
 | 
			
		||||
				return errors.New(line)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if len(responseBody.Choices) == 0 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 初始化 role
 | 
			
		||||
			if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
 | 
			
		||||
				message.Role = responseBody.Choices[0].Delta.Role
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
				continue
 | 
			
		||||
			} else if responseBody.Choices[0].FinishReason != "" {
 | 
			
		||||
				break // 输出完成或者输出中断了
 | 
			
		||||
			} else {
 | 
			
		||||
				content := responseBody.Choices[0].Delta.Content
 | 
			
		||||
				contents = append(contents, utils.InterfaceToString(content))
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
		} // end for
 | 
			
		||||
 | 
			
		||||
		if err := scanner.Err(); err != nil {
 | 
			
		||||
			if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
				logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Error("信息读取出错:", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	} else {
 | 
			
		||||
		body, _ := io.ReadAll(response.Body)
 | 
			
		||||
		return fmt.Errorf("请求大模型 API 失败:%s", body)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -1,185 +0,0 @@
 | 
			
		||||
package chatimpl
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type baiduResp struct {
 | 
			
		||||
	Id               string `json:"id"`
 | 
			
		||||
	Object           string `json:"object"`
 | 
			
		||||
	Created          int    `json:"created"`
 | 
			
		||||
	SentenceId       int    `json:"sentence_id"`
 | 
			
		||||
	IsEnd            bool   `json:"is_end"`
 | 
			
		||||
	IsTruncated      bool   `json:"is_truncated"`
 | 
			
		||||
	Result           string `json:"result"`
 | 
			
		||||
	NeedClearHistory bool   `json:"need_clear_history"`
 | 
			
		||||
	Usage            struct {
 | 
			
		||||
		PromptTokens     int `json:"prompt_tokens"`
 | 
			
		||||
		CompletionTokens int `json:"completion_tokens"`
 | 
			
		||||
		TotalTokens      int `json:"total_tokens"`
 | 
			
		||||
	} `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 百度文心一言消息发送实现
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) sendBaiduMessage(
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	session *types.ChatSession,
 | 
			
		||||
	role model.ChatRole,
 | 
			
		||||
	prompt string,
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	var apiKey = model.ApiKey{}
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session, &apiKey)
 | 
			
		||||
	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
			return fmt.Errorf("用户取消了请求:%s", prompt)
 | 
			
		||||
		} else if strings.Contains(err.Error(), "no available key") {
 | 
			
		||||
			return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
		}
 | 
			
		||||
		return err
 | 
			
		||||
	} else {
 | 
			
		||||
		defer response.Body.Close()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	contentType := response.Header.Get("Content-Type")
 | 
			
		||||
	if strings.Contains(contentType, "text/event-stream") {
 | 
			
		||||
		replyCreatedAt := time.Now() // 记录回复时间
 | 
			
		||||
		// 循环读取 Chunk 消息
 | 
			
		||||
		var message = types.Message{}
 | 
			
		||||
		var contents = make([]string, 0)
 | 
			
		||||
		var content string
 | 
			
		||||
		scanner := bufio.NewScanner(response.Body)
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			line := scanner.Text()
 | 
			
		||||
			if len(line) < 5 || strings.HasPrefix(line, "id:") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if strings.HasPrefix(line, "data:") {
 | 
			
		||||
				content = line[5:]
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 处理代码换行
 | 
			
		||||
			if len(content) == 0 {
 | 
			
		||||
				content = "\n"
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var resp baiduResp
 | 
			
		||||
			err := utils.JsonDecode(content, &resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with parse data line: ", err)
 | 
			
		||||
				utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if len(contents) == 0 {
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
			}
 | 
			
		||||
			utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
				Type:    types.WsMiddle,
 | 
			
		||||
				Content: utils.InterfaceToString(resp.Result),
 | 
			
		||||
			})
 | 
			
		||||
			contents = append(contents, resp.Result)
 | 
			
		||||
 | 
			
		||||
			if resp.IsTruncated {
 | 
			
		||||
				utils.ReplyMessage(ws, "AI 输出异常中断")
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if resp.IsEnd {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		} // end for
 | 
			
		||||
 | 
			
		||||
		if err := scanner.Err(); err != nil {
 | 
			
		||||
			if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
				logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Error("信息读取出错:", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		body, _ := io.ReadAll(response.Body)
 | 
			
		||||
		return fmt.Errorf("请求大模型 API 失败:%s", body)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) getBaiduToken(apiKey string) (string, error) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	tokenString, err := h.redis.Get(ctx, apiKey).Result()
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return tokenString, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	expr := time.Hour * 24 * 20 // access_token 有效期
 | 
			
		||||
	key := strings.Split(apiKey, "|")
 | 
			
		||||
	if len(key) != 2 {
 | 
			
		||||
		return "", fmt.Errorf("invalid api key: %s", apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1])
 | 
			
		||||
	client := &http.Client{}
 | 
			
		||||
	req, err := http.NewRequest("POST", url, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Add("Content-Type", "application/json")
 | 
			
		||||
	req.Header.Add("Accept", "application/json")
 | 
			
		||||
 | 
			
		||||
	res, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with send request: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	defer res.Body.Close()
 | 
			
		||||
 | 
			
		||||
	body, err := io.ReadAll(res.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with read response: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	var r map[string]interface{}
 | 
			
		||||
	err = json.Unmarshal(body, &r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with parse response: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r["error"] != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with api response: %s", r["error_description"])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tokenString = fmt.Sprintf("%s", r["access_token"])
 | 
			
		||||
	h.redis.Set(ctx, apiKey, tokenString, expr)
 | 
			
		||||
	return tokenString, nil
 | 
			
		||||
}
 | 
			
		||||
@@ -44,14 +44,20 @@ type ChatHandler struct {
 | 
			
		||||
	redis          *redis.Client
 | 
			
		||||
	uploadManager  *oss.UploaderManager
 | 
			
		||||
	licenseService *service.LicenseService
 | 
			
		||||
	ReqCancelFunc  *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
 | 
			
		||||
	ChatContexts   *types.LMap[string, []types.Message]    // 聊天上下文 Map [chatId] => []Message
 | 
			
		||||
	userService    *service.UserService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ChatHandler {
 | 
			
		||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler {
 | 
			
		||||
	return &ChatHandler{
 | 
			
		||||
		BaseHandler:    handler.BaseHandler{App: app, DB: db},
 | 
			
		||||
		redis:          redis,
 | 
			
		||||
		uploadManager:  manager,
 | 
			
		||||
		licenseService: licenseService,
 | 
			
		||||
		ReqCancelFunc:  types.NewLMap[string, context.CancelFunc](),
 | 
			
		||||
		ChatContexts:   types.NewLMap[string, []types.Message](),
 | 
			
		||||
		userService:    userService,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -67,6 +73,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
	roleId := h.GetInt(c, "role_id", 0)
 | 
			
		||||
	chatId := c.Query("chat_id")
 | 
			
		||||
	modelId := h.GetInt(c, "model_id", 0)
 | 
			
		||||
	tools := c.Query("tools")
 | 
			
		||||
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	var chatRole model.ChatRole
 | 
			
		||||
@@ -89,21 +96,11 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	session := h.App.ChatSession.Get(sessionId)
 | 
			
		||||
	if session == nil {
 | 
			
		||||
		user, err := h.GetLoginUser(c)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Info("用户未登录")
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		session = &types.ChatSession{
 | 
			
		||||
			SessionId: sessionId,
 | 
			
		||||
			ClientIP:  c.ClientIP(),
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
		}
 | 
			
		||||
		h.App.ChatSession.Put(sessionId, session)
 | 
			
		||||
	session := &types.ChatSession{
 | 
			
		||||
		SessionId: sessionId,
 | 
			
		||||
		ClientIP:  c.ClientIP(),
 | 
			
		||||
		UserId:    h.GetLoginUserId(c),
 | 
			
		||||
		Tools:     tools,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// use old chat data override the chat model and role ID
 | 
			
		||||
@@ -123,24 +120,19 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
		MaxTokens:   chatModel.MaxTokens,
 | 
			
		||||
		MaxContext:  chatModel.MaxContext,
 | 
			
		||||
		Temperature: chatModel.Temperature,
 | 
			
		||||
		KeyId:       chatModel.KeyId,
 | 
			
		||||
		Platform:    chatModel.Platform}
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
 | 
			
		||||
		KeyId:       chatModel.KeyId}
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.ClientIP())
 | 
			
		||||
 | 
			
		||||
	// 保存会话连接
 | 
			
		||||
	h.App.ChatClients.Put(sessionId, client)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			_, msg, err := client.Receive()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
 | 
			
		||||
				client.Close()
 | 
			
		||||
				h.App.ChatClients.Delete(sessionId)
 | 
			
		||||
				h.App.ChatSession.Delete(sessionId)
 | 
			
		||||
				cancelFunc := h.App.ReqCancelFunc.Get(sessionId)
 | 
			
		||||
				cancelFunc := h.ReqCancelFunc.Get(sessionId)
 | 
			
		||||
				if cancelFunc != nil {
 | 
			
		||||
					cancelFunc()
 | 
			
		||||
					h.App.ReqCancelFunc.Delete(sessionId)
 | 
			
		||||
					h.ReqCancelFunc.Delete(sessionId)
 | 
			
		||||
				}
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
@@ -160,7 +152,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
			logger.Info("Receive a message: ", message.Content)
 | 
			
		||||
 | 
			
		||||
			ctx, cancel := context.WithCancel(context.Background())
 | 
			
		||||
			h.App.ReqCancelFunc.Put(sessionId, cancel)
 | 
			
		||||
			h.ReqCancelFunc.Put(sessionId, cancel)
 | 
			
		||||
			// 回复消息
 | 
			
		||||
			err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
@@ -219,63 +211,48 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
		Model:  session.Model.Value,
 | 
			
		||||
		Stream: true,
 | 
			
		||||
	}
 | 
			
		||||
	switch session.Model.Platform {
 | 
			
		||||
	case types.Azure.Value, types.ChatGLM.Value, types.Baidu.Value, types.XunFei.Value:
 | 
			
		||||
		req.Temperature = session.Model.Temperature
 | 
			
		||||
		req.MaxTokens = session.Model.MaxTokens
 | 
			
		||||
		break
 | 
			
		||||
	case types.OpenAI.Value:
 | 
			
		||||
		req.Temperature = session.Model.Temperature
 | 
			
		||||
		req.MaxTokens = session.Model.MaxTokens
 | 
			
		||||
		// OpenAI 支持函数功能
 | 
			
		||||
	req.Temperature = session.Model.Temperature
 | 
			
		||||
	req.MaxTokens = session.Model.MaxTokens
 | 
			
		||||
 | 
			
		||||
	if session.Tools != "" {
 | 
			
		||||
		toolIds := strings.Split(session.Tools, ",")
 | 
			
		||||
		var items []model.Function
 | 
			
		||||
		res := h.DB.Where("enabled", true).Find(&items)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var tools = make([]types.Tool, 0)
 | 
			
		||||
		for _, v := range items {
 | 
			
		||||
			var parameters map[string]interface{}
 | 
			
		||||
			err = utils.JsonDecode(v.Parameters, ¶meters)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
		res = h.DB.Where("enabled", true).Where("id IN ?", toolIds).Find(&items)
 | 
			
		||||
		if res.Error == nil {
 | 
			
		||||
			var tools = make([]types.Tool, 0)
 | 
			
		||||
			for _, v := range items {
 | 
			
		||||
				var parameters map[string]interface{}
 | 
			
		||||
				err = utils.JsonDecode(v.Parameters, ¶meters)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				tool := types.Tool{
 | 
			
		||||
					Type: "function",
 | 
			
		||||
					Function: types.Function{
 | 
			
		||||
						Name:        v.Name,
 | 
			
		||||
						Description: v.Description,
 | 
			
		||||
						Parameters:  parameters,
 | 
			
		||||
					},
 | 
			
		||||
				}
 | 
			
		||||
				if v, ok := parameters["required"]; v == nil || !ok {
 | 
			
		||||
					tool.Function.Parameters["required"] = []string{}
 | 
			
		||||
				}
 | 
			
		||||
				tools = append(tools, tool)
 | 
			
		||||
			}
 | 
			
		||||
			tool := types.Tool{
 | 
			
		||||
				Type: "function",
 | 
			
		||||
				Function: types.Function{
 | 
			
		||||
					Name:        v.Name,
 | 
			
		||||
					Description: v.Description,
 | 
			
		||||
					Parameters:  parameters,
 | 
			
		||||
				},
 | 
			
		||||
			}
 | 
			
		||||
			if v, ok := parameters["required"]; v == nil || !ok {
 | 
			
		||||
				tool.Function.Parameters["required"] = []string{}
 | 
			
		||||
			}
 | 
			
		||||
			tools = append(tools, tool)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if len(tools) > 0 {
 | 
			
		||||
			req.Tools = tools
 | 
			
		||||
			req.ToolChoice = "auto"
 | 
			
		||||
			if len(tools) > 0 {
 | 
			
		||||
				req.Tools = tools
 | 
			
		||||
				req.ToolChoice = "auto"
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	case types.QWen.Value:
 | 
			
		||||
		req.Parameters = map[string]interface{}{
 | 
			
		||||
			"max_tokens":  session.Model.MaxTokens,
 | 
			
		||||
			"temperature": session.Model.Temperature,
 | 
			
		||||
		}
 | 
			
		||||
		break
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		return fmt.Errorf("不支持的平台:%s", session.Model.Platform)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 加载聊天上下文
 | 
			
		||||
	chatCtx := make([]types.Message, 0)
 | 
			
		||||
	messages := make([]types.Message, 0)
 | 
			
		||||
	if h.App.SysConfig.EnableContext {
 | 
			
		||||
		if h.App.ChatContexts.Has(session.ChatId) {
 | 
			
		||||
			messages = h.App.ChatContexts.Get(session.ChatId)
 | 
			
		||||
		if h.ChatContexts.Has(session.ChatId) {
 | 
			
		||||
			messages = h.ChatContexts.Get(session.ChatId)
 | 
			
		||||
		} else {
 | 
			
		||||
			_ = utils.JsonDecode(role.Context, &messages)
 | 
			
		||||
			if h.App.SysConfig.ContextDeep > 0 {
 | 
			
		||||
@@ -300,7 +277,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
		tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
 | 
			
		||||
		tokens += tks + promptTokens
 | 
			
		||||
 | 
			
		||||
		for _, v := range messages {
 | 
			
		||||
		for i := len(messages) - 1; i >= 0; i-- {
 | 
			
		||||
			v := messages[i]
 | 
			
		||||
			tks, _ := utils.CalcTokens(v.Content, req.Model)
 | 
			
		||||
			// 上下文 token 超出了模型的最大上下文长度
 | 
			
		||||
			if tokens+tks >= session.Model.MaxContext {
 | 
			
		||||
@@ -323,66 +301,69 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
		reqMgs = append(reqMgs, m)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if session.Model.Platform == types.QWen.Value {
 | 
			
		||||
		req.Input = make(map[string]interface{})
 | 
			
		||||
		reqMgs = append(reqMgs, types.Message{
 | 
			
		||||
			Role:    "user",
 | 
			
		||||
			Content: prompt,
 | 
			
		||||
		})
 | 
			
		||||
		req.Input["messages"] = reqMgs
 | 
			
		||||
	} else if session.Model.Platform == types.OpenAI.Value { // extract image for gpt-vision model
 | 
			
		||||
		imgURLs := utils.ExtractImgURL(prompt)
 | 
			
		||||
		logger.Debugf("detected IMG: %+v", imgURLs)
 | 
			
		||||
		var content interface{}
 | 
			
		||||
		if len(imgURLs) > 0 {
 | 
			
		||||
			data := make([]interface{}, 0)
 | 
			
		||||
			text := prompt
 | 
			
		||||
			for _, v := range imgURLs {
 | 
			
		||||
				text = strings.Replace(text, v, "", 1)
 | 
			
		||||
				data = append(data, gin.H{
 | 
			
		||||
					"type": "image_url",
 | 
			
		||||
					"image_url": gin.H{
 | 
			
		||||
						"url": v,
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
	fullPrompt := prompt
 | 
			
		||||
	text := prompt
 | 
			
		||||
	// extract files in prompt
 | 
			
		||||
	files := utils.ExtractFileURLs(prompt)
 | 
			
		||||
	logger.Debugf("detected FILES: %+v", files)
 | 
			
		||||
	// 如果不是逆向模型,则提取文件内容
 | 
			
		||||
	if len(files) > 0 && !(session.Model.Value == "gpt-4-all" ||
 | 
			
		||||
		strings.HasPrefix(session.Model.Value, "gpt-4-gizmo") ||
 | 
			
		||||
		strings.HasSuffix(session.Model.Value, "claude-3")) {
 | 
			
		||||
		contents := make([]string, 0)
 | 
			
		||||
		var file model.File
 | 
			
		||||
		for _, v := range files {
 | 
			
		||||
			h.DB.Where("url = ?", v).First(&file)
 | 
			
		||||
			content, err := utils.ReadFileContent(v, h.App.Config.TikaHost)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with read file: ", err)
 | 
			
		||||
			} else {
 | 
			
		||||
				contents = append(contents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
 | 
			
		||||
			}
 | 
			
		||||
			data = append(data, gin.H{
 | 
			
		||||
				"type": "text",
 | 
			
		||||
				"text": text,
 | 
			
		||||
			})
 | 
			
		||||
			content = data
 | 
			
		||||
		} else {
 | 
			
		||||
			content = prompt
 | 
			
		||||
			text = strings.Replace(text, v, "", 1)
 | 
			
		||||
		}
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			fullPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML):\n\n %s\n\n 问题:%s", strings.Join(contents, "\n"), text)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		tokens, _ := utils.CalcTokens(fullPrompt, req.Model)
 | 
			
		||||
		if tokens > session.Model.MaxContext {
 | 
			
		||||
			return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。")
 | 
			
		||||
		}
 | 
			
		||||
		req.Messages = append(reqMgs, map[string]interface{}{
 | 
			
		||||
			"role":    "user",
 | 
			
		||||
			"content": content,
 | 
			
		||||
		})
 | 
			
		||||
	} else {
 | 
			
		||||
		req.Messages = append(reqMgs, map[string]interface{}{
 | 
			
		||||
			"role":    "user",
 | 
			
		||||
			"content": prompt,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	logger.Debug("最终Prompt:", fullPrompt)
 | 
			
		||||
 | 
			
		||||
	// extract images from prompt
 | 
			
		||||
	imgURLs := utils.ExtractImgURLs(prompt)
 | 
			
		||||
	logger.Debugf("detected IMG: %+v", imgURLs)
 | 
			
		||||
	var content interface{}
 | 
			
		||||
	if len(imgURLs) > 0 {
 | 
			
		||||
		data := make([]interface{}, 0)
 | 
			
		||||
		for _, v := range imgURLs {
 | 
			
		||||
			text = strings.Replace(text, v, "", 1)
 | 
			
		||||
			data = append(data, gin.H{
 | 
			
		||||
				"type": "image_url",
 | 
			
		||||
				"image_url": gin.H{
 | 
			
		||||
					"url": v,
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
		data = append(data, gin.H{
 | 
			
		||||
			"type": "text",
 | 
			
		||||
			"text": strings.TrimSpace(text),
 | 
			
		||||
		})
 | 
			
		||||
		content = data
 | 
			
		||||
	} else {
 | 
			
		||||
		content = fullPrompt
 | 
			
		||||
	}
 | 
			
		||||
	req.Messages = append(reqMgs, map[string]interface{}{
 | 
			
		||||
		"role":    "user",
 | 
			
		||||
		"content": content,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	logger.Debugf("%+v", req.Messages)
 | 
			
		||||
 | 
			
		||||
	switch session.Model.Platform {
 | 
			
		||||
	case types.Azure.Value:
 | 
			
		||||
		return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	case types.OpenAI.Value:
 | 
			
		||||
		return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	case types.ChatGLM.Value:
 | 
			
		||||
		return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	case types.Baidu.Value:
 | 
			
		||||
		return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	case types.XunFei.Value:
 | 
			
		||||
		return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	case types.QWen.Value:
 | 
			
		||||
		return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
	return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tokens 统计 token 数量
 | 
			
		||||
@@ -442,9 +423,9 @@ func getTotalTokens(req types.ApiRequest) int {
 | 
			
		||||
// StopGenerate 停止生成
 | 
			
		||||
func (h *ChatHandler) StopGenerate(c *gin.Context) {
 | 
			
		||||
	sessionId := c.Query("session_id")
 | 
			
		||||
	if h.App.ReqCancelFunc.Has(sessionId) {
 | 
			
		||||
		h.App.ReqCancelFunc.Get(sessionId)()
 | 
			
		||||
		h.App.ReqCancelFunc.Delete(sessionId)
 | 
			
		||||
	if h.ReqCancelFunc.Has(sessionId) {
 | 
			
		||||
		h.ReqCancelFunc.Get(sessionId)()
 | 
			
		||||
		h.ReqCancelFunc.Delete(sessionId)
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, types.OkMsg)
 | 
			
		||||
}
 | 
			
		||||
@@ -454,59 +435,24 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
 | 
			
		||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
 | 
			
		||||
	// if the chat model bind a KEY, use it directly
 | 
			
		||||
	if session.Model.KeyId > 0 {
 | 
			
		||||
		h.DB.Debug().Where("id", session.Model.KeyId).Where("enabled", true).Find(apiKey)
 | 
			
		||||
		h.DB.Where("id", session.Model.KeyId).Find(apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	// use the last unused key
 | 
			
		||||
	if apiKey.Id == 0 {
 | 
			
		||||
		h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
 | 
			
		||||
		h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	if apiKey.Id == 0 {
 | 
			
		||||
		return nil, errors.New("no available key, please import key")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ONLY allow apiURL in blank list
 | 
			
		||||
	if session.Model.Platform == types.OpenAI.Value {
 | 
			
		||||
		err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var apiURL string
 | 
			
		||||
	switch session.Model.Platform {
 | 
			
		||||
	case types.Azure.Value:
 | 
			
		||||
		md := strings.Replace(req.Model, ".", "", 1)
 | 
			
		||||
		apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
 | 
			
		||||
		break
 | 
			
		||||
	case types.ChatGLM.Value:
 | 
			
		||||
		apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
 | 
			
		||||
		req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
 | 
			
		||||
		req.Messages = nil
 | 
			
		||||
		break
 | 
			
		||||
	case types.Baidu.Value:
 | 
			
		||||
		apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
 | 
			
		||||
		break
 | 
			
		||||
	case types.QWen.Value:
 | 
			
		||||
		apiURL = apiKey.ApiURL
 | 
			
		||||
		req.Messages = nil
 | 
			
		||||
		break
 | 
			
		||||
	default:
 | 
			
		||||
		apiURL = apiKey.ApiURL
 | 
			
		||||
	}
 | 
			
		||||
	// 更新 API KEY 的最后使用时间
 | 
			
		||||
	h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
	// 百度文心,需要串接 access_token
 | 
			
		||||
	if session.Model.Platform == types.Baidu.Value {
 | 
			
		||||
		token, err := h.getBaiduToken(apiKey.Value)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		logger.Info("百度文心 Access_Token:", token)
 | 
			
		||||
		apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Debugf(utils.JsonEncode(req))
 | 
			
		||||
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
 | 
			
		||||
	// 创建 HttpClient 请求对象
 | 
			
		||||
	var client *http.Client
 | 
			
		||||
	requestBody, err := json.Marshal(req)
 | 
			
		||||
@@ -530,28 +476,10 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
 | 
			
		||||
	} else {
 | 
			
		||||
		client = http.DefaultClient
 | 
			
		||||
	}
 | 
			
		||||
	logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
 | 
			
		||||
	switch session.Model.Platform {
 | 
			
		||||
	case types.Azure.Value:
 | 
			
		||||
		request.Header.Set("api-key", apiKey.Value)
 | 
			
		||||
		break
 | 
			
		||||
	case types.ChatGLM.Value:
 | 
			
		||||
		token, err := h.getChatGLMToken(apiKey.Value)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
 | 
			
		||||
		break
 | 
			
		||||
	case types.Baidu.Value:
 | 
			
		||||
		request.RequestURI = ""
 | 
			
		||||
	case types.OpenAI.Value:
 | 
			
		||||
		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
 | 
			
		||||
		break
 | 
			
		||||
	case types.QWen.Value:
 | 
			
		||||
		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
 | 
			
		||||
		request.Header.Set("X-DashScope-SSE", "enable")
 | 
			
		||||
		break
 | 
			
		||||
	}
 | 
			
		||||
	logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
 | 
			
		||||
	request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
 | 
			
		||||
	// 更新API KEY 最后使用时间
 | 
			
		||||
	h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
	return client.Do(request)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -561,24 +489,15 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
 | 
			
		||||
	if session.Model.Power > 0 {
 | 
			
		||||
		power = session.Model.Power
 | 
			
		||||
	}
 | 
			
		||||
	res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		// 记录算力消费日志
 | 
			
		||||
		var u model.User
 | 
			
		||||
		h.DB.Where("id", userVo.Id).First(&u)
 | 
			
		||||
		h.DB.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    userVo.Id,
 | 
			
		||||
			Username:  userVo.Username,
 | 
			
		||||
			Type:      types.PowerConsume,
 | 
			
		||||
			Amount:    power,
 | 
			
		||||
			Mark:      types.PowerSub,
 | 
			
		||||
			Balance:   u.Power,
 | 
			
		||||
			Model:     session.Model.Value,
 | 
			
		||||
			Remark:    fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := h.userService.DecreasePower(int(userVo.Id), power, model.PowerLog{
 | 
			
		||||
		Type:   types.PowerConsume,
 | 
			
		||||
		Model:  session.Model.Value,
 | 
			
		||||
		Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens),
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) saveChatHistory(
 | 
			
		||||
@@ -602,7 +521,7 @@ func (h *ChatHandler) saveChatHistory(
 | 
			
		||||
	if h.App.SysConfig.EnableContext {
 | 
			
		||||
		chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
		chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
		h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
		h.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 追加聊天记录
 | 
			
		||||
@@ -624,9 +543,9 @@ func (h *ChatHandler) saveChatHistory(
 | 
			
		||||
	}
 | 
			
		||||
	historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
	historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
	res := h.DB.Save(&historyUserMsg)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
	err = h.DB.Save(&historyUserMsg).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("failed to save prompt history message: ", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// for reply
 | 
			
		||||
@@ -646,30 +565,32 @@ func (h *ChatHandler) saveChatHistory(
 | 
			
		||||
	}
 | 
			
		||||
	historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
	historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
	res = h.DB.Create(&historyReplyMsg)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
	err = h.DB.Create(&historyReplyMsg).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("failed to save reply history message: ", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新用户算力
 | 
			
		||||
	if session.Model.Power > 0 {
 | 
			
		||||
		// 更新用户算力
 | 
			
		||||
		h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
			
		||||
 | 
			
		||||
		// 保存当前会话
 | 
			
		||||
		var chatItem model.ChatItem
 | 
			
		||||
		res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			chatItem.ChatId = session.ChatId
 | 
			
		||||
			chatItem.UserId = session.UserId
 | 
			
		||||
			chatItem.RoleId = role.Id
 | 
			
		||||
			chatItem.ModelId = session.Model.Id
 | 
			
		||||
			if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
				chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
			} else {
 | 
			
		||||
				chatItem.Title = prompt
 | 
			
		||||
			}
 | 
			
		||||
			chatItem.Model = req.Model
 | 
			
		||||
			h.DB.Create(&chatItem)
 | 
			
		||||
	}
 | 
			
		||||
	// 保存当前会话
 | 
			
		||||
	var chatItem model.ChatItem
 | 
			
		||||
	err = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		chatItem.ChatId = session.ChatId
 | 
			
		||||
		chatItem.UserId = userVo.Id
 | 
			
		||||
		chatItem.RoleId = role.Id
 | 
			
		||||
		chatItem.ModelId = session.Model.Id
 | 
			
		||||
		if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
			chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
		} else {
 | 
			
		||||
			chatItem.Title = prompt
 | 
			
		||||
		}
 | 
			
		||||
		chatItem.Model = req.Model
 | 
			
		||||
		err = h.DB.Create(&chatItem).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("failed to save chat item: ", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -689,7 +610,7 @@ func (h *ChatHandler) extractImgUrl(text string) string {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		newImgURL, err := h.uploadManager.GetUploadHandler().PutImg(imageURL, false)
 | 
			
		||||
		newImgURL, err := h.uploadManager.GetUploadHandler().PutUrlFile(imageURL, false)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("error with download image: ", err)
 | 
			
		||||
			continue
 | 
			
		||||
 
 | 
			
		||||
@@ -96,7 +96,7 @@ func (h *ChatHandler) Clear(c *gin.Context) {
 | 
			
		||||
	for _, chat := range chats {
 | 
			
		||||
		chatIds = append(chatIds, chat.ChatId)
 | 
			
		||||
		// 清空会话上下文
 | 
			
		||||
		h.App.ChatContexts.Delete(chat.ChatId)
 | 
			
		||||
		h.ChatContexts.Delete(chat.ChatId)
 | 
			
		||||
	}
 | 
			
		||||
	err = h.DB.Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
 | 
			
		||||
@@ -108,8 +108,6 @@ func (h *ChatHandler) Clear(c *gin.Context) {
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			return res.Error
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// TODO: 是否要删除 MidJourney 绘画记录和图片文件?
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
@@ -175,7 +173,7 @@ func (h *ChatHandler) Remove(c *gin.Context) {
 | 
			
		||||
	// TODO: 是否要删除 MidJourney 绘画记录和图片文件?
 | 
			
		||||
 | 
			
		||||
	// 清空会话上下文
 | 
			
		||||
	h.App.ChatContexts.Delete(chatId)
 | 
			
		||||
	h.ChatContexts.Delete(chatId)
 | 
			
		||||
	resp.SUCCESS(c, types.OkMsg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,142 +0,0 @@
 | 
			
		||||
package chatimpl
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/golang-jwt/jwt/v5"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 清华大学 ChatGML 消息发送实现
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) sendChatGLMMessage(
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	session *types.ChatSession,
 | 
			
		||||
	role model.ChatRole,
 | 
			
		||||
	prompt string,
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	var apiKey = model.ApiKey{}
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session, &apiKey)
 | 
			
		||||
	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
			return fmt.Errorf("用户取消了请求:%s", prompt)
 | 
			
		||||
		} else if strings.Contains(err.Error(), "no available key") {
 | 
			
		||||
			return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
		}
 | 
			
		||||
		return err
 | 
			
		||||
	} else {
 | 
			
		||||
		defer response.Body.Close()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	contentType := response.Header.Get("Content-Type")
 | 
			
		||||
	if strings.Contains(contentType, "text/event-stream") {
 | 
			
		||||
		replyCreatedAt := time.Now() // 记录回复时间
 | 
			
		||||
		// 循环读取 Chunk 消息
 | 
			
		||||
		var message = types.Message{}
 | 
			
		||||
		var contents = make([]string, 0)
 | 
			
		||||
		var event, content string
 | 
			
		||||
		scanner := bufio.NewScanner(response.Body)
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			line := scanner.Text()
 | 
			
		||||
			if len(line) < 5 || strings.HasPrefix(line, "id:") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(line, "event:") {
 | 
			
		||||
				event = line[6:]
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if strings.HasPrefix(line, "data:") {
 | 
			
		||||
				content = line[5:]
 | 
			
		||||
			}
 | 
			
		||||
			// 处理代码换行
 | 
			
		||||
			if len(content) == 0 {
 | 
			
		||||
				content = "\n"
 | 
			
		||||
			}
 | 
			
		||||
			switch event {
 | 
			
		||||
			case "add":
 | 
			
		||||
				if len(contents) == 0 {
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
				}
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: utils.InterfaceToString(content),
 | 
			
		||||
				})
 | 
			
		||||
				contents = append(contents, content)
 | 
			
		||||
			case "finish":
 | 
			
		||||
				break
 | 
			
		||||
			case "error":
 | 
			
		||||
				utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content))
 | 
			
		||||
				break
 | 
			
		||||
			case "interrupted":
 | 
			
		||||
				utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		} // end for
 | 
			
		||||
 | 
			
		||||
		if err := scanner.Err(); err != nil {
 | 
			
		||||
			if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
				logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Error("信息读取出错:", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		body, _ := io.ReadAll(response.Body)
 | 
			
		||||
		return fmt.Errorf("请求大模型 API 失败:%s", body)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) getChatGLMToken(apiKey string) (string, error) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	tokenString, err := h.redis.Get(ctx, apiKey).Result()
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return tokenString, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	expr := time.Hour * 2
 | 
			
		||||
	key := strings.Split(apiKey, ".")
 | 
			
		||||
	if len(key) != 2 {
 | 
			
		||||
		return "", fmt.Errorf("invalid api key: %s", apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
 | 
			
		||||
		"api_key":   key[0],
 | 
			
		||||
		"timestamp": time.Now().Unix(),
 | 
			
		||||
		"exp":       time.Now().Add(expr).Add(time.Second * 10).Unix(),
 | 
			
		||||
	})
 | 
			
		||||
	token.Header["alg"] = "HS256"
 | 
			
		||||
	token.Header["sign_type"] = "SIGN"
 | 
			
		||||
	delete(token.Header, "typ")
 | 
			
		||||
	// Sign and get the complete encoded token as a string using the secret
 | 
			
		||||
	tokenString, err = token.SignedString([]byte(key[1]))
 | 
			
		||||
	h.redis.Set(ctx, apiKey, tokenString, expr)
 | 
			
		||||
	return tokenString, err
 | 
			
		||||
}
 | 
			
		||||
@@ -65,7 +65,6 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
			if !strings.Contains(line, "data:") || len(line) < 30 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var responseBody = types.ApiResponse{}
 | 
			
		||||
			err = json.Unmarshal([]byte(line[6:]), &responseBody)
 | 
			
		||||
			if err != nil { // 数据解析出错
 | 
			
		||||
@@ -74,6 +73,9 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
			if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if responseBody.Choices[0].Delta.Content == nil && responseBody.Choices[0].Delta.ToolCalls == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
 | 
			
		||||
				utils.ReplyMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
 | 
			
		||||
@@ -142,7 +144,7 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if toolCall { // 调用函数完成任务
 | 
			
		||||
			var params map[string]interface{}
 | 
			
		||||
			params := make(map[string]interface{})
 | 
			
		||||
			_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
 | 
			
		||||
			logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
 | 
			
		||||
			params["user_id"] = userVo.Id
 | 
			
		||||
 
 | 
			
		||||
@@ -1,150 +0,0 @@
 | 
			
		||||
package chatimpl
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/syndtr/goleveldb/leveldb/errors"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type qWenResp struct {
 | 
			
		||||
	Output struct {
 | 
			
		||||
		FinishReason string `json:"finish_reason"`
 | 
			
		||||
		Text         string `json:"text"`
 | 
			
		||||
	} `json:"output,omitempty"`
 | 
			
		||||
	Usage struct {
 | 
			
		||||
		TotalTokens  int `json:"total_tokens"`
 | 
			
		||||
		InputTokens  int `json:"input_tokens"`
 | 
			
		||||
		OutputTokens int `json:"output_tokens"`
 | 
			
		||||
	} `json:"usage,omitempty"`
 | 
			
		||||
	RequestID string `json:"request_id"`
 | 
			
		||||
 | 
			
		||||
	Code    string `json:"code,omitempty"`
 | 
			
		||||
	Message string `json:"message,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 通义千问消息发送实现
 | 
			
		||||
func (h *ChatHandler) sendQWenMessage(
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	session *types.ChatSession,
 | 
			
		||||
	role model.ChatRole,
 | 
			
		||||
	prompt string,
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	var apiKey = model.ApiKey{}
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session, &apiKey)
 | 
			
		||||
	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
			return fmt.Errorf("用户取消了请求:%s", prompt)
 | 
			
		||||
		} else if strings.Contains(err.Error(), "no available key") {
 | 
			
		||||
			return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
		}
 | 
			
		||||
		return err
 | 
			
		||||
	} else {
 | 
			
		||||
		defer response.Body.Close()
 | 
			
		||||
	}
 | 
			
		||||
	contentType := response.Header.Get("Content-Type")
 | 
			
		||||
	if strings.Contains(contentType, "text/event-stream") {
 | 
			
		||||
		replyCreatedAt := time.Now() // 记录回复时间
 | 
			
		||||
		// 循环读取 Chunk 消息
 | 
			
		||||
		var message = types.Message{}
 | 
			
		||||
		var contents = make([]string, 0)
 | 
			
		||||
		scanner := bufio.NewScanner(response.Body)
 | 
			
		||||
 | 
			
		||||
		var content, lastText, newText string
 | 
			
		||||
		var outPutStart = false
 | 
			
		||||
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			line := scanner.Text()
 | 
			
		||||
			if len(line) < 5 || strings.HasPrefix(line, "id:") ||
 | 
			
		||||
				strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if !strings.HasPrefix(line, "data:") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			content = line[5:]
 | 
			
		||||
			var resp qWenResp
 | 
			
		||||
			if len(contents) == 0 { // 发送消息头
 | 
			
		||||
				if !outPutStart {
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
					outPutStart = true
 | 
			
		||||
					continue
 | 
			
		||||
				} else {
 | 
			
		||||
					// 处理代码换行
 | 
			
		||||
					content = "\n"
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				err := utils.JsonDecode(content, &resp)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Error("error with parse data line: ", content)
 | 
			
		||||
					utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
				if resp.Message != "" {
 | 
			
		||||
					utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message))
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			//通过比较 lastText(上一次的文本)和 currentText(当前的文本),
 | 
			
		||||
			//提取出新添加的文本部分。然后只将这部分新文本发送到客户端。
 | 
			
		||||
			//每次循环结束后,lastText 会更新为当前的完整文本,以便于下一次循环进行比较。
 | 
			
		||||
			currentText := resp.Output.Text
 | 
			
		||||
			if currentText != lastText {
 | 
			
		||||
				// 提取新增文本
 | 
			
		||||
				newText = strings.Replace(currentText, lastText, "", 1)
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: utils.InterfaceToString(newText),
 | 
			
		||||
				})
 | 
			
		||||
				lastText = currentText // 更新 lastText
 | 
			
		||||
			}
 | 
			
		||||
			contents = append(contents, newText)
 | 
			
		||||
 | 
			
		||||
			if resp.Output.FinishReason == "stop" {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		} //end for
 | 
			
		||||
 | 
			
		||||
		if err := scanner.Err(); err != nil {
 | 
			
		||||
			if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
				logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Error("信息读取出错:", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		body, _ := io.ReadAll(response.Body)
 | 
			
		||||
		return fmt.Errorf("请求大模型 API 失败:%s", body)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -1,255 +0,0 @@
 | 
			
		||||
package chatimpl
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/hmac"
 | 
			
		||||
	"crypto/sha256"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type xunFeiResp struct {
 | 
			
		||||
	Header struct {
 | 
			
		||||
		Code    int    `json:"code"`
 | 
			
		||||
		Message string `json:"message"`
 | 
			
		||||
		Sid     string `json:"sid"`
 | 
			
		||||
		Status  int    `json:"status"`
 | 
			
		||||
	} `json:"header"`
 | 
			
		||||
	Payload struct {
 | 
			
		||||
		Choices struct {
 | 
			
		||||
			Status int `json:"status"`
 | 
			
		||||
			Seq    int `json:"seq"`
 | 
			
		||||
			Text   []struct {
 | 
			
		||||
				Content string `json:"content"`
 | 
			
		||||
				Role    string `json:"role"`
 | 
			
		||||
				Index   int    `json:"index"`
 | 
			
		||||
			} `json:"text"`
 | 
			
		||||
		} `json:"choices"`
 | 
			
		||||
		Usage struct {
 | 
			
		||||
			Text struct {
 | 
			
		||||
				QuestionTokens   int `json:"question_tokens"`
 | 
			
		||||
				PromptTokens     int `json:"prompt_tokens"`
 | 
			
		||||
				CompletionTokens int `json:"completion_tokens"`
 | 
			
		||||
				TotalTokens      int `json:"total_tokens"`
 | 
			
		||||
			} `json:"text"`
 | 
			
		||||
		} `json:"usage"`
 | 
			
		||||
	} `json:"payload"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var Model2URL = map[string]string{
 | 
			
		||||
	"general":     "v1.1",
 | 
			
		||||
	"generalv2":   "v2.1",
 | 
			
		||||
	"generalv3":   "v3.1",
 | 
			
		||||
	"generalv3.5": "v3.5",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 科大讯飞消息发送实现
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) sendXunFeiMessage(
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	session *types.ChatSession,
 | 
			
		||||
	role model.ChatRole,
 | 
			
		||||
	prompt string,
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	var res *gorm.DB
 | 
			
		||||
	// use the bind key
 | 
			
		||||
	if session.Model.KeyId > 0 {
 | 
			
		||||
		res = h.DB.Where("id", session.Model.KeyId).Where("enabled", true).Find(&apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	// use the last unused key
 | 
			
		||||
	if apiKey.Id == 0 {
 | 
			
		||||
		res = h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
	}
 | 
			
		||||
	// 更新 API KEY 的最后使用时间
 | 
			
		||||
	h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
 | 
			
		||||
	d := websocket.Dialer{
 | 
			
		||||
		HandshakeTimeout: 5 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
	key := strings.Split(apiKey.Value, "|")
 | 
			
		||||
	if len(key) != 3 {
 | 
			
		||||
		utils.ReplyMessage(ws, "非法的 API KEY!")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
 | 
			
		||||
	logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
 | 
			
		||||
	wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
 | 
			
		||||
	//握手并建立websocket 连接
 | 
			
		||||
	conn, resp, err := d.Dial(wsURL, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(readResp(resp) + err.Error())
 | 
			
		||||
		utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
 | 
			
		||||
		return nil
 | 
			
		||||
	} else if resp.StatusCode != 101 {
 | 
			
		||||
		utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	data := buildRequest(key[0], req)
 | 
			
		||||
	fmt.Printf("%+v", data)
 | 
			
		||||
	fmt.Println(apiURL)
 | 
			
		||||
	err = conn.WriteJSON(data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		utils.ReplyMessage(ws, "发送消息失败:"+err.Error())
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	replyCreatedAt := time.Now() // 记录回复时间
 | 
			
		||||
	// 循环读取 Chunk 消息
 | 
			
		||||
	var message = types.Message{}
 | 
			
		||||
	var contents = make([]string, 0)
 | 
			
		||||
	var content string
 | 
			
		||||
	for {
 | 
			
		||||
		_, msg, err := conn.ReadMessage()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("error with read message:", err)
 | 
			
		||||
			utils.ReplyMessage(ws, fmt.Sprintf("**数据读取失败:%s**", err))
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 解析数据
 | 
			
		||||
		var result xunFeiResp
 | 
			
		||||
		err = json.Unmarshal(msg, &result)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("error with parsing JSON:", err)
 | 
			
		||||
			utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if result.Header.Code != 0 {
 | 
			
		||||
			utils.ReplyMessage(ws, fmt.Sprintf("**请求 API 返回错误:%s**", result.Header.Message))
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		content = result.Payload.Choices.Text[0].Content
 | 
			
		||||
		// 处理代码换行
 | 
			
		||||
		if len(content) == 0 {
 | 
			
		||||
			content = "\n"
 | 
			
		||||
		}
 | 
			
		||||
		contents = append(contents, content)
 | 
			
		||||
		// 第一个结果
 | 
			
		||||
		if result.Payload.Choices.Status == 0 {
 | 
			
		||||
			utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
		}
 | 
			
		||||
		utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
			Type:    types.WsMiddle,
 | 
			
		||||
			Content: utils.InterfaceToString(content),
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		if result.Payload.Choices.Status == 2 { // 最终结果
 | 
			
		||||
			_ = conn.Close() // 关闭连接
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		select {
 | 
			
		||||
		case <-ctx.Done():
 | 
			
		||||
			utils.ReplyMessage(ws, "**用户取消了生成指令!**")
 | 
			
		||||
			return nil
 | 
			
		||||
		default:
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
	// 消息发送成功
 | 
			
		||||
	if len(contents) > 0 {
 | 
			
		||||
		h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 构建 websocket 请求实体
 | 
			
		||||
func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
 | 
			
		||||
	return map[string]interface{}{
 | 
			
		||||
		"header": map[string]interface{}{
 | 
			
		||||
			"app_id": appid,
 | 
			
		||||
		},
 | 
			
		||||
		"parameter": map[string]interface{}{
 | 
			
		||||
			"chat": map[string]interface{}{
 | 
			
		||||
				"domain":      req.Model,
 | 
			
		||||
				"temperature": req.Temperature,
 | 
			
		||||
				"top_k":       int64(6),
 | 
			
		||||
				"max_tokens":  int64(req.MaxTokens),
 | 
			
		||||
				"auditing":    "default",
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		"payload": map[string]interface{}{
 | 
			
		||||
			"message": map[string]interface{}{
 | 
			
		||||
				"text": req.Messages,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 创建鉴权 URL
 | 
			
		||||
func assembleAuthUrl(hostURL string, apiKey, apiSecret string) (string, error) {
 | 
			
		||||
	ul, err := url.Parse(hostURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	date := time.Now().UTC().Format(time.RFC1123)
 | 
			
		||||
	signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
 | 
			
		||||
	//拼接签名字符串
 | 
			
		||||
	signStr := strings.Join(signString, "\n")
 | 
			
		||||
	sha := hmacWithSha256(signStr, apiSecret)
 | 
			
		||||
 | 
			
		||||
	authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
 | 
			
		||||
		"hmac-sha256", "host date request-line", sha)
 | 
			
		||||
	//将请求参数使用base64编码
 | 
			
		||||
	authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
 | 
			
		||||
	v := url.Values{}
 | 
			
		||||
	v.Add("host", ul.Host)
 | 
			
		||||
	v.Add("date", date)
 | 
			
		||||
	v.Add("authorization", authorization)
 | 
			
		||||
	//将编码后的字符串url encode后添加到url后面
 | 
			
		||||
	return hostURL + "?" + v.Encode(), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 使用 sha256 签名
 | 
			
		||||
func hmacWithSha256(data, key string) string {
 | 
			
		||||
	mac := hmac.New(sha256.New, []byte(key))
 | 
			
		||||
	mac.Write([]byte(data))
 | 
			
		||||
	encodeData := mac.Sum(nil)
 | 
			
		||||
	return base64.StdEncoding.EncodeToString(encodeData)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 读取响应
 | 
			
		||||
func readResp(resp *http.Response) string {
 | 
			
		||||
	if resp == nil {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	b, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
 | 
			
		||||
}
 | 
			
		||||
@@ -8,34 +8,36 @@ package handler
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/dalle"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type DallJobHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	redis    *redis.Client
 | 
			
		||||
	service  *dalle.Service
 | 
			
		||||
	uploader *oss.UploaderManager
 | 
			
		||||
	redis       *redis.Client
 | 
			
		||||
	dallService *dalle.Service
 | 
			
		||||
	uploader    *oss.UploaderManager
 | 
			
		||||
	userService *service.UserService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler {
 | 
			
		||||
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService) *DallJobHandler {
 | 
			
		||||
	return &DallJobHandler{
 | 
			
		||||
		service:  service,
 | 
			
		||||
		uploader: manager,
 | 
			
		||||
		dallService: service,
 | 
			
		||||
		uploader:    manager,
 | 
			
		||||
		userService: userService,
 | 
			
		||||
		BaseHandler: BaseHandler{
 | 
			
		||||
			App: app,
 | 
			
		||||
			DB:  db,
 | 
			
		||||
@@ -60,14 +62,14 @@ func (h *DallJobHandler) Client(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	h.service.Clients.Put(uint(userId), client)
 | 
			
		||||
	h.dallService.Clients.Put(uint(userId), client)
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			_, msg, err := client.Receive()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				client.Close()
 | 
			
		||||
				h.service.Clients.Delete(uint(userId))
 | 
			
		||||
				h.dallService.Clients.Delete(uint(userId))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
@@ -126,7 +128,7 @@ func (h *DallJobHandler) Image(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.service.PushTask(types.DallTask{
 | 
			
		||||
	h.dallService.PushTask(types.DallTask{
 | 
			
		||||
		JobId:   job.Id,
 | 
			
		||||
		UserId:  uint(userId),
 | 
			
		||||
		Prompt:  data.Prompt,
 | 
			
		||||
@@ -136,7 +138,7 @@ func (h *DallJobHandler) Image(c *gin.Context) {
 | 
			
		||||
		Power:   job.Power,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	client := h.service.Clients.Get(job.UserId)
 | 
			
		||||
	client := h.dallService.Clients.Get(job.UserId)
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
@@ -158,13 +160,13 @@ func (h *DallJobHandler) ImgWall(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
// JobList 获取 SD 任务列表
 | 
			
		||||
func (h *DallJobHandler) JobList(c *gin.Context) {
 | 
			
		||||
	status := h.GetBool(c, "status")
 | 
			
		||||
	finish := h.GetBool(c, "finish")
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	page := h.GetInt(c, "page", 0)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 0)
 | 
			
		||||
	publish := h.GetBool(c, "publish")
 | 
			
		||||
 | 
			
		||||
	err, jobs := h.getData(status, userId, page, pageSize, publish)
 | 
			
		||||
	err, jobs := h.getData(finish, userId, page, pageSize, publish)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
@@ -174,11 +176,11 @@ func (h *DallJobHandler) JobList(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JobList 获取任务列表
 | 
			
		||||
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) {
 | 
			
		||||
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
 | 
			
		||||
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	if finish {
 | 
			
		||||
		session = session.Where("progress = ?", 100).Order("id DESC")
 | 
			
		||||
		session = session.Where("progress >= ?", 100).Order("id DESC")
 | 
			
		||||
	} else {
 | 
			
		||||
		session = session.Where("progress < ?", 100).Order("id ASC")
 | 
			
		||||
	}
 | 
			
		||||
@@ -192,11 +194,14 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
 | 
			
		||||
		offset := (page - 1) * pageSize
 | 
			
		||||
		session = session.Offset(offset).Limit(pageSize)
 | 
			
		||||
	}
 | 
			
		||||
	// 统计总数
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.DallJob{}).Count(&total)
 | 
			
		||||
 | 
			
		||||
	var items []model.DallJob
 | 
			
		||||
	res := session.Find(&items)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return res.Error, nil
 | 
			
		||||
		return res.Error, vo.Page{}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var jobs = make([]vo.DallJob, 0)
 | 
			
		||||
@@ -209,30 +214,44 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
 | 
			
		||||
		jobs = append(jobs, job)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, jobs
 | 
			
		||||
	return nil, vo.NewPage(total, page, pageSize, jobs)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Remove remove task image
 | 
			
		||||
func (h *DallJobHandler) Remove(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id     uint   `json:"id"`
 | 
			
		||||
		UserId uint   `json:"user_id"`
 | 
			
		||||
		ImgURL string `json:"img_url"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	var job model.DallJob
 | 
			
		||||
	if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "记录不存在")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// remove job recode
 | 
			
		||||
	res := h.DB.Delete(&model.DallJob{Id: data.Id})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
	// 删除任务
 | 
			
		||||
	tx := h.DB.Begin()
 | 
			
		||||
	if err := tx.Delete(&job).Error; err != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果任务未完成,或者任务失败,则恢复用户算力
 | 
			
		||||
	if job.Progress != 100 {
 | 
			
		||||
		err := h.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{
 | 
			
		||||
			Type:   types.PowerRefund,
 | 
			
		||||
			Model:  "dall-e-3",
 | 
			
		||||
			Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg),
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	tx.Commit()
 | 
			
		||||
 | 
			
		||||
	// remove image
 | 
			
		||||
	err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
 | 
			
		||||
	err := h.uploader.GetUploadHandler().Delete(job.ImgURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("remove image failed: ", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -242,19 +261,13 @@ func (h *DallJobHandler) Remove(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
// Publish 发布/取消发布图片到画廊显示
 | 
			
		||||
func (h *DallJobHandler) Publish(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id     uint `json:"id"`
 | 
			
		||||
		Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.DallJob{Id: data.Id}).UpdateColumn("publish", true)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败")
 | 
			
		||||
	err := h.DB.Model(&model.DallJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -8,15 +8,16 @@ package handler
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service/dalle"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
@@ -224,3 +225,27 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, content)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// List 获取所有的工具函数列表
 | 
			
		||||
func (h *FunctionHandler) List(c *gin.Context) {
 | 
			
		||||
	var items []model.Function
 | 
			
		||||
	err := h.DB.Where("enabled", true).Find(&items).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tools := make([]vo.Function, 0)
 | 
			
		||||
	for _, v := range items {
 | 
			
		||||
		var f vo.Function
 | 
			
		||||
		err = utils.CopyObject(v, &f)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		f.Action = ""
 | 
			
		||||
		f.Token = ""
 | 
			
		||||
		tools = append(tools, f)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, tools)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,6 @@ package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
@@ -59,23 +58,16 @@ func (h *InviteHandler) Code(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
// List Log 用户邀请记录
 | 
			
		||||
func (h *InviteHandler) List(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Page     int `json:"page"`
 | 
			
		||||
		PageSize int `json:"page_size"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	page := h.GetInt(c, "page", 1)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 20)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.InviteLog{}).Count(&total)
 | 
			
		||||
	var items []model.InviteLog
 | 
			
		||||
	var list = make([]vo.InviteLog, 0)
 | 
			
		||||
	offset := (data.Page - 1) * data.PageSize
 | 
			
		||||
	res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
 | 
			
		||||
	offset := (page - 1) * pageSize
 | 
			
		||||
	res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			var v vo.InviteLog
 | 
			
		||||
@@ -89,7 +81,7 @@ func (h *InviteHandler) List(c *gin.Context) {
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
 | 
			
		||||
	resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Hits 访问邀请码
 | 
			
		||||
 
 | 
			
		||||
@@ -15,6 +15,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
@@ -30,13 +31,15 @@ import (
 | 
			
		||||
// MarkMapHandler 生成思维导图
 | 
			
		||||
type MarkMapHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	clients *types.LMap[int, *types.WsClient]
 | 
			
		||||
	clients     *types.LMap[int, *types.WsClient]
 | 
			
		||||
	userService *service.UserService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler {
 | 
			
		||||
func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *MarkMapHandler {
 | 
			
		||||
	return &MarkMapHandler{
 | 
			
		||||
		BaseHandler: BaseHandler{App: app, DB: db},
 | 
			
		||||
		clients:     types.NewLMap[int, *types.WsClient](),
 | 
			
		||||
		userService: userService,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -101,17 +104,13 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
 | 
			
		||||
		return fmt.Errorf("error with query chat model: %v", res.Error)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Status == false {
 | 
			
		||||
		return errors.New("当前用户被禁用")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Power < chatModel.Power {
 | 
			
		||||
		return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	messages := make([]interface{}, 0)
 | 
			
		||||
	messages = append(messages, types.Message{Role: "system", Content: `
 | 
			
		||||
你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
 | 
			
		||||
你是一位非常优秀的思维导图助手, 你能帮助用户整理思路,根据用户提供的主题或内容,快速生成结构清晰,有条理的思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
 | 
			
		||||
# Geek-AI 助手
 | 
			
		||||
 | 
			
		||||
## 完整的开源系统
 | 
			
		||||
@@ -130,7 +129,7 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
 | 
			
		||||
 | 
			
		||||
另外,除此之外不要任何解释性语句。
 | 
			
		||||
`})
 | 
			
		||||
	messages = append(messages, types.Message{Role: "user", Content: prompt})
 | 
			
		||||
	messages = append(messages, types.Message{Role: "user", Content: fmt.Sprintf("请生成一份有关【%s】一份思维导图,要求结构清晰,有条理", prompt)})
 | 
			
		||||
	var req = types.ApiRequest{
 | 
			
		||||
		Model:    chatModel.Value,
 | 
			
		||||
		Stream:   true,
 | 
			
		||||
@@ -183,66 +182,41 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
 | 
			
		||||
		utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
 | 
			
		||||
 | 
			
		||||
	} else {
 | 
			
		||||
		body, err := io.ReadAll(response.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("读取响应失败: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		var res types.ApiError
 | 
			
		||||
		err = json.Unmarshal(body, &res)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("解析响应失败: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// OpenAI API 调用异常处理
 | 
			
		||||
		if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
 | 
			
		||||
			// remove key
 | 
			
		||||
			h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
 | 
			
		||||
			return errors.New("请求 OpenAI API 失败:API KEY 所关联的账户被禁用。")
 | 
			
		||||
		} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
 | 
			
		||||
			return errors.New("请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。")
 | 
			
		||||
		} else {
 | 
			
		||||
			return fmt.Errorf("请求 OpenAI API 失败:%v", res.Error.Message)
 | 
			
		||||
		}
 | 
			
		||||
		body, _ := io.ReadAll(response.Body)
 | 
			
		||||
		return fmt.Errorf("请求 OpenAI API 失败:%s", string(body))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 扣减算力
 | 
			
		||||
	res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power))
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		// 记录算力消费日志
 | 
			
		||||
		var u model.User
 | 
			
		||||
		h.DB.Where("id", userId).First(&u)
 | 
			
		||||
		h.DB.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    u.Id,
 | 
			
		||||
			Username:  u.Username,
 | 
			
		||||
			Type:      types.PowerConsume,
 | 
			
		||||
			Amount:    chatModel.Power,
 | 
			
		||||
			Mark:      types.PowerSub,
 | 
			
		||||
			Balance:   u.Power,
 | 
			
		||||
			Model:     chatModel.Value,
 | 
			
		||||
			Remark:    fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
	if chatModel.Power > 0 {
 | 
			
		||||
		err = h.userService.DecreasePower(userId, chatModel.Power, model.PowerLog{
 | 
			
		||||
			Type:   types.PowerConsume,
 | 
			
		||||
			Model:  chatModel.Value,
 | 
			
		||||
			Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value),
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
 | 
			
		||||
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	// if the chat model bind a KEY, use it directly
 | 
			
		||||
	var res *gorm.DB
 | 
			
		||||
	if chatModel.KeyId > 0 {
 | 
			
		||||
		res = h.DB.Where("id", chatModel.KeyId).Where("enabled", true).Find(apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	// use the last unused key
 | 
			
		||||
	if apiKey.Id == 0 {
 | 
			
		||||
		res = h.DB.Where("platform", types.OpenAI).
 | 
			
		||||
			Where("type", "chat").
 | 
			
		||||
			Where("enabled", true).Order("last_used_at ASC").First(apiKey)
 | 
			
		||||
		session = session.Where("id", chatModel.KeyId)
 | 
			
		||||
	} else { // use the last unused key
 | 
			
		||||
		session = session.Where("type", "chat").
 | 
			
		||||
			Where("enabled", true).Order("last_used_at ASC")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := session.First(apiKey)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return nil, errors.New("no available key, please import key")
 | 
			
		||||
	}
 | 
			
		||||
	apiURL := apiKey.ApiURL
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
 | 
			
		||||
	// 更新 API KEY 的最后使用时间
 | 
			
		||||
	h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
 | 
			
		||||
@@ -269,5 +243,6 @@ func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatMod
 | 
			
		||||
		client = http.DefaultClient
 | 
			
		||||
	}
 | 
			
		||||
	request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
 | 
			
		||||
	logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
 | 
			
		||||
	return client.Do(request)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -27,9 +27,15 @@ func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
 | 
			
		||||
 | 
			
		||||
// List 数据列表
 | 
			
		||||
func (h *MenuHandler) List(c *gin.Context) {
 | 
			
		||||
	index := h.GetBool(c, "index")
 | 
			
		||||
	var items []model.Menu
 | 
			
		||||
	var list = make([]vo.Menu, 0)
 | 
			
		||||
	res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	session = session.Where("enabled", true)
 | 
			
		||||
	if index {
 | 
			
		||||
		session = session.Where("id IN ?", h.App.SysConfig.IndexNavs)
 | 
			
		||||
	}
 | 
			
		||||
	res := session.Order("sort_num ASC").Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			var product vo.Menu
 | 
			
		||||
 
 | 
			
		||||
@@ -30,16 +30,18 @@ import (
 | 
			
		||||
 | 
			
		||||
type MidJourneyHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	pool      *mj.ServicePool
 | 
			
		||||
	snowflake *service.Snowflake
 | 
			
		||||
	uploader  *oss.UploaderManager
 | 
			
		||||
	mjService   *mj.Service
 | 
			
		||||
	snowflake   *service.Snowflake
 | 
			
		||||
	uploader    *oss.UploaderManager
 | 
			
		||||
	userService *service.UserService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
 | 
			
		||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService) *MidJourneyHandler {
 | 
			
		||||
	return &MidJourneyHandler{
 | 
			
		||||
		snowflake: snowflake,
 | 
			
		||||
		pool:      pool,
 | 
			
		||||
		uploader:  manager,
 | 
			
		||||
		snowflake:   snowflake,
 | 
			
		||||
		mjService:   service,
 | 
			
		||||
		uploader:    manager,
 | 
			
		||||
		userService: userService,
 | 
			
		||||
		BaseHandler: BaseHandler{
 | 
			
		||||
			App: app,
 | 
			
		||||
			DB:  db,
 | 
			
		||||
@@ -59,11 +61,6 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !h.pool.HasAvailableService() {
 | 
			
		||||
		resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!")
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return true
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@@ -85,26 +82,25 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	h.pool.Clients.Put(uint(userId), client)
 | 
			
		||||
	h.mjService.Clients.Put(uint(userId), client)
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Image 创建一个绘画任务
 | 
			
		||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		SessionId string   `json:"session_id"`
 | 
			
		||||
		TaskType  string   `json:"task_type"`
 | 
			
		||||
		Prompt    string   `json:"prompt"`
 | 
			
		||||
		NegPrompt string   `json:"neg_prompt"`
 | 
			
		||||
		Rate      string   `json:"rate"`
 | 
			
		||||
		Model     string   `json:"model"`
 | 
			
		||||
		Chaos     int      `json:"chaos"`
 | 
			
		||||
		Raw       bool     `json:"raw"`
 | 
			
		||||
		Seed      int64    `json:"seed"`
 | 
			
		||||
		Stylize   int      `json:"stylize"`
 | 
			
		||||
		Model     string   `json:"model"`   // 模型
 | 
			
		||||
		Chaos     int      `json:"chaos"`   // 创意度取值范围: 0-100
 | 
			
		||||
		Raw       bool     `json:"raw"`     // 是否开启原始模型
 | 
			
		||||
		Seed      int64    `json:"seed"`    // 随机数
 | 
			
		||||
		Stylize   int      `json:"stylize"` // 风格化
 | 
			
		||||
		ImgArr    []string `json:"img_arr"`
 | 
			
		||||
		Tile      bool     `json:"tile"`
 | 
			
		||||
		Quality   float32  `json:"quality"`
 | 
			
		||||
		Tile      bool     `json:"tile"`    // 重复平铺
 | 
			
		||||
		Quality   float32  `json:"quality"` // 画质
 | 
			
		||||
		Iw        float32  `json:"iw"`
 | 
			
		||||
		CRef      string   `json:"cref"` //生成角色一致的图像
 | 
			
		||||
		SRef      string   `json:"sref"` //生成风格一致的图像
 | 
			
		||||
@@ -202,40 +198,34 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.pool.PushTask(types.MjTask{
 | 
			
		||||
	h.mjService.PushTask(types.MjTask{
 | 
			
		||||
		Id:        job.Id,
 | 
			
		||||
		TaskId:    taskId,
 | 
			
		||||
		SessionId: data.SessionId,
 | 
			
		||||
		Type:      types.TaskType(data.TaskType),
 | 
			
		||||
		Prompt:    data.Prompt,
 | 
			
		||||
		NegPrompt: data.NegPrompt,
 | 
			
		||||
		Params:    params,
 | 
			
		||||
		UserId:    userId,
 | 
			
		||||
		ImgArr:    data.ImgArr,
 | 
			
		||||
		Mode:      h.App.SysConfig.MjMode,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	client := h.pool.Clients.Get(uint(job.UserId))
 | 
			
		||||
	client := h.mjService.Clients.Get(uint(job.UserId))
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update user's power
 | 
			
		||||
	tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
 | 
			
		||||
	// 记录算力变化日志
 | 
			
		||||
	if tx.Error == nil && tx.RowsAffected > 0 {
 | 
			
		||||
		user, _ := h.GetLoginUser(c)
 | 
			
		||||
		h.DB.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
			Type:      types.PowerConsume,
 | 
			
		||||
			Amount:    job.Power,
 | 
			
		||||
			Balance:   user.Power - job.Power,
 | 
			
		||||
			Mark:      types.PowerSub,
 | 
			
		||||
			Model:     "mid-journey",
 | 
			
		||||
			Remark:    fmt.Sprintf("%s操作,任务ID:%s", opt, job.TaskId),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		})
 | 
			
		||||
	err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
 | 
			
		||||
		Type:   types.PowerConsume,
 | 
			
		||||
		Model:  "mid-journey",
 | 
			
		||||
		Remark: fmt.Sprintf("%s操作,任务ID:%s", opt, job.TaskId),
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -244,17 +234,12 @@ type reqVo struct {
 | 
			
		||||
	ChannelId   string `json:"channel_id"`
 | 
			
		||||
	MessageId   string `json:"message_id"`
 | 
			
		||||
	MessageHash string `json:"message_hash"`
 | 
			
		||||
	SessionId   string `json:"session_id"`
 | 
			
		||||
	Prompt      string `json:"prompt"`
 | 
			
		||||
	ChatId      string `json:"chat_id"`
 | 
			
		||||
	RoleId      int    `json:"role_id"`
 | 
			
		||||
	Icon        string `json:"icon"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Upscale send upscale command to MidJourney Bot
 | 
			
		||||
func (h *MidJourneyHandler) Upscale(c *gin.Context) {
 | 
			
		||||
	var data reqVo
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil || data.SessionId == "" {
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -272,7 +257,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
 | 
			
		||||
		UserId:      userId,
 | 
			
		||||
		TaskId:      taskId,
 | 
			
		||||
		Progress:    0,
 | 
			
		||||
		Prompt:      data.Prompt,
 | 
			
		||||
		Power:       h.App.SysConfig.MjActionPower,
 | 
			
		||||
		CreatedAt:   time.Now(),
 | 
			
		||||
	}
 | 
			
		||||
@@ -281,46 +265,40 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.pool.PushTask(types.MjTask{
 | 
			
		||||
	h.mjService.PushTask(types.MjTask{
 | 
			
		||||
		Id:          job.Id,
 | 
			
		||||
		SessionId:   data.SessionId,
 | 
			
		||||
		Type:        types.TaskUpscale,
 | 
			
		||||
		Prompt:      data.Prompt,
 | 
			
		||||
		UserId:      userId,
 | 
			
		||||
		ChannelId:   data.ChannelId,
 | 
			
		||||
		Index:       data.Index,
 | 
			
		||||
		MessageId:   data.MessageId,
 | 
			
		||||
		MessageHash: data.MessageHash,
 | 
			
		||||
		Mode:        h.App.SysConfig.MjMode,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	client := h.pool.Clients.Get(uint(job.UserId))
 | 
			
		||||
	client := h.mjService.Clients.Get(uint(job.UserId))
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update user's power
 | 
			
		||||
	tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
 | 
			
		||||
	// 记录算力变化日志
 | 
			
		||||
	if tx.Error == nil && tx.RowsAffected > 0 {
 | 
			
		||||
		user, _ := h.GetLoginUser(c)
 | 
			
		||||
		h.DB.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
			Type:      types.PowerConsume,
 | 
			
		||||
			Amount:    job.Power,
 | 
			
		||||
			Balance:   user.Power - job.Power,
 | 
			
		||||
			Mark:      types.PowerSub,
 | 
			
		||||
			Model:     "mid-journey",
 | 
			
		||||
			Remark:    fmt.Sprintf("Upscale 操作,任务ID:%s", job.TaskId),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		})
 | 
			
		||||
	err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
 | 
			
		||||
		Type:   types.PowerConsume,
 | 
			
		||||
		Model:  "mid-journey",
 | 
			
		||||
		Remark: fmt.Sprintf("Upscale 操作,任务ID:%s", job.TaskId),
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Variation send variation command to MidJourney Bot
 | 
			
		||||
func (h *MidJourneyHandler) Variation(c *gin.Context) {
 | 
			
		||||
	var data reqVo
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil || data.SessionId == "" {
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -339,7 +317,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
 | 
			
		||||
		UserId:      userId,
 | 
			
		||||
		TaskId:      taskId,
 | 
			
		||||
		Progress:    0,
 | 
			
		||||
		Prompt:      data.Prompt,
 | 
			
		||||
		Power:       h.App.SysConfig.MjActionPower,
 | 
			
		||||
		CreatedAt:   time.Now(),
 | 
			
		||||
	}
 | 
			
		||||
@@ -348,40 +325,32 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.pool.PushTask(types.MjTask{
 | 
			
		||||
	h.mjService.PushTask(types.MjTask{
 | 
			
		||||
		Id:          job.Id,
 | 
			
		||||
		SessionId:   data.SessionId,
 | 
			
		||||
		Type:        types.TaskVariation,
 | 
			
		||||
		Prompt:      data.Prompt,
 | 
			
		||||
		UserId:      userId,
 | 
			
		||||
		Index:       data.Index,
 | 
			
		||||
		ChannelId:   data.ChannelId,
 | 
			
		||||
		MessageId:   data.MessageId,
 | 
			
		||||
		MessageHash: data.MessageHash,
 | 
			
		||||
		Mode:        h.App.SysConfig.MjMode,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	client := h.pool.Clients.Get(uint(job.UserId))
 | 
			
		||||
	client := h.mjService.Clients.Get(uint(job.UserId))
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update user's power
 | 
			
		||||
	tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
 | 
			
		||||
	// 记录算力变化日志
 | 
			
		||||
	if tx.Error == nil && tx.RowsAffected > 0 {
 | 
			
		||||
		user, _ := h.GetLoginUser(c)
 | 
			
		||||
		h.DB.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
			Type:      types.PowerConsume,
 | 
			
		||||
			Amount:    job.Power,
 | 
			
		||||
			Balance:   user.Power - job.Power,
 | 
			
		||||
			Mark:      types.PowerSub,
 | 
			
		||||
			Model:     "mid-journey",
 | 
			
		||||
			Remark:    fmt.Sprintf("Variation 操作,任务ID:%s", job.TaskId),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		})
 | 
			
		||||
	err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
 | 
			
		||||
		Type:   types.PowerConsume,
 | 
			
		||||
		Model:  "mid-journey",
 | 
			
		||||
		Remark: fmt.Sprintf("Variation 操作,任务ID:%s", job.TaskId),
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -400,13 +369,13 @@ func (h *MidJourneyHandler) ImgWall(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
// JobList 获取 MJ 任务列表
 | 
			
		||||
func (h *MidJourneyHandler) JobList(c *gin.Context) {
 | 
			
		||||
	status := h.GetBool(c, "status")
 | 
			
		||||
	finish := h.GetBool(c, "finish")
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	page := h.GetInt(c, "page", 0)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 0)
 | 
			
		||||
	publish := h.GetBool(c, "publish")
 | 
			
		||||
 | 
			
		||||
	err, jobs := h.getData(status, userId, page, pageSize, publish)
 | 
			
		||||
	err, jobs := h.getData(finish, userId, page, pageSize, publish)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
@@ -416,10 +385,10 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JobList 获取 MJ 任务列表
 | 
			
		||||
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) {
 | 
			
		||||
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	if finish {
 | 
			
		||||
		session = session.Where("progress = ?", 100).Order("id DESC")
 | 
			
		||||
		session = session.Where("progress >= ?", 100).Order("id DESC")
 | 
			
		||||
	} else {
 | 
			
		||||
		session = session.Where("progress < ?", 100).Order("id ASC")
 | 
			
		||||
	}
 | 
			
		||||
@@ -434,10 +403,14 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
 | 
			
		||||
		session = session.Offset(offset).Limit(pageSize)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 统计总数
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.MidJourneyJob{}).Count(&total)
 | 
			
		||||
 | 
			
		||||
	var items []model.MidJourneyJob
 | 
			
		||||
	res := session.Find(&items)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return res.Error, nil
 | 
			
		||||
		return res.Error, vo.Page{}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var jobs = make([]vo.MidJourneyJob, 0)
 | 
			
		||||
@@ -449,48 +422,57 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
 | 
			
		||||
			// discord 服务器图片需要使用代理转发图片数据流
 | 
			
		||||
			if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") {
 | 
			
		||||
				image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				job.ImgURL = job.OrgURL
 | 
			
		||||
			image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		jobs = append(jobs, job)
 | 
			
		||||
	}
 | 
			
		||||
	return nil, jobs
 | 
			
		||||
	return nil, vo.NewPage(total, page, pageSize, jobs)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Remove remove task image
 | 
			
		||||
func (h *MidJourneyHandler) Remove(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id     uint   `json:"id"`
 | 
			
		||||
		UserId uint   `json:"user_id"`
 | 
			
		||||
		ImgURL string `json:"img_url"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	userId := h.GetInt(c, "user_id", 0)
 | 
			
		||||
	var job model.MidJourneyJob
 | 
			
		||||
	if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "记录不存在")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// remove job recode
 | 
			
		||||
	res := h.DB.Delete(&model.MidJourneyJob{Id: data.Id})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
	tx := h.DB.Begin()
 | 
			
		||||
	if err := tx.Delete(&job).Error; err != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果任务未完成,或者任务失败,则恢复用户算力
 | 
			
		||||
	if job.Progress != 100 {
 | 
			
		||||
		err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
 | 
			
		||||
			Type:   types.PowerRefund,
 | 
			
		||||
			Model:  "mid-journey",
 | 
			
		||||
			Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg),
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	tx.Commit()
 | 
			
		||||
 | 
			
		||||
	// remove image
 | 
			
		||||
	err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
 | 
			
		||||
	err := h.uploader.GetUploadHandler().Delete(job.ImgURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("remove image failed: ", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := h.pool.Clients.Get(data.UserId)
 | 
			
		||||
	client := h.mjService.Clients.Get(uint(job.UserId))
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
@@ -500,19 +482,12 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
// Publish 发布图片到画廊显示
 | 
			
		||||
func (h *MidJourneyHandler) Publish(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id     uint `json:"id"`
 | 
			
		||||
		Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败")
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	userId := h.GetInt(c, "user_id", 0)
 | 
			
		||||
	action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
 | 
			
		||||
	err := h.DB.Model(&model.MidJourneyJob{Id: uint(id), UserId: userId}).UpdateColumn("publish", action).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -14,6 +14,7 @@ import (
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
@@ -27,23 +28,18 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
 | 
			
		||||
	return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// List 订单列表
 | 
			
		||||
func (h *OrderHandler) List(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Page     int `json:"page"`
 | 
			
		||||
		PageSize int `json:"page_size"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	page := h.GetInt(c, "page", 1)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 20)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess)
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.Order{}).Count(&total)
 | 
			
		||||
	var items []model.Order
 | 
			
		||||
	var list = make([]vo.Order, 0)
 | 
			
		||||
	offset := (data.Page - 1) * data.PageSize
 | 
			
		||||
	res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
 | 
			
		||||
	offset := (page - 1) * pageSize
 | 
			
		||||
	res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			var order vo.Order
 | 
			
		||||
@@ -58,5 +54,35 @@ func (h *OrderHandler) List(c *gin.Context) {
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
 | 
			
		||||
	resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Query 查询订单状态
 | 
			
		||||
func (h *OrderHandler) Query(c *gin.Context) {
 | 
			
		||||
	orderNo := h.GetTrim(c, "order_no")
 | 
			
		||||
	var order model.Order
 | 
			
		||||
	res := h.DB.Where("order_no = ?", orderNo).First(&order)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "Order not found")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if order.Status == types.OrderPaidSuccess {
 | 
			
		||||
		resp.SUCCESS(c, gin.H{"status": order.Status})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	counter := 0
 | 
			
		||||
	for {
 | 
			
		||||
		time.Sleep(time.Second)
 | 
			
		||||
		var item model.Order
 | 
			
		||||
		h.DB.Where("order_no = ?", orderNo).First(&item)
 | 
			
		||||
		if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
 | 
			
		||||
			order.Status = item.Status
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		counter++
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, gin.H{"status": order.Status})
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -29,39 +29,48 @@ import (
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	PayWayAlipay = "支付宝"
 | 
			
		||||
	PayWayXunHu  = "虎皮椒"
 | 
			
		||||
	PayWayJs     = "PayJS"
 | 
			
		||||
type PayWay struct {
 | 
			
		||||
	Name  string `json:"name"`
 | 
			
		||||
	Value string `json:"value"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	PayWayAlipay = PayWay{Name: "支付宝", Value: "alipay"}
 | 
			
		||||
	PayWayXunHu  = PayWay{Name: "虎皮椒", Value: "hupi"}
 | 
			
		||||
	PayWayJs     = PayWay{Name: "PayJS", Value: "payjs"}
 | 
			
		||||
	PayWayWechat = PayWay{Name: "微信支付", Value: "wechat"}
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// PaymentHandler 支付服务回调 handler
 | 
			
		||||
type PaymentHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	alipayService  *payment.AlipayService
 | 
			
		||||
	huPiPayService *payment.HuPiPayService
 | 
			
		||||
	js             *payment.PayJS
 | 
			
		||||
	snowflake      *service.Snowflake
 | 
			
		||||
	fs             embed.FS
 | 
			
		||||
	lock           sync.Mutex
 | 
			
		||||
	signKey        string // 用来签名的随机秘钥
 | 
			
		||||
	alipayService    *payment.AlipayService
 | 
			
		||||
	huPiPayService   *payment.HuPiPayService
 | 
			
		||||
	jsPayService     *payment.JPayService
 | 
			
		||||
	wechatPayService *payment.WechatPayService
 | 
			
		||||
	snowflake        *service.Snowflake
 | 
			
		||||
	fs               embed.FS
 | 
			
		||||
	lock             sync.Mutex
 | 
			
		||||
	signKey          string // 用来签名的随机秘钥
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPaymentHandler(
 | 
			
		||||
	server *core.AppServer,
 | 
			
		||||
	alipayService *payment.AlipayService,
 | 
			
		||||
	huPiPayService *payment.HuPiPayService,
 | 
			
		||||
	js *payment.PayJS,
 | 
			
		||||
	jsPayService *payment.JPayService,
 | 
			
		||||
	wechatPayService *payment.WechatPayService,
 | 
			
		||||
	db *gorm.DB,
 | 
			
		||||
	snowflake *service.Snowflake,
 | 
			
		||||
	fs embed.FS) *PaymentHandler {
 | 
			
		||||
	return &PaymentHandler{
 | 
			
		||||
		alipayService:  alipayService,
 | 
			
		||||
		huPiPayService: huPiPayService,
 | 
			
		||||
		js:             js,
 | 
			
		||||
		snowflake:      snowflake,
 | 
			
		||||
		fs:             fs,
 | 
			
		||||
		lock:           sync.Mutex{},
 | 
			
		||||
		alipayService:    alipayService,
 | 
			
		||||
		huPiPayService:   huPiPayService,
 | 
			
		||||
		jsPayService:     jsPayService,
 | 
			
		||||
		wechatPayService: wechatPayService,
 | 
			
		||||
		snowflake:        snowflake,
 | 
			
		||||
		fs:               fs,
 | 
			
		||||
		lock:             sync.Mutex{},
 | 
			
		||||
		BaseHandler: BaseHandler{
 | 
			
		||||
			App: server,
 | 
			
		||||
			DB:  db,
 | 
			
		||||
@@ -102,19 +111,16 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	// fix: 这里先检查一下订单状态,如果已经支付了,就直接返回
 | 
			
		||||
	if order.Status == types.OrderPaidSuccess {
 | 
			
		||||
		resp.ERROR(c, "This order had been paid, please do not pay twice")
 | 
			
		||||
		resp.ERROR(c, "订单已支付成功,无需重复支付!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新扫码状态
 | 
			
		||||
	h.DB.Model(&order).UpdateColumn("status", types.OrderScanned)
 | 
			
		||||
	if payWay == "alipay" { // 支付宝
 | 
			
		||||
		// 生成支付链接
 | 
			
		||||
		notifyURL := h.App.Config.AlipayConfig.NotifyURL
 | 
			
		||||
		returnURL := "" // 关闭同步回跳
 | 
			
		||||
		amount := fmt.Sprintf("%.2f", order.Amount)
 | 
			
		||||
 | 
			
		||||
		uri, err := h.alipayService.PayUrlMobile(order.OrderNo, notifyURL, returnURL, amount, order.Subject)
 | 
			
		||||
	if payWay == "alipay" { // 支付宝
 | 
			
		||||
		amount := fmt.Sprintf("%.2f", order.Amount)
 | 
			
		||||
		uri, err := h.alipayService.PayUrlMobile(order.OrderNo, amount, order.Subject)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, "error with generate pay url: "+err.Error())
 | 
			
		||||
			return
 | 
			
		||||
@@ -142,49 +148,11 @@ func (h *PaymentHandler) DoPay(c *gin.Context) {
 | 
			
		||||
	resp.ERROR(c, "Invalid operations")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OrderQuery 查询订单状态
 | 
			
		||||
func (h *PaymentHandler) OrderQuery(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		OrderNo string `json:"order_no"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var order model.Order
 | 
			
		||||
	res := h.DB.Where("order_no = ?", data.OrderNo).First(&order)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "Order not found")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if order.Status == types.OrderPaidSuccess {
 | 
			
		||||
		resp.SUCCESS(c, gin.H{"status": order.Status})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	counter := 0
 | 
			
		||||
	for {
 | 
			
		||||
		time.Sleep(time.Second)
 | 
			
		||||
		var item model.Order
 | 
			
		||||
		h.DB.Where("order_no = ?", data.OrderNo).First(&item)
 | 
			
		||||
		if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
 | 
			
		||||
			order.Status = item.Status
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		counter++
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, gin.H{"status": order.Status})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PayQrcode 生成支付 URL 二维码
 | 
			
		||||
func (h *PaymentHandler) PayQrcode(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		PayWay    string `json:"pay_way"` // 支付方式
 | 
			
		||||
		ProductId uint   `json:"product_id"`
 | 
			
		||||
		UserId    int    `json:"user_id"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
@@ -203,10 +171,9 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
 | 
			
		||||
		resp.ERROR(c, "error with generate trade no: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res = h.DB.First(&user, data.UserId)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "Invalid user ID")
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -214,14 +181,21 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
 | 
			
		||||
	var notifyURL string
 | 
			
		||||
	switch data.PayWay {
 | 
			
		||||
	case "hupi":
 | 
			
		||||
		payWay = PayWayXunHu
 | 
			
		||||
		payWay = PayWayXunHu.Value
 | 
			
		||||
		notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
 | 
			
		||||
		break
 | 
			
		||||
	case "payjs":
 | 
			
		||||
		payWay = PayWayJs
 | 
			
		||||
		payWay = PayWayJs.Value
 | 
			
		||||
		notifyURL = h.App.Config.JPayConfig.NotifyURL
 | 
			
		||||
	default:
 | 
			
		||||
		payWay = PayWayAlipay
 | 
			
		||||
		break
 | 
			
		||||
	case "alipay":
 | 
			
		||||
		payWay = PayWayAlipay.Value
 | 
			
		||||
		notifyURL = h.App.Config.AlipayConfig.NotifyURL
 | 
			
		||||
		break
 | 
			
		||||
	default:
 | 
			
		||||
		payWay = PayWayWechat.Value
 | 
			
		||||
		notifyURL = h.App.Config.WechatPayConfig.NotifyURL
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
	// 创建订单
 | 
			
		||||
	remark := types.OrderRemark{
 | 
			
		||||
@@ -257,7 +231,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
 | 
			
		||||
			OutTradeNo: order.OrderNo,
 | 
			
		||||
			Subject:    product.Name,
 | 
			
		||||
		}
 | 
			
		||||
		r := h.js.Pay(params)
 | 
			
		||||
		r := h.jsPayService.Pay(params)
 | 
			
		||||
		if r.IsOK() {
 | 
			
		||||
			resp.SUCCESS(c, gin.H{"order_no": order.OrderNo, "image": r.Qrcode})
 | 
			
		||||
			return
 | 
			
		||||
@@ -276,6 +250,8 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
 | 
			
		||||
		} else {
 | 
			
		||||
			logo = "res/img/alipay.jpg"
 | 
			
		||||
		}
 | 
			
		||||
	} else if data.PayWay == "wechat" {
 | 
			
		||||
		logo = "res/img/wechat-pay.jpg"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	file, err := h.fs.Open(logo)
 | 
			
		||||
@@ -292,7 +268,18 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
 | 
			
		||||
	timestamp := time.Now().Unix()
 | 
			
		||||
	signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, data.PayWay, timestamp, h.signKey)
 | 
			
		||||
	sign := utils.Sha256(signStr)
 | 
			
		||||
	imageURL := fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s&t=%d&sign=%s", parse.Scheme, parse.Host, orderNo, data.PayWay, timestamp, sign)
 | 
			
		||||
	var imageURL string
 | 
			
		||||
	if data.PayWay == "wechat" {
 | 
			
		||||
		payUrl, err := h.wechatPayService.PayUrlNative(order.OrderNo, int(math.Floor(order.Amount*100)), product.Name)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, "error with generating wechat payment qrcode: "+err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		} else {
 | 
			
		||||
			imageURL = payUrl
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		imageURL = fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s&t=%d&sign=%s", parse.Scheme, parse.Host, orderNo, data.PayWay, timestamp, sign)
 | 
			
		||||
	}
 | 
			
		||||
	imgData, err := utils.GenQrcode(imageURL, 400, file)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
@@ -307,7 +294,6 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		PayWay    string `json:"pay_way"` // 支付方式
 | 
			
		||||
		ProductId uint   `json:"product_id"`
 | 
			
		||||
		UserId    int    `json:"user_id"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
@@ -326,10 +312,9 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
 | 
			
		||||
		resp.ERROR(c, "error with generate trade no: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res = h.DB.First(&user, data.UserId)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "Invalid user ID")
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -339,7 +324,7 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
 | 
			
		||||
	var payURL string
 | 
			
		||||
	switch data.PayWay {
 | 
			
		||||
	case "hupi":
 | 
			
		||||
		payWay = PayWayXunHu
 | 
			
		||||
		payWay = PayWayXunHu.Name
 | 
			
		||||
		notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
 | 
			
		||||
		returnURL = h.App.Config.HuPiPayConfig.ReturnURL
 | 
			
		||||
		parse, _ := url.Parse(h.App.Config.HuPiPayConfig.ReturnURL)
 | 
			
		||||
@@ -358,13 +343,14 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
 | 
			
		||||
		}
 | 
			
		||||
		r, err := h.huPiPayService.Pay(params)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("error with generating Pay URL: ", err.Error())
 | 
			
		||||
			resp.ERROR(c, "error with generating Pay URL: "+err.Error())
 | 
			
		||||
			errMsg := "error with generating Pay Hupi URL: " + err.Error()
 | 
			
		||||
			logger.Error(errMsg)
 | 
			
		||||
			resp.ERROR(c, errMsg)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		payURL = r.URL
 | 
			
		||||
	case "payjs":
 | 
			
		||||
		payWay = PayWayJs
 | 
			
		||||
		payWay = PayWayJs.Name
 | 
			
		||||
		notifyURL = h.App.Config.JPayConfig.NotifyURL
 | 
			
		||||
		returnURL = h.App.Config.JPayConfig.ReturnURL
 | 
			
		||||
		totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart()
 | 
			
		||||
@@ -374,14 +360,22 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
 | 
			
		||||
		params.Add("body", product.Name)
 | 
			
		||||
		params.Add("notify_url", notifyURL)
 | 
			
		||||
		params.Add("auto", "0")
 | 
			
		||||
		payURL = h.js.PayH5(params)
 | 
			
		||||
		payURL = h.jsPayService.PayH5(params)
 | 
			
		||||
	case "alipay":
 | 
			
		||||
		payWay = PayWayAlipay
 | 
			
		||||
		notifyURL = h.App.Config.AlipayConfig.NotifyURL
 | 
			
		||||
		returnURL = h.App.Config.AlipayConfig.ReturnURL
 | 
			
		||||
		payURL, err = h.alipayService.PayUrlMobile(orderNo, notifyURL, returnURL, fmt.Sprintf("%.2f", amount), product.Name)
 | 
			
		||||
		payWay = PayWayAlipay.Name
 | 
			
		||||
		payURL, err = h.alipayService.PayUrlMobile(orderNo, fmt.Sprintf("%.2f", amount), product.Name)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, "error with generating Pay URL: "+err.Error())
 | 
			
		||||
			errMsg := "error with generating Alipay URL: " + err.Error()
 | 
			
		||||
			resp.ERROR(c, errMsg)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	case "wechat":
 | 
			
		||||
		payWay = PayWayWechat.Name
 | 
			
		||||
		payURL, err = h.wechatPayService.PayUrlH5(orderNo, int(amount*100), product.Name, c.ClientIP())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			errMsg := "error with generating Wechat URL: " + err.Error()
 | 
			
		||||
			logger.Error(errMsg)
 | 
			
		||||
			resp.ERROR(c, errMsg)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
@@ -414,7 +408,7 @@ func (h *PaymentHandler) Mobile(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, payURL)
 | 
			
		||||
	resp.SUCCESS(c, gin.H{"url": payURL, "order_no": orderNo})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 异步通知回调公共逻辑
 | 
			
		||||
@@ -493,7 +487,7 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
 | 
			
		||||
	h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
 | 
			
		||||
 | 
			
		||||
	// 记录算力充值日志
 | 
			
		||||
	if opt != "" {
 | 
			
		||||
	if power > 0 {
 | 
			
		||||
		h.DB.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
@@ -522,6 +516,9 @@ func (h *PaymentHandler) GetPayWays(c *gin.Context) {
 | 
			
		||||
	if h.App.Config.JPayConfig.Enabled {
 | 
			
		||||
		data["payjs"] = gin.H{"name": h.App.Config.JPayConfig.Name}
 | 
			
		||||
	}
 | 
			
		||||
	if h.App.Config.WechatPayConfig.Enabled {
 | 
			
		||||
		data["wechat"] = gin.H{"name": "wechat"}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -560,7 +557,7 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO:验证交易签名
 | 
			
		||||
	res := h.alipayService.TradeVerify(c.Request.Form)
 | 
			
		||||
	res := h.alipayService.TradeVerify(c.Request)
 | 
			
		||||
	logger.Infof("验证支付结果:%+v", res)
 | 
			
		||||
	if !res.Success() {
 | 
			
		||||
		logger.Error("订单校验失败:", res.Message)
 | 
			
		||||
@@ -588,7 +585,7 @@ func (h *PaymentHandler) PayJsNotify(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	orderNo := c.Request.Form.Get("out_trade_no")
 | 
			
		||||
	returnCode := c.Request.Form.Get("return_code")
 | 
			
		||||
	logger.Infof("收到订单支付回调,订单 NO:%s,支付结果代码:%v", orderNo, returnCode)
 | 
			
		||||
	logger.Infof("收到PayJs订单支付回调,订单 NO:%s,支付结果代码:%v", orderNo, returnCode)
 | 
			
		||||
	// 支付失败
 | 
			
		||||
	if returnCode != "1" {
 | 
			
		||||
		return
 | 
			
		||||
@@ -596,7 +593,7 @@ func (h *PaymentHandler) PayJsNotify(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	// 校验订单支付状态
 | 
			
		||||
	tradeNo := c.Request.Form.Get("payjs_order_id")
 | 
			
		||||
	err = h.js.Check(tradeNo)
 | 
			
		||||
	err = h.jsPayService.TradeVerify(tradeNo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("订单校验失败:", err)
 | 
			
		||||
		c.String(http.StatusOK, "fail")
 | 
			
		||||
@@ -611,3 +608,30 @@ func (h *PaymentHandler) PayJsNotify(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	c.String(http.StatusOK, "success")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WechatPayNotify 微信商户支付异步回调
 | 
			
		||||
func (h *PaymentHandler) WechatPayNotify(c *gin.Context) {
 | 
			
		||||
	err := c.Request.ParseForm()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.String(http.StatusOK, "fail")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := h.wechatPayService.TradeVerify(c.Request)
 | 
			
		||||
	if !result.Success() {
 | 
			
		||||
		logger.Error("订单校验失败:", err)
 | 
			
		||||
		c.JSON(http.StatusBadRequest, gin.H{
 | 
			
		||||
			"code":    "FAIL",
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = h.notify(result.OutTradeNo, result.TradeId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.String(http.StatusOK, "fail")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.String(http.StatusOK, "success")
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										88
									
								
								api/handler/redeem_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								api/handler/redeem_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,88 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type RedeemHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	lock        sync.Mutex
 | 
			
		||||
	userService *service.UserService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewRedeemHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *RedeemHandler {
 | 
			
		||||
	return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}, userService: userService}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *RedeemHandler) Verify(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Code string `json:"code"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
 | 
			
		||||
	h.lock.Lock()
 | 
			
		||||
	defer h.lock.Unlock()
 | 
			
		||||
 | 
			
		||||
	var item model.Redeem
 | 
			
		||||
	res := h.DB.Where("code", data.Code).First(&item)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "无效的兑换码!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !item.Enabled {
 | 
			
		||||
		resp.ERROR(c, "当前兑换码已被禁用!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if item.RedeemedAt > 0 {
 | 
			
		||||
		resp.ERROR(c, "当前兑换码已使用,请勿重复使用!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx := h.DB.Begin()
 | 
			
		||||
	err := h.userService.IncreasePower(int(userId), item.Power, model.PowerLog{
 | 
			
		||||
		Type:   types.PowerRedeem,
 | 
			
		||||
		Model:  "兑换码",
 | 
			
		||||
		Remark: fmt.Sprintf("兑换码核销,算力:%d,兑换码:%s...", item.Power, item.Code[:10]),
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新核销状态
 | 
			
		||||
	item.RedeemedAt = time.Now().Unix()
 | 
			
		||||
	item.UserId = userId
 | 
			
		||||
	err = tx.Updates(&item).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx.Commit()
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@@ -1,108 +0,0 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"math"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type RewardHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	lock sync.Mutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
 | 
			
		||||
	return &RewardHandler{BaseHandler: BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Verify 打赏码核销
 | 
			
		||||
func (h *RewardHandler) Verify(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		TxId string `json:"tx_id"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.HACKER(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 移除转账单号中间的空格,防止有人复制的时候多复制了空格
 | 
			
		||||
	data.TxId = strings.ReplaceAll(data.TxId, " ", "")
 | 
			
		||||
 | 
			
		||||
	h.lock.Lock()
 | 
			
		||||
	defer h.lock.Unlock()
 | 
			
		||||
 | 
			
		||||
	var item model.Reward
 | 
			
		||||
	res := h.DB.Where("tx_id = ?", data.TxId).First(&item)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "无效的众筹交易流水号!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if item.Status {
 | 
			
		||||
		resp.ERROR(c, "当前众筹交易流水号已经被核销,请不要重复核销!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx := h.DB.Begin()
 | 
			
		||||
	exchange := vo.RewardExchange{}
 | 
			
		||||
	power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice)
 | 
			
		||||
	exchange.Power = int(power)
 | 
			
		||||
	res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power))
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		logger.Error("添加应用失败:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新核销状态
 | 
			
		||||
	item.Status = true
 | 
			
		||||
	item.UserId = user.Id
 | 
			
		||||
	item.Exchange = utils.JsonEncode(exchange)
 | 
			
		||||
	res = tx.Updates(&item)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		logger.Error("添加应用失败:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 记录算力充值日志
 | 
			
		||||
	h.DB.Create(&model.PowerLog{
 | 
			
		||||
		UserId:    user.Id,
 | 
			
		||||
		Username:  user.Username,
 | 
			
		||||
		Type:      types.PowerReward,
 | 
			
		||||
		Amount:    exchange.Power,
 | 
			
		||||
		Balance:   user.Power + exchange.Power,
 | 
			
		||||
		Mark:      types.PowerAdd,
 | 
			
		||||
		Model:     "众筹支付",
 | 
			
		||||
		Remark:    fmt.Sprintf("众筹充值算力,金额:%f,价格:%f", item.Amount, h.App.SysConfig.PowerPrice),
 | 
			
		||||
		CreatedAt: time.Now(),
 | 
			
		||||
	})
 | 
			
		||||
	tx.Commit()
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@@ -31,19 +31,27 @@ import (
 | 
			
		||||
 | 
			
		||||
type SdJobHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	redis     *redis.Client
 | 
			
		||||
	pool      *sd.ServicePool
 | 
			
		||||
	uploader  *oss.UploaderManager
 | 
			
		||||
	snowflake *service.Snowflake
 | 
			
		||||
	leveldb   *store.LevelDB
 | 
			
		||||
	redis       *redis.Client
 | 
			
		||||
	sdService   *sd.Service
 | 
			
		||||
	uploader    *oss.UploaderManager
 | 
			
		||||
	snowflake   *service.Snowflake
 | 
			
		||||
	leveldb     *store.LevelDB
 | 
			
		||||
	userService *service.UserService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
 | 
			
		||||
func NewSdJobHandler(app *core.AppServer,
 | 
			
		||||
	db *gorm.DB,
 | 
			
		||||
	service *sd.Service,
 | 
			
		||||
	manager *oss.UploaderManager,
 | 
			
		||||
	snowflake *service.Snowflake,
 | 
			
		||||
	userService *service.UserService,
 | 
			
		||||
	levelDB *store.LevelDB) *SdJobHandler {
 | 
			
		||||
	return &SdJobHandler{
 | 
			
		||||
		pool:      pool,
 | 
			
		||||
		uploader:  manager,
 | 
			
		||||
		snowflake: snowflake,
 | 
			
		||||
		leveldb:   levelDB,
 | 
			
		||||
		sdService:   service,
 | 
			
		||||
		uploader:    manager,
 | 
			
		||||
		snowflake:   snowflake,
 | 
			
		||||
		leveldb:     levelDB,
 | 
			
		||||
		userService: userService,
 | 
			
		||||
		BaseHandler: BaseHandler{
 | 
			
		||||
			App: app,
 | 
			
		||||
			DB:  db,
 | 
			
		||||
@@ -68,7 +76,7 @@ func (h *SdJobHandler) Client(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	h.pool.Clients.Put(uint(userId), client)
 | 
			
		||||
	h.sdService.Clients.Put(uint(userId), client)
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -79,11 +87,6 @@ func (h *SdJobHandler) preCheck(c *gin.Context) bool {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !h.pool.HasAvailableService() {
 | 
			
		||||
		resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!")
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Power < h.App.SysConfig.SdPower {
 | 
			
		||||
		resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
 | 
			
		||||
		return false
 | 
			
		||||
@@ -99,10 +102,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var data struct {
 | 
			
		||||
		SessionId string `json:"session_id"`
 | 
			
		||||
		types.SdTaskParams
 | 
			
		||||
	}
 | 
			
		||||
	var data types.SdTaskParams
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
@@ -167,35 +167,27 @@ func (h *SdJobHandler) Image(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.pool.PushTask(types.SdTask{
 | 
			
		||||
		Id:        int(job.Id),
 | 
			
		||||
		SessionId: data.SessionId,
 | 
			
		||||
		Type:      types.TaskImage,
 | 
			
		||||
		Params:    params,
 | 
			
		||||
		UserId:    userId,
 | 
			
		||||
	h.sdService.PushTask(types.SdTask{
 | 
			
		||||
		Id:     int(job.Id),
 | 
			
		||||
		Type:   types.TaskImage,
 | 
			
		||||
		Params: params,
 | 
			
		||||
		UserId: userId,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	client := h.pool.Clients.Get(uint(job.UserId))
 | 
			
		||||
	client := h.sdService.Clients.Get(uint(job.UserId))
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update user's power
 | 
			
		||||
	tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
 | 
			
		||||
	// 记录算力变化日志
 | 
			
		||||
	if tx.Error == nil && tx.RowsAffected > 0 {
 | 
			
		||||
		user, _ := h.GetLoginUser(c)
 | 
			
		||||
		h.DB.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
			Type:      types.PowerConsume,
 | 
			
		||||
			Amount:    job.Power,
 | 
			
		||||
			Balance:   user.Power - job.Power,
 | 
			
		||||
			Mark:      types.PowerSub,
 | 
			
		||||
			Model:     "stable-diffusion",
 | 
			
		||||
			Remark:    fmt.Sprintf("绘图操作,任务ID:%s", job.TaskId),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		})
 | 
			
		||||
	err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
 | 
			
		||||
		Type:   types.PowerConsume,
 | 
			
		||||
		Model:  "stable-diffusion",
 | 
			
		||||
		Remark: fmt.Sprintf("绘图操作,任务ID:%s", job.TaskId),
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
@@ -216,13 +208,13 @@ func (h *SdJobHandler) ImgWall(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
// JobList 获取 SD 任务列表
 | 
			
		||||
func (h *SdJobHandler) JobList(c *gin.Context) {
 | 
			
		||||
	status := h.GetBool(c, "status")
 | 
			
		||||
	finish := h.GetBool(c, "finish")
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	page := h.GetInt(c, "page", 0)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 0)
 | 
			
		||||
	publish := h.GetBool(c, "publish")
 | 
			
		||||
 | 
			
		||||
	err, jobs := h.getData(status, userId, page, pageSize, publish)
 | 
			
		||||
	err, jobs := h.getData(finish, userId, page, pageSize, publish)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
@@ -232,11 +224,11 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JobList 获取 MJ 任务列表
 | 
			
		||||
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) {
 | 
			
		||||
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
 | 
			
		||||
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	if finish {
 | 
			
		||||
		session = session.Where("progress = ?", 100).Order("id DESC")
 | 
			
		||||
		session = session.Where("progress >= ?", 100).Order("id DESC")
 | 
			
		||||
	} else {
 | 
			
		||||
		session = session.Where("progress < ?", 100).Order("id ASC")
 | 
			
		||||
	}
 | 
			
		||||
@@ -251,10 +243,14 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
 | 
			
		||||
		session = session.Offset(offset).Limit(pageSize)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 统计总数
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.SdJob{}).Count(&total)
 | 
			
		||||
 | 
			
		||||
	var items []model.SdJob
 | 
			
		||||
	res := session.Find(&items)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return res.Error, nil
 | 
			
		||||
		return res.Error, vo.Page{}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var jobs = make([]vo.SdJob, 0)
 | 
			
		||||
@@ -276,57 +272,60 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
 | 
			
		||||
		jobs = append(jobs, job)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, jobs
 | 
			
		||||
	return nil, vo.NewPage(total, page, pageSize, jobs)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Remove remove task image
 | 
			
		||||
func (h *SdJobHandler) Remove(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id     uint   `json:"id"`
 | 
			
		||||
		UserId uint   `json:"user_id"`
 | 
			
		||||
		ImgURL string `json:"img_url"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	var job model.SdJob
 | 
			
		||||
	if res := h.DB.Where("id = ? AND user_id = ?", id, userId).First(&job); res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "记录不存在")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// remove job recode
 | 
			
		||||
	res := h.DB.Delete(&model.SdJob{Id: data.Id})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
	// 删除任务
 | 
			
		||||
	tx := h.DB.Begin()
 | 
			
		||||
	if err := tx.Delete(&job).Error; err != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果任务未完成,或者任务失败,则恢复用户算力
 | 
			
		||||
	if job.Progress != 100 {
 | 
			
		||||
		err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
 | 
			
		||||
			Type:   types.PowerRefund,
 | 
			
		||||
			Model:  "stable-diffusion",
 | 
			
		||||
			Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%s, Err: %s", job.TaskId, job.ErrMsg),
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	tx.Commit()
 | 
			
		||||
 | 
			
		||||
	// remove image
 | 
			
		||||
	err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
 | 
			
		||||
	err := h.uploader.GetUploadHandler().Delete(job.ImgURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("remove image failed: ", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := h.pool.Clients.Get(data.UserId)
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte(sd.Finished))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Publish 发布/取消发布图片到画廊显示
 | 
			
		||||
func (h *SdJobHandler) Publish(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id     uint `json:"id"`
 | 
			
		||||
		Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	action := h.GetBool(c, "action") // 发布动作,true => 发布,false => 取消分享
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败")
 | 
			
		||||
	err := h.DB.Model(&model.SdJob{Id: uint(id), UserId: int(userId)}).UpdateColumn("publish", action).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -49,28 +49,36 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Receiver string `json:"receiver"` // 接收者
 | 
			
		||||
		Key      string `json:"key"`
 | 
			
		||||
		Dots     string `json:"dots"`
 | 
			
		||||
		Dots     string `json:"dots,omitempty"`
 | 
			
		||||
		X        int    `json:"x,omitempty"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !h.captcha.Check(data) {
 | 
			
		||||
		resp.ERROR(c, "验证码错误,请先完人机验证")
 | 
			
		||||
		return
 | 
			
		||||
	if h.App.SysConfig.EnabledVerify {
 | 
			
		||||
		var check bool
 | 
			
		||||
		if data.X != 0 {
 | 
			
		||||
			check = h.captcha.SlideCheck(data)
 | 
			
		||||
		} else {
 | 
			
		||||
			check = h.captcha.Check(data)
 | 
			
		||||
		}
 | 
			
		||||
		if !check {
 | 
			
		||||
			resp.ERROR(c, "请先完人机验证")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	code := utils.RandomNumber(6)
 | 
			
		||||
	var err error
 | 
			
		||||
	if strings.Contains(data.Receiver, "@") { // email
 | 
			
		||||
		if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "email") {
 | 
			
		||||
		if !utils.Contains(h.App.SysConfig.RegisterWays, "email") {
 | 
			
		||||
			resp.ERROR(c, "系统已禁用邮箱注册!")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		err = h.smtp.SendVerifyCode(data.Receiver, code)
 | 
			
		||||
	} else {
 | 
			
		||||
		if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "mobile") {
 | 
			
		||||
		if !utils.Contains(h.App.SysConfig.RegisterWays, "mobile") {
 | 
			
		||||
			resp.ERROR(c, "系统已禁用手机号注册!")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
@@ -89,5 +97,9 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
	if h.App.Debug {
 | 
			
		||||
		resp.SUCCESS(c, code)
 | 
			
		||||
	} else {
 | 
			
		||||
		resp.SUCCESS(c)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										382
									
								
								api/handler/suno_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										382
									
								
								api/handler/suno_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,382 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/service/suno"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type SunoHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	sunoService *suno.Service
 | 
			
		||||
	uploader    *oss.UploaderManager
 | 
			
		||||
	userService *service.UserService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService) *SunoHandler {
 | 
			
		||||
	return &SunoHandler{
 | 
			
		||||
		BaseHandler: BaseHandler{
 | 
			
		||||
			App: app,
 | 
			
		||||
			DB:  db,
 | 
			
		||||
		},
 | 
			
		||||
		sunoService: service,
 | 
			
		||||
		uploader:    uploader,
 | 
			
		||||
		userService: userService,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Client WebSocket 客户端,用于通知任务状态变更
 | 
			
		||||
func (h *SunoHandler) Client(c *gin.Context) {
 | 
			
		||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := h.GetInt(c, "user_id", 0)
 | 
			
		||||
	if userId == 0 {
 | 
			
		||||
		logger.Info("Invalid user ID")
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	h.sunoService.Clients.Put(uint(userId), client)
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *SunoHandler) Create(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Prompt       string `json:"prompt"`
 | 
			
		||||
		Instrumental bool   `json:"instrumental"`
 | 
			
		||||
		Lyrics       string `json:"lyrics"`
 | 
			
		||||
		Model        string `json:"model"`
 | 
			
		||||
		Tags         string `json:"tags"`
 | 
			
		||||
		Title        string `json:"title"`
 | 
			
		||||
		Type         int    `json:"type"`
 | 
			
		||||
		RefTaskId    string `json:"ref_task_id"`         // 续写的任务id
 | 
			
		||||
		ExtendSecs   int    `json:"extend_secs"`         // 续写秒数
 | 
			
		||||
		RefSongId    string `json:"ref_song_id"`         // 续写的歌曲id
 | 
			
		||||
		SongId       string `json:"song_id,omitempty"`   // 要拼接的歌曲id
 | 
			
		||||
		AudioURL     string `json:"audio_url,omitempty"` // 上传自己创作的歌曲
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 歌曲拼接
 | 
			
		||||
	if data.SongId != "" && data.Type == 3 {
 | 
			
		||||
		var song model.SunoJob
 | 
			
		||||
		if err := h.DB.Where("song_id = ?", data.SongId).First(&song).Error; err == nil {
 | 
			
		||||
			data.Instrumental = song.Instrumental
 | 
			
		||||
			data.Model = song.ModelName
 | 
			
		||||
			data.Tags = song.Tags
 | 
			
		||||
		}
 | 
			
		||||
		// 拼接歌词
 | 
			
		||||
		var refSong model.SunoJob
 | 
			
		||||
		if err := h.DB.Where("song_id = ?", data.RefSongId).First(&refSong).Error; err == nil {
 | 
			
		||||
			data.Prompt = fmt.Sprintf("%s\n%s", song.Prompt, refSong.Prompt)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 插入数据库
 | 
			
		||||
	job := model.SunoJob{
 | 
			
		||||
		UserId:       int(h.GetLoginUserId(c)),
 | 
			
		||||
		Prompt:       data.Prompt,
 | 
			
		||||
		Instrumental: data.Instrumental,
 | 
			
		||||
		ModelName:    data.Model,
 | 
			
		||||
		Tags:         data.Tags,
 | 
			
		||||
		Title:        data.Title,
 | 
			
		||||
		Type:         data.Type,
 | 
			
		||||
		RefSongId:    data.RefSongId,
 | 
			
		||||
		RefTaskId:    data.RefTaskId,
 | 
			
		||||
		ExtendSecs:   data.ExtendSecs,
 | 
			
		||||
		Power:        h.App.SysConfig.SunoPower,
 | 
			
		||||
		SongId:       utils.RandString(32),
 | 
			
		||||
	}
 | 
			
		||||
	if data.Lyrics != "" {
 | 
			
		||||
		job.Prompt = data.Lyrics
 | 
			
		||||
	}
 | 
			
		||||
	tx := h.DB.Create(&job)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		resp.ERROR(c, tx.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 创建任务
 | 
			
		||||
	h.sunoService.PushTask(types.SunoTask{
 | 
			
		||||
		Id:           job.Id,
 | 
			
		||||
		UserId:       job.UserId,
 | 
			
		||||
		Type:         job.Type,
 | 
			
		||||
		Title:        job.Title,
 | 
			
		||||
		RefTaskId:    data.RefTaskId,
 | 
			
		||||
		RefSongId:    data.RefSongId,
 | 
			
		||||
		ExtendSecs:   data.ExtendSecs,
 | 
			
		||||
		Prompt:       job.Prompt,
 | 
			
		||||
		Tags:         data.Tags,
 | 
			
		||||
		Model:        data.Model,
 | 
			
		||||
		Instrumental: data.Instrumental,
 | 
			
		||||
		SongId:       data.SongId,
 | 
			
		||||
		AudioURL:     data.AudioURL,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	// update user's power
 | 
			
		||||
	err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
 | 
			
		||||
		Type:      types.PowerConsume,
 | 
			
		||||
		Remark:    fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName),
 | 
			
		||||
		CreatedAt: time.Now(),
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := h.sunoService.Clients.Get(uint(job.UserId))
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *SunoHandler) List(c *gin.Context) {
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	page := h.GetInt(c, "page", 1)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 20)
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
 | 
			
		||||
 | 
			
		||||
	// 统计总数
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.SunoJob{}).Count(&total)
 | 
			
		||||
 | 
			
		||||
	if page > 0 && pageSize > 0 {
 | 
			
		||||
		offset := (page - 1) * pageSize
 | 
			
		||||
		session = session.Offset(offset).Limit(pageSize)
 | 
			
		||||
	}
 | 
			
		||||
	var list []model.SunoJob
 | 
			
		||||
	err := session.Order("id desc").Find(&list).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 初始化续写关系
 | 
			
		||||
	songIds := make([]string, 0)
 | 
			
		||||
	for _, v := range list {
 | 
			
		||||
		if v.RefTaskId != "" {
 | 
			
		||||
			songIds = append(songIds, v.RefSongId)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	var tasks []model.SunoJob
 | 
			
		||||
	h.DB.Where("song_id IN ?", songIds).Find(&tasks)
 | 
			
		||||
	songMap := make(map[string]model.SunoJob)
 | 
			
		||||
	for _, t := range tasks {
 | 
			
		||||
		songMap[t.SongId] = t
 | 
			
		||||
	}
 | 
			
		||||
	// 转换为 VO
 | 
			
		||||
	items := make([]vo.SunoJob, 0)
 | 
			
		||||
	for _, v := range list {
 | 
			
		||||
		var item vo.SunoJob
 | 
			
		||||
		err = utils.CopyObject(v, &item)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		item.CreatedAt = v.CreatedAt.Unix()
 | 
			
		||||
		if s, ok := songMap[v.RefSongId]; ok {
 | 
			
		||||
			item.RefSong = map[string]interface{}{
 | 
			
		||||
				"id":    s.Id,
 | 
			
		||||
				"title": s.Title,
 | 
			
		||||
				"cover": s.CoverURL,
 | 
			
		||||
				"audio": s.AudioURL,
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		items = append(items, item)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *SunoHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	var job model.SunoJob
 | 
			
		||||
	err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 删除任务
 | 
			
		||||
	tx := h.DB.Begin()
 | 
			
		||||
	if err := tx.Delete(&job).Error; err != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果任务未完成,或者任务失败,则恢复用户算力
 | 
			
		||||
	if job.Progress != 100 {
 | 
			
		||||
		err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
 | 
			
		||||
			Type:   types.PowerRefund,
 | 
			
		||||
			Model:  job.ModelName,
 | 
			
		||||
			Remark: fmt.Sprintf("Suno 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg),
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	tx.Commit()
 | 
			
		||||
 | 
			
		||||
	// 删除文件
 | 
			
		||||
	_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
 | 
			
		||||
	_ = h.uploader.GetUploadHandler().Delete(job.AudioURL)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *SunoHandler) Publish(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	publish := h.GetBool(c, "publish")
 | 
			
		||||
	err := h.DB.Model(&model.SunoJob{}).Where("id", id).Where("user_id", userId).UpdateColumn("publish", publish).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *SunoHandler) Update(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id    int    `json:"id"`
 | 
			
		||||
		Title string `json:"title"`
 | 
			
		||||
		Cover string `json:"cover"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if data.Id == 0 || data.Title == "" || data.Cover == "" {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	var item model.SunoJob
 | 
			
		||||
	if err := h.DB.Where("id", data.Id).Where("user_id", userId).First(&item).Error; err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	item.Title = data.Title
 | 
			
		||||
	item.CoverURL = data.Cover
 | 
			
		||||
 | 
			
		||||
	if err := h.DB.Updates(&item).Error; err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Detail 歌曲详情
 | 
			
		||||
func (h *SunoHandler) Detail(c *gin.Context) {
 | 
			
		||||
	songId := c.Query("song_id")
 | 
			
		||||
	if songId == "" {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var item model.SunoJob
 | 
			
		||||
	if err := h.DB.Where("song_id", songId).First(&item).Error; err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 读取用户信息
 | 
			
		||||
	var user model.User
 | 
			
		||||
	if err := h.DB.Where("id", item.UserId).First(&user).Error; err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var itemVo vo.SunoJob
 | 
			
		||||
	if err := utils.CopyObject(item, &itemVo); err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	itemVo.CreatedAt = item.CreatedAt.Unix()
 | 
			
		||||
	itemVo.User = map[string]interface{}{
 | 
			
		||||
		"nickname": user.Nickname,
 | 
			
		||||
		"avatar":   user.Avatar,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, itemVo)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Play 增加歌曲播放次数
 | 
			
		||||
func (h *SunoHandler) Play(c *gin.Context) {
 | 
			
		||||
	songId := c.Query("song_id")
 | 
			
		||||
	if songId == "" {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	h.DB.Model(&model.SunoJob{}).Where("song_id", songId).UpdateColumn("play_times", gorm.Expr("play_times + ?", 1))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const genLyricTemplate = `
 | 
			
		||||
你是一位才华横溢的作曲家,拥有丰富的情感和细腻的笔触,你对文字有着独特的感悟力,能将各种情感和意境巧妙地融入歌词中。
 | 
			
		||||
请以【%s】为主题创作一首歌曲,歌曲时间不要太短,3分钟左右,不要输出任何解释性的内容。
 | 
			
		||||
输出格式如下:
 | 
			
		||||
歌曲名称
 | 
			
		||||
第一节:
 | 
			
		||||
{{歌词内容}}
 | 
			
		||||
副歌:
 | 
			
		||||
{{歌词内容}}
 | 
			
		||||
 | 
			
		||||
第二节:
 | 
			
		||||
{{歌词内容}}
 | 
			
		||||
副歌:
 | 
			
		||||
{{歌词内容}}
 | 
			
		||||
 | 
			
		||||
尾声:
 | 
			
		||||
{{歌词内容}}
 | 
			
		||||
`
 | 
			
		||||
 | 
			
		||||
// Lyric 生成歌词
 | 
			
		||||
func (h *SunoHandler) Lyric(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Prompt string `json:"prompt"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, content)
 | 
			
		||||
}
 | 
			
		||||
@@ -3,15 +3,52 @@ package handler
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/payment"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TestHandler struct {
 | 
			
		||||
	db        *gorm.DB
 | 
			
		||||
	snowflake *service.Snowflake
 | 
			
		||||
	js        *payment.PayJS
 | 
			
		||||
	js        *payment.JPayService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler {
 | 
			
		||||
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.JPayService) *TestHandler {
 | 
			
		||||
	return &TestHandler{db: db, snowflake: snowflake, js: js}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *TestHandler) SseTest(c *gin.Context) {
 | 
			
		||||
	//c.Header("Content-Type", "text/event-stream")
 | 
			
		||||
	//c.Header("Cache-Control", "no-cache")
 | 
			
		||||
	//c.Header("Connection", "keep-alive")
 | 
			
		||||
	//
 | 
			
		||||
	//
 | 
			
		||||
	//// 模拟实时数据更新
 | 
			
		||||
	//for i := 0; i < 10; i++ {
 | 
			
		||||
	//	// 发送 SSE 数据
 | 
			
		||||
	//	_, err := fmt.Fprintf(c.Writer, "data: %v\n\n", data)
 | 
			
		||||
	//	if err != nil {
 | 
			
		||||
	//		return
 | 
			
		||||
	//	}
 | 
			
		||||
	//	c.Writer.Flush()            // 确保立即发送数据
 | 
			
		||||
	//	time.Sleep(1 * time.Second) // 每秒发送一次数据
 | 
			
		||||
	//}
 | 
			
		||||
	//c.Abort()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *TestHandler) PostTest(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Message string `json:"message"`
 | 
			
		||||
		UserId  uint   `json:"user_id"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 将参数存储在上下文中
 | 
			
		||||
	c.Set("data", data)
 | 
			
		||||
	c.Next()
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -9,6 +9,7 @@ package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
@@ -16,25 +17,33 @@ import (
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type UploadHandler struct {
 | 
			
		||||
type NetHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	uploaderManager *oss.UploaderManager
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
 | 
			
		||||
	return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
 | 
			
		||||
func NewNetHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *NetHandler {
 | 
			
		||||
	return &NetHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *UploadHandler) Upload(c *gin.Context) {
 | 
			
		||||
func (h *NetHandler) Upload(c *gin.Context) {
 | 
			
		||||
	file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Info("upload file: ", file.Name)
 | 
			
		||||
	// cut the file name if it's too long
 | 
			
		||||
	if len(file.Name) > 100 {
 | 
			
		||||
		file.Name = file.Name[:90] + file.Ext
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	res := h.DB.Create(&model.File{
 | 
			
		||||
		UserId:    int(userId),
 | 
			
		||||
@@ -53,11 +62,24 @@ func (h *UploadHandler) Upload(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c, file)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *UploadHandler) List(c *gin.Context) {
 | 
			
		||||
func (h *NetHandler) List(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Urls []string `json:"urls,omitempty"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	var items []model.File
 | 
			
		||||
	var files = make([]vo.File, 0)
 | 
			
		||||
	h.DB.Where("user_id = ?", userId).Find(&items)
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	session = session.Where("user_id = ?", userId)
 | 
			
		||||
	if len(data.Urls) > 0 {
 | 
			
		||||
		session = session.Where("url IN ?", data.Urls)
 | 
			
		||||
	}
 | 
			
		||||
	session.Find(&items)
 | 
			
		||||
	if len(items) > 0 {
 | 
			
		||||
		for _, v := range items {
 | 
			
		||||
			var file vo.File
 | 
			
		||||
@@ -75,7 +97,7 @@ func (h *UploadHandler) List(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Remove remove files
 | 
			
		||||
func (h *UploadHandler) Remove(c *gin.Context) {
 | 
			
		||||
func (h *NetHandler) Remove(c *gin.Context) {
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	var file model.File
 | 
			
		||||
@@ -99,3 +121,28 @@ func (h *UploadHandler) Remove(c *gin.Context) {
 | 
			
		||||
	_ = h.uploaderManager.GetUploadHandler().Delete(objectKey)
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *NetHandler) Download(c *gin.Context) {
 | 
			
		||||
	fileUrl := c.Query("url")
 | 
			
		||||
	// 使用http工具下载文件
 | 
			
		||||
	if fileUrl == "" {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 使用http.Get下载文件
 | 
			
		||||
	r, err := http.Get(fileUrl)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer r.Body.Close()
 | 
			
		||||
 | 
			
		||||
	if r.StatusCode != http.StatusOK {
 | 
			
		||||
		resp.ERROR(c, "error status:"+r.Status)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.Status(http.StatusOK)
 | 
			
		||||
	// 将下载的文件内容写入响应
 | 
			
		||||
	_, _ = io.Copy(c.Writer, r.Body)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -16,6 +16,7 @@ import (
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
@@ -32,6 +33,8 @@ type UserHandler struct {
 | 
			
		||||
	searcher       *xdb.Searcher
 | 
			
		||||
	redis          *redis.Client
 | 
			
		||||
	licenseService *service.LicenseService
 | 
			
		||||
	captcha        *service.CaptchaService
 | 
			
		||||
	userService    *service.UserService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewUserHandler(
 | 
			
		||||
@@ -39,12 +42,16 @@ func NewUserHandler(
 | 
			
		||||
	db *gorm.DB,
 | 
			
		||||
	searcher *xdb.Searcher,
 | 
			
		||||
	client *redis.Client,
 | 
			
		||||
	captcha *service.CaptchaService,
 | 
			
		||||
	userService *service.UserService,
 | 
			
		||||
	licenseService *service.LicenseService) *UserHandler {
 | 
			
		||||
	return &UserHandler{
 | 
			
		||||
		BaseHandler:    BaseHandler{DB: db, App: app},
 | 
			
		||||
		searcher:       searcher,
 | 
			
		||||
		redis:          client,
 | 
			
		||||
		captcha:        captcha,
 | 
			
		||||
		licenseService: licenseService,
 | 
			
		||||
		userService:    userService,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -54,9 +61,14 @@ func (h *UserHandler) Register(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		RegWay     string `json:"reg_way"`
 | 
			
		||||
		Username   string `json:"username"`
 | 
			
		||||
		Mobile     string `json:"mobile"`
 | 
			
		||||
		Email      string `json:"email"`
 | 
			
		||||
		Password   string `json:"password"`
 | 
			
		||||
		Code       string `json:"code"`
 | 
			
		||||
		InviteCode string `json:"invite_code"`
 | 
			
		||||
		Key        string `json:"key,omitempty"`
 | 
			
		||||
		Dots       string `json:"dots,omitempty"`
 | 
			
		||||
		X          int    `json:"x,omitempty"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
@@ -78,8 +90,15 @@ func (h *UserHandler) Register(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	// 检查验证码
 | 
			
		||||
	var key string
 | 
			
		||||
	if data.RegWay == "email" || data.RegWay == "mobile" {
 | 
			
		||||
		key = CodeStorePrefix + data.Username
 | 
			
		||||
	if data.RegWay == "email" {
 | 
			
		||||
		key = CodeStorePrefix + data.Email
 | 
			
		||||
		code, err := h.redis.Get(c, key).Result()
 | 
			
		||||
		if err != nil || code != data.Code {
 | 
			
		||||
			resp.ERROR(c, "验证码错误")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	} else if data.RegWay == "mobile" {
 | 
			
		||||
		key = CodeStorePrefix + data.Mobile
 | 
			
		||||
		code, err := h.redis.Get(c, key).Result()
 | 
			
		||||
		if err != nil || code != data.Code {
 | 
			
		||||
			resp.ERROR(c, "验证码错误")
 | 
			
		||||
@@ -97,9 +116,19 @@ func (h *UserHandler) Register(c *gin.Context) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// check if the username is exists
 | 
			
		||||
	// check if the username is existing
 | 
			
		||||
	var item model.User
 | 
			
		||||
	res := h.DB.Where("username = ?", data.Username).First(&item)
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	if data.Mobile != "" {
 | 
			
		||||
		session = session.Where("mobile = ?", data.Mobile)
 | 
			
		||||
		data.Username = data.Mobile
 | 
			
		||||
	} else if data.Email != "" {
 | 
			
		||||
		session = session.Where("email = ?", data.Email)
 | 
			
		||||
		data.Username = data.Email
 | 
			
		||||
	} else if data.Username != "" {
 | 
			
		||||
		session = session.Where("username = ?", data.Username)
 | 
			
		||||
	}
 | 
			
		||||
	session.First(&item)
 | 
			
		||||
	if item.Id > 0 {
 | 
			
		||||
		resp.ERROR(c, "该用户名已经被注册")
 | 
			
		||||
		return
 | 
			
		||||
@@ -108,8 +137,9 @@ func (h *UserHandler) Register(c *gin.Context) {
 | 
			
		||||
	salt := utils.RandString(8)
 | 
			
		||||
	user := model.User{
 | 
			
		||||
		Username:   data.Username,
 | 
			
		||||
		Mobile:     data.Mobile,
 | 
			
		||||
		Email:      data.Email,
 | 
			
		||||
		Password:   utils.GenPassword(data.Password, salt),
 | 
			
		||||
		Nickname:   fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
 | 
			
		||||
		Avatar:     "/images/avatar/user.png",
 | 
			
		||||
		Salt:       salt,
 | 
			
		||||
		Status:     true,
 | 
			
		||||
@@ -118,10 +148,19 @@ func (h *UserHandler) Register(c *gin.Context) {
 | 
			
		||||
		Power:      h.App.SysConfig.InitPower,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res = h.DB.Create(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "保存数据失败")
 | 
			
		||||
		logger.Error(res.Error)
 | 
			
		||||
	// 被邀请人也获得赠送算力
 | 
			
		||||
	if data.InviteCode != "" {
 | 
			
		||||
		user.Power += h.App.SysConfig.InvitePower
 | 
			
		||||
	}
 | 
			
		||||
	if h.licenseService.GetLicense().Configs.DeCopy {
 | 
			
		||||
		user.Nickname = fmt.Sprintf("用户@%d", utils.RandomNumber(6))
 | 
			
		||||
	} else {
 | 
			
		||||
		user.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx := h.DB.Begin()
 | 
			
		||||
	if err := tx.Create(&user).Error; err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -130,35 +169,35 @@ func (h *UserHandler) Register(c *gin.Context) {
 | 
			
		||||
		// 增加邀请数量
 | 
			
		||||
		h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
 | 
			
		||||
		if h.App.SysConfig.InvitePower > 0 {
 | 
			
		||||
			h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower))
 | 
			
		||||
			// 记录邀请算力充值日志
 | 
			
		||||
			var inviter model.User
 | 
			
		||||
			h.DB.Where("id", inviteCode.UserId).First(&inviter)
 | 
			
		||||
			h.DB.Create(&model.PowerLog{
 | 
			
		||||
				UserId:    inviter.Id,
 | 
			
		||||
				Username:  inviter.Username,
 | 
			
		||||
				Type:      types.PowerInvite,
 | 
			
		||||
				Amount:    h.App.SysConfig.InvitePower,
 | 
			
		||||
				Balance:   inviter.Power,
 | 
			
		||||
				Mark:      types.PowerAdd,
 | 
			
		||||
				Model:     "",
 | 
			
		||||
				Remark:    fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
 | 
			
		||||
				CreatedAt: time.Now(),
 | 
			
		||||
			err := h.userService.IncreasePower(int(inviteCode.UserId), h.App.SysConfig.InvitePower, model.PowerLog{
 | 
			
		||||
				Type:   types.PowerInvite,
 | 
			
		||||
				Model:  "",
 | 
			
		||||
				Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
 | 
			
		||||
			})
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				tx.Rollback()
 | 
			
		||||
				resp.ERROR(c, err.Error())
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 添加邀请记录
 | 
			
		||||
		h.DB.Create(&model.InviteLog{
 | 
			
		||||
		err := tx.Create(&model.InviteLog{
 | 
			
		||||
			InviterId:  inviteCode.UserId,
 | 
			
		||||
			UserId:     user.Id,
 | 
			
		||||
			Username:   user.Username,
 | 
			
		||||
			InviteCode: inviteCode.Code,
 | 
			
		||||
			Remark:     fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
 | 
			
		||||
		})
 | 
			
		||||
		}).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	tx.Commit()
 | 
			
		||||
 | 
			
		||||
	_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
 | 
			
		||||
 | 
			
		||||
	// 自动登录创建 token
 | 
			
		||||
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
 | 
			
		||||
		"user_id": user.Id,
 | 
			
		||||
@@ -175,7 +214,7 @@ func (h *UserHandler) Register(c *gin.Context) {
 | 
			
		||||
		resp.ERROR(c, "error with save token: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, tokenString)
 | 
			
		||||
	resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Login 用户登录
 | 
			
		||||
@@ -183,11 +222,28 @@ func (h *UserHandler) Login(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Username string `json:"username"`
 | 
			
		||||
		Password string `json:"password"`
 | 
			
		||||
		Key      string `json:"key,omitempty"`
 | 
			
		||||
		Dots     string `json:"dots,omitempty"`
 | 
			
		||||
		X        int    `json:"x,omitempty"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if h.App.SysConfig.EnabledVerify {
 | 
			
		||||
		var check bool
 | 
			
		||||
		if data.X != 0 {
 | 
			
		||||
			check = h.captcha.SlideCheck(data)
 | 
			
		||||
		} else {
 | 
			
		||||
			check = h.captcha.Check(data)
 | 
			
		||||
		}
 | 
			
		||||
		if !check {
 | 
			
		||||
			resp.ERROR(c, "请先完人机验证")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res := h.DB.Where("username = ?", data.Username).First(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
@@ -234,7 +290,7 @@ func (h *UserHandler) Login(c *gin.Context) {
 | 
			
		||||
		resp.ERROR(c, "error with save token: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, tokenString)
 | 
			
		||||
	resp.SUCCESS(c, gin.H{"token": tokenString, "user_id": user.Id, "username": user.Username})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Logout 注 销
 | 
			
		||||
@@ -246,21 +302,176 @@ func (h *UserHandler) Logout(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CLogin 第三方登录请求二维码
 | 
			
		||||
func (h *UserHandler) CLogin(c *gin.Context) {
 | 
			
		||||
	returnURL := h.GetTrim(c, "return_url")
 | 
			
		||||
	var res types.BizVo
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/api/clogin/request", h.App.Config.ApiConfig.ApiURL)
 | 
			
		||||
	r, err := req.C().R().SetBody(gin.H{"login_type": "wx", "return_url": returnURL}).
 | 
			
		||||
		SetHeader("AppId", h.App.Config.ApiConfig.AppId).
 | 
			
		||||
		SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		resp.ERROR(c, "error with login http status: "+r.Status)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		resp.ERROR(c, "error with http response: "+res.Message)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, res.Data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CLoginCallback 第三方登录回调
 | 
			
		||||
func (h *UserHandler) CLoginCallback(c *gin.Context) {
 | 
			
		||||
	loginType := c.Query("login_type")
 | 
			
		||||
	code := c.Query("code")
 | 
			
		||||
	userId := h.GetInt(c, "user_id", 0)
 | 
			
		||||
	action := c.Query("action")
 | 
			
		||||
 | 
			
		||||
	var res types.BizVo
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL)
 | 
			
		||||
	r, err := req.C().R().SetBody(gin.H{"login_type": loginType, "code": code}).
 | 
			
		||||
		SetHeader("AppId", h.App.Config.ApiConfig.AppId).
 | 
			
		||||
		SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.App.Config.ApiConfig.Token)).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		resp.ERROR(c, "error with login http status: "+r.Status)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		resp.ERROR(c, "error with http response: "+res.Message)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// login successfully
 | 
			
		||||
	data := res.Data.(map[string]interface{})
 | 
			
		||||
	var user model.User
 | 
			
		||||
	if action == "bind" && userId > 0 {
 | 
			
		||||
		err = h.DB.Where("openid", data["openid"]).First(&user).Error
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			resp.ERROR(c, "该微信已经绑定其他账号,请先解绑")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = h.DB.Where("id", userId).First(&user).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, "绑定用户不存在")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = h.DB.Model(&user).UpdateColumn("openid", data["openid"]).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, "更新用户信息失败,"+err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		resp.SUCCESS(c, gin.H{"token": ""})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	session := gin.H{}
 | 
			
		||||
	tx := h.DB.Where("openid", data["openid"]).First(&user)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		// create new user
 | 
			
		||||
		var totalUser int64
 | 
			
		||||
		h.DB.Model(&model.User{}).Count(&totalUser)
 | 
			
		||||
		if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
 | 
			
		||||
			resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		salt := utils.RandString(8)
 | 
			
		||||
		password := fmt.Sprintf("%d", utils.RandomNumber(8))
 | 
			
		||||
		user = model.User{
 | 
			
		||||
			Username:   fmt.Sprintf("%s@%d", loginType, utils.RandomNumber(10)),
 | 
			
		||||
			Password:   utils.GenPassword(password, salt),
 | 
			
		||||
			Avatar:     fmt.Sprintf("%s", data["avatar"]),
 | 
			
		||||
			Salt:       salt,
 | 
			
		||||
			Status:     true,
 | 
			
		||||
			ChatRoles:  utils.JsonEncode([]string{"gpt"}),               // 默认只订阅通用助手角色
 | 
			
		||||
			ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
 | 
			
		||||
			Power:      h.App.SysConfig.InitPower,
 | 
			
		||||
			OpenId:     fmt.Sprintf("%s", data["openid"]),
 | 
			
		||||
			Nickname:   fmt.Sprintf("%s", data["nickname"]),
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		tx = h.DB.Create(&user)
 | 
			
		||||
		if tx.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "保存数据失败")
 | 
			
		||||
			logger.Error(tx.Error)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		session["username"] = user.Username
 | 
			
		||||
		session["password"] = password
 | 
			
		||||
	} else { // login directly
 | 
			
		||||
		// 更新最后登录时间和IP
 | 
			
		||||
		user.LastLoginIp = c.ClientIP()
 | 
			
		||||
		user.LastLoginAt = time.Now().Unix()
 | 
			
		||||
		h.DB.Model(&user).Updates(user)
 | 
			
		||||
 | 
			
		||||
		h.DB.Create(&model.UserLoginLog{
 | 
			
		||||
			UserId:       user.Id,
 | 
			
		||||
			Username:     user.Username,
 | 
			
		||||
			LoginIp:      c.ClientIP(),
 | 
			
		||||
			LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 创建 token
 | 
			
		||||
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
 | 
			
		||||
		"user_id": user.Id,
 | 
			
		||||
		"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
 | 
			
		||||
	})
 | 
			
		||||
	tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "Failed to generate token, "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 保存到 redis
 | 
			
		||||
	key := fmt.Sprintf("users/%d", user.Id)
 | 
			
		||||
	if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
 | 
			
		||||
		resp.ERROR(c, "error with save token: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	session["token"] = tokenString
 | 
			
		||||
	resp.SUCCESS(c, session)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Session 获取/验证会话
 | 
			
		||||
func (h *UserHandler) Session(c *gin.Context) {
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		var userVo vo.User
 | 
			
		||||
		err := utils.CopyObject(user, &userVo)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c)
 | 
			
		||||
		}
 | 
			
		||||
		userVo.Id = user.Id
 | 
			
		||||
		resp.SUCCESS(c, userVo)
 | 
			
		||||
	} else {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var userVo vo.User
 | 
			
		||||
	err = utils.CopyObject(user, &userVo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 用户 VIP 到期
 | 
			
		||||
	if user.ExpiredTime > 0 && user.ExpiredTime < time.Now().Unix() {
 | 
			
		||||
		h.DB.Model(&user).UpdateColumn("vip", false)
 | 
			
		||||
	}
 | 
			
		||||
	userVo.Id = user.Id
 | 
			
		||||
	resp.SUCCESS(c, userVo)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type userProfile struct {
 | 
			
		||||
@@ -347,20 +558,21 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newPass := utils.GenPassword(data.Password, user.Salt)
 | 
			
		||||
	res := h.DB.Model(&user).UpdateColumn("password", newPass)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update database:", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败")
 | 
			
		||||
	err = h.DB.Model(&user).UpdateColumn("password", newPass).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ResetPass 重置密码
 | 
			
		||||
// ResetPass 找回密码
 | 
			
		||||
func (h *UserHandler) ResetPass(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Username string `json:"username"`
 | 
			
		||||
		Type     string `json:"type"`     // 验证类别:mobile, email
 | 
			
		||||
		Mobile   string `json:"mobile"`   // 手机号
 | 
			
		||||
		Email    string `json:"email"`    // 邮箱地址
 | 
			
		||||
		Code     string `json:"code"`     // 验证码
 | 
			
		||||
		Password string `json:"password"` // 新密码
 | 
			
		||||
	}
 | 
			
		||||
@@ -369,37 +581,47 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	var key string
 | 
			
		||||
	if data.Type == "email" {
 | 
			
		||||
		session = session.Where("email", data.Email)
 | 
			
		||||
		key = CodeStorePrefix + data.Email
 | 
			
		||||
	} else if data.Type == "mobile" {
 | 
			
		||||
		session = session.Where("mobile", data.Email)
 | 
			
		||||
		key = CodeStorePrefix + data.Mobile
 | 
			
		||||
	} else {
 | 
			
		||||
		resp.ERROR(c, "验证类别错误")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res := h.DB.Where("username", data.Username).First(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
	err := session.First(&user).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "用户不存在!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查验证码
 | 
			
		||||
	key := CodeStorePrefix + data.Username
 | 
			
		||||
	code, err := h.redis.Get(c, key).Result()
 | 
			
		||||
	if err != nil || code != data.Code {
 | 
			
		||||
		resp.ERROR(c, "短信验证码错误")
 | 
			
		||||
		resp.ERROR(c, "验证码错误")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	password := utils.GenPassword(data.Password, user.Salt)
 | 
			
		||||
	user.Password = password
 | 
			
		||||
	res = h.DB.Updates(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c)
 | 
			
		||||
	err = h.DB.Model(&user).UpdateColumn("password", password).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
	} else {
 | 
			
		||||
		h.redis.Del(c, key)
 | 
			
		||||
		resp.SUCCESS(c)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BindUsername 重置账号
 | 
			
		||||
func (h *UserHandler) BindUsername(c *gin.Context) {
 | 
			
		||||
// BindMobile 绑定手机号
 | 
			
		||||
func (h *UserHandler) BindMobile(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Username string `json:"username"`
 | 
			
		||||
		Code     string `json:"code"`
 | 
			
		||||
		Mobile string `json:"mobile"`
 | 
			
		||||
		Code   string `json:"code"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
@@ -407,7 +629,7 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查验证码
 | 
			
		||||
	key := CodeStorePrefix + data.Username
 | 
			
		||||
	key := CodeStorePrefix + data.Mobile
 | 
			
		||||
	code, err := h.redis.Get(c, key).Result()
 | 
			
		||||
	if err != nil || code != data.Code {
 | 
			
		||||
		resp.ERROR(c, "验证码错误")
 | 
			
		||||
@@ -416,22 +638,56 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	// 检查手机号是否被其他账号绑定
 | 
			
		||||
	var item model.User
 | 
			
		||||
	res := h.DB.Where("username = ?", data.Username).First(&item)
 | 
			
		||||
	res := h.DB.Where("mobile", data.Mobile).First(&item)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		resp.ERROR(c, "该账号已经被其他账号绑定")
 | 
			
		||||
		resp.ERROR(c, "该手机号已经绑定了其他账号,请更换手机号")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
 | 
			
		||||
	err = h.DB.Model(&item).Where("id", userId).UpdateColumn("mobile", data.Mobile).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res = h.DB.Model(&user).UpdateColumn("username", data.Username)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error(res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败")
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_ = h.redis.Del(c, key) // 删除短信验证码
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BindEmail 绑定邮箱
 | 
			
		||||
func (h *UserHandler) BindEmail(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Email string `json:"email"`
 | 
			
		||||
		Code  string `json:"code"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查验证码
 | 
			
		||||
	key := CodeStorePrefix + data.Email
 | 
			
		||||
	code, err := h.redis.Get(c, key).Result()
 | 
			
		||||
	if err != nil || code != data.Code {
 | 
			
		||||
		resp.ERROR(c, "验证码错误")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查手机号是否被其他账号绑定
 | 
			
		||||
	var item model.User
 | 
			
		||||
	res := h.DB.Where("email", data.Email).First(&item)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		resp.ERROR(c, "该邮箱地址已经绑定了其他账号,请更邮箱地址")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
 | 
			
		||||
	err = h.DB.Model(&item).Where("id", userId).UpdateColumn("email", data.Email).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										233
									
								
								api/handler/video_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										233
									
								
								api/handler/video_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,233 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/service/video"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type VideoHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	videoService *video.Service
 | 
			
		||||
	uploader     *oss.UploaderManager
 | 
			
		||||
	userService  *service.UserService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService) *VideoHandler {
 | 
			
		||||
	return &VideoHandler{
 | 
			
		||||
		BaseHandler: BaseHandler{
 | 
			
		||||
			App: app,
 | 
			
		||||
			DB:  db,
 | 
			
		||||
		},
 | 
			
		||||
		videoService: service,
 | 
			
		||||
		uploader:     uploader,
 | 
			
		||||
		userService:  userService,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Client WebSocket 客户端,用于通知任务状态变更
 | 
			
		||||
func (h *VideoHandler) Client(c *gin.Context) {
 | 
			
		||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := h.GetInt(c, "user_id", 0)
 | 
			
		||||
	if userId == 0 {
 | 
			
		||||
		logger.Info("Invalid user ID")
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	h.videoService.Clients.Put(uint(userId), client)
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *VideoHandler) LumaCreate(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Prompt        string `json:"prompt"`
 | 
			
		||||
		FirstFrameImg string `json:"first_frame_img,omitempty"`
 | 
			
		||||
		EndFrameImg   string `json:"end_frame_img,omitempty"`
 | 
			
		||||
		ExpandPrompt  bool   `json:"expand_prompt,omitempty"`
 | 
			
		||||
		Loop          bool   `json:"loop,omitempty"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if data.Prompt == "" {
 | 
			
		||||
		resp.ERROR(c, "prompt is needed")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := int(h.GetLoginUserId(c))
 | 
			
		||||
	params := types.VideoParams{
 | 
			
		||||
		PromptOptimize: data.ExpandPrompt,
 | 
			
		||||
		Loop:           data.Loop,
 | 
			
		||||
		StartImgURL:    data.FirstFrameImg,
 | 
			
		||||
		EndImgURL:      data.EndFrameImg,
 | 
			
		||||
	}
 | 
			
		||||
	// 插入数据库
 | 
			
		||||
	job := model.VideoJob{
 | 
			
		||||
		UserId: userId,
 | 
			
		||||
		Type:   types.VideoLuma,
 | 
			
		||||
		Prompt: data.Prompt,
 | 
			
		||||
		Power:  h.App.SysConfig.LumaPower,
 | 
			
		||||
		Params: utils.JsonEncode(params),
 | 
			
		||||
	}
 | 
			
		||||
	tx := h.DB.Create(&job)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		resp.ERROR(c, tx.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 创建任务
 | 
			
		||||
	h.videoService.PushTask(types.VideoTask{
 | 
			
		||||
		Id:     job.Id,
 | 
			
		||||
		UserId: userId,
 | 
			
		||||
		Type:   types.VideoLuma,
 | 
			
		||||
		Prompt: data.Prompt,
 | 
			
		||||
		Params: params,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	// update user's power
 | 
			
		||||
	err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
 | 
			
		||||
		Type:   types.PowerConsume,
 | 
			
		||||
		Model:  "luma",
 | 
			
		||||
		Remark: fmt.Sprintf("Luma 文生视频,任务ID:%d", job.Id),
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := h.videoService.Clients.Get(uint(job.UserId))
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *VideoHandler) List(c *gin.Context) {
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	t := c.Query("type")
 | 
			
		||||
	page := h.GetInt(c, "page", 1)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 20)
 | 
			
		||||
	all := h.GetBool(c, "all")
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
 | 
			
		||||
	if t != "" {
 | 
			
		||||
		session = session.Where("type", t)
 | 
			
		||||
	}
 | 
			
		||||
	if all {
 | 
			
		||||
		session = session.Where("publish", 0).Where("progress", 100)
 | 
			
		||||
	} else {
 | 
			
		||||
		session = session.Where("user_id", h.GetLoginUserId(c))
 | 
			
		||||
	}
 | 
			
		||||
	// 统计总数
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.VideoJob{}).Count(&total)
 | 
			
		||||
 | 
			
		||||
	if page > 0 && pageSize > 0 {
 | 
			
		||||
		offset := (page - 1) * pageSize
 | 
			
		||||
		session = session.Offset(offset).Limit(pageSize)
 | 
			
		||||
	}
 | 
			
		||||
	var list []model.VideoJob
 | 
			
		||||
	err := session.Order("id desc").Find(&list).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 转换为 VO
 | 
			
		||||
	items := make([]vo.VideoJob, 0)
 | 
			
		||||
	for _, v := range list {
 | 
			
		||||
		var item vo.VideoJob
 | 
			
		||||
		err = utils.CopyObject(v, &item)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		item.CreatedAt = v.CreatedAt.Unix()
 | 
			
		||||
		items = append(items, item)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *VideoHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	var job model.VideoJob
 | 
			
		||||
	err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 删除任务
 | 
			
		||||
	tx := h.DB.Begin()
 | 
			
		||||
	if err := tx.Delete(&job).Error; err != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果任务未完成,或者任务失败,则恢复用户算力
 | 
			
		||||
	if job.Progress != 100 {
 | 
			
		||||
		err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
 | 
			
		||||
			Type:   types.PowerRefund,
 | 
			
		||||
			Model:  "luma",
 | 
			
		||||
			Remark: fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg),
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	tx.Commit()
 | 
			
		||||
 | 
			
		||||
	// 删除文件
 | 
			
		||||
	_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
 | 
			
		||||
	_ = h.uploader.GetUploadHandler().Delete(job.VideoURL)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *VideoHandler) Publish(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	publish := h.GetBool(c, "publish")
 | 
			
		||||
	var job model.VideoJob
 | 
			
		||||
	err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = h.DB.Model(&job).UpdateColumn("publish", publish).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										153
									
								
								api/main.go
									
									
									
									
									
								
							
							
						
						
									
										153
									
								
								api/main.go
									
									
									
									
									
								
							@@ -23,7 +23,8 @@ import (
 | 
			
		||||
	"geekai/service/payment"
 | 
			
		||||
	"geekai/service/sd"
 | 
			
		||||
	"geekai/service/sms"
 | 
			
		||||
	"geekai/service/wx"
 | 
			
		||||
	"geekai/service/suno"
 | 
			
		||||
	"geekai/service/video"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"io"
 | 
			
		||||
	"log"
 | 
			
		||||
@@ -128,9 +129,9 @@ func main() {
 | 
			
		||||
		fx.Provide(handler.NewChatRoleHandler),
 | 
			
		||||
		fx.Provide(handler.NewUserHandler),
 | 
			
		||||
		fx.Provide(chatimpl.NewChatHandler),
 | 
			
		||||
		fx.Provide(handler.NewUploadHandler),
 | 
			
		||||
		fx.Provide(handler.NewNetHandler),
 | 
			
		||||
		fx.Provide(handler.NewSmsHandler),
 | 
			
		||||
		fx.Provide(handler.NewRewardHandler),
 | 
			
		||||
		fx.Provide(handler.NewRedeemHandler),
 | 
			
		||||
		fx.Provide(handler.NewCaptchaHandler),
 | 
			
		||||
		fx.Provide(handler.NewMidJourneyHandler),
 | 
			
		||||
		fx.Provide(handler.NewChatModelHandler),
 | 
			
		||||
@@ -146,7 +147,7 @@ func main() {
 | 
			
		||||
		fx.Provide(admin.NewApiKeyHandler),
 | 
			
		||||
		fx.Provide(admin.NewUserHandler),
 | 
			
		||||
		fx.Provide(admin.NewChatRoleHandler),
 | 
			
		||||
		fx.Provide(admin.NewRewardHandler),
 | 
			
		||||
		fx.Provide(admin.NewRedeemHandler),
 | 
			
		||||
		fx.Provide(admin.NewDashboardHandler),
 | 
			
		||||
		fx.Provide(admin.NewChatModelHandler),
 | 
			
		||||
		fx.Provide(admin.NewProductHandler),
 | 
			
		||||
@@ -160,13 +161,12 @@ func main() {
 | 
			
		||||
			return service.NewCaptchaService(config.ApiConfig)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(oss.NewUploaderManager),
 | 
			
		||||
		fx.Provide(mj.NewService),
 | 
			
		||||
		fx.Provide(dalle.NewService),
 | 
			
		||||
		fx.Invoke(func(service *dalle.Service) {
 | 
			
		||||
			service.Run()
 | 
			
		||||
			service.CheckTaskNotify()
 | 
			
		||||
			service.DownloadImages()
 | 
			
		||||
			service.CheckTaskStatus()
 | 
			
		||||
		fx.Invoke(func(s *dalle.Service) {
 | 
			
		||||
			s.Run()
 | 
			
		||||
			s.CheckTaskNotify()
 | 
			
		||||
			s.DownloadImages()
 | 
			
		||||
			s.CheckTaskStatus()
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 邮件服务
 | 
			
		||||
@@ -177,41 +177,43 @@ func main() {
 | 
			
		||||
			licenseService.SyncLicense()
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 微信机器人服务
 | 
			
		||||
		fx.Provide(wx.NewWeChatBot),
 | 
			
		||||
		fx.Invoke(func(config *types.AppConfig, bot *wx.Bot) {
 | 
			
		||||
			if config.WeChatBot {
 | 
			
		||||
				err := bot.Run()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Error("微信登录失败:", err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// MidJourney service pool
 | 
			
		||||
		fx.Provide(mj.NewServicePool),
 | 
			
		||||
		fx.Invoke(func(pool *mj.ServicePool, config *types.AppConfig) {
 | 
			
		||||
			pool.InitServices(config.MjPlusConfigs, config.MjProxyConfigs)
 | 
			
		||||
			if pool.HasAvailableService() {
 | 
			
		||||
				pool.DownloadImages()
 | 
			
		||||
				pool.CheckTaskNotify()
 | 
			
		||||
				pool.SyncTaskProgress()
 | 
			
		||||
			}
 | 
			
		||||
		fx.Provide(mj.NewService),
 | 
			
		||||
		fx.Provide(mj.NewClient),
 | 
			
		||||
		fx.Invoke(func(s *mj.Service) {
 | 
			
		||||
			s.Run()
 | 
			
		||||
			s.SyncTaskProgress()
 | 
			
		||||
			s.CheckTaskNotify()
 | 
			
		||||
			s.DownloadImages()
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// Stable Diffusion 机器人
 | 
			
		||||
		fx.Provide(sd.NewServicePool),
 | 
			
		||||
		fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) {
 | 
			
		||||
			pool.InitServices(config.SdConfigs)
 | 
			
		||||
			if pool.HasAvailableService() {
 | 
			
		||||
				pool.CheckTaskNotify()
 | 
			
		||||
				pool.CheckTaskStatus()
 | 
			
		||||
			}
 | 
			
		||||
		fx.Provide(sd.NewService),
 | 
			
		||||
		fx.Invoke(func(s *sd.Service, config *types.AppConfig) {
 | 
			
		||||
			s.Run()
 | 
			
		||||
			s.CheckTaskStatus()
 | 
			
		||||
			s.CheckTaskNotify()
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		fx.Provide(suno.NewService),
 | 
			
		||||
		fx.Invoke(func(s *suno.Service) {
 | 
			
		||||
			s.Run()
 | 
			
		||||
			s.SyncTaskProgress()
 | 
			
		||||
			s.CheckTaskNotify()
 | 
			
		||||
			s.DownloadFiles()
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(video.NewService),
 | 
			
		||||
		fx.Invoke(func(s *video.Service) {
 | 
			
		||||
			s.Run()
 | 
			
		||||
			s.SyncTaskProgress()
 | 
			
		||||
			s.CheckTaskNotify()
 | 
			
		||||
			s.DownloadFiles()
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(service.NewUserService),
 | 
			
		||||
		fx.Provide(payment.NewAlipayService),
 | 
			
		||||
		fx.Provide(payment.NewHuPiPay),
 | 
			
		||||
		fx.Provide(payment.NewPayJS),
 | 
			
		||||
		fx.Provide(payment.NewJPayService),
 | 
			
		||||
		fx.Provide(payment.NewWechatService),
 | 
			
		||||
		fx.Provide(service.NewSnowflake),
 | 
			
		||||
		fx.Provide(service.NewXXLJobExecutor),
 | 
			
		||||
		fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) {
 | 
			
		||||
@@ -237,8 +239,11 @@ func main() {
 | 
			
		||||
			group.GET("profile", h.Profile)
 | 
			
		||||
			group.POST("profile/update", h.ProfileUpdate)
 | 
			
		||||
			group.POST("password", h.UpdatePass)
 | 
			
		||||
			group.POST("bind/username", h.BindUsername)
 | 
			
		||||
			group.POST("bind/mobile", h.BindMobile)
 | 
			
		||||
			group.POST("bind/email", h.BindEmail)
 | 
			
		||||
			group.POST("resetPass", h.ResetPass)
 | 
			
		||||
			group.GET("clogin", h.CLogin)
 | 
			
		||||
			group.GET("clogin/callback", h.CLoginCallback)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *chatimpl.ChatHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/chat/")
 | 
			
		||||
@@ -252,10 +257,11 @@ func main() {
 | 
			
		||||
			group.POST("tokens", h.Tokens)
 | 
			
		||||
			group.GET("stop", h.StopGenerate)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) {
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.NetHandler) {
 | 
			
		||||
			s.Engine.POST("/api/upload", h.Upload)
 | 
			
		||||
			s.Engine.GET("/api/upload/list", h.List)
 | 
			
		||||
			s.Engine.POST("/api/upload/list", h.List)
 | 
			
		||||
			s.Engine.GET("/api/upload/remove", h.Remove)
 | 
			
		||||
			s.Engine.GET("/api/download", h.Download)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/sms/")
 | 
			
		||||
@@ -268,8 +274,8 @@ func main() {
 | 
			
		||||
			group.GET("slide/get", h.SlideGet)
 | 
			
		||||
			group.POST("slide/check", h.SlideCheck)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/reward/")
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.RedeemHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/redeem/")
 | 
			
		||||
			group.POST("verify", h.Verify)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
 | 
			
		||||
@@ -280,8 +286,8 @@ func main() {
 | 
			
		||||
			group.POST("variation", h.Variation)
 | 
			
		||||
			group.GET("jobs", h.JobList)
 | 
			
		||||
			group.GET("imgWall", h.ImgWall)
 | 
			
		||||
			group.POST("remove", h.Remove)
 | 
			
		||||
			group.POST("publish", h.Publish)
 | 
			
		||||
			group.GET("remove", h.Remove)
 | 
			
		||||
			group.GET("publish", h.Publish)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/sd")
 | 
			
		||||
@@ -289,8 +295,8 @@ func main() {
 | 
			
		||||
			group.POST("image", h.Image)
 | 
			
		||||
			group.GET("jobs", h.JobList)
 | 
			
		||||
			group.GET("imgWall", h.ImgWall)
 | 
			
		||||
			group.POST("remove", h.Remove)
 | 
			
		||||
			group.POST("publish", h.Publish)
 | 
			
		||||
			group.GET("remove", h.Remove)
 | 
			
		||||
			group.GET("publish", h.Publish)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/config/")
 | 
			
		||||
@@ -305,8 +311,6 @@ func main() {
 | 
			
		||||
			group.GET("config/get", h.Get)
 | 
			
		||||
			group.POST("active", h.Active)
 | 
			
		||||
			group.GET("config/get/license", h.GetLicense)
 | 
			
		||||
			group.GET("config/get/app", h.GetAppConfig)
 | 
			
		||||
			group.POST("config/update/draw", h.SaveDrawingConfig)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/")
 | 
			
		||||
@@ -342,9 +346,11 @@ func main() {
 | 
			
		||||
			group.POST("set", h.Set)
 | 
			
		||||
			group.GET("remove", h.Remove)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/reward/")
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.RedeemHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/redeem/")
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
			group.POST("create", h.Create)
 | 
			
		||||
			group.POST("set", h.Set)
 | 
			
		||||
			group.POST("remove", h.Remove)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
 | 
			
		||||
@@ -367,12 +373,12 @@ func main() {
 | 
			
		||||
			group := s.Engine.Group("/api/payment/")
 | 
			
		||||
			group.GET("doPay", h.DoPay)
 | 
			
		||||
			group.GET("payWays", h.GetPayWays)
 | 
			
		||||
			group.POST("query", h.OrderQuery)
 | 
			
		||||
			group.POST("qrcode", h.PayQrcode)
 | 
			
		||||
			group.POST("mobile", h.Mobile)
 | 
			
		||||
			group.POST("alipay/notify", h.AlipayNotify)
 | 
			
		||||
			group.POST("hupipay/notify", h.HuPiPayNotify)
 | 
			
		||||
			group.POST("payjs/notify", h.PayJsNotify)
 | 
			
		||||
			group.POST("wechat/notify", h.WechatPayNotify)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/product/")
 | 
			
		||||
@@ -386,10 +392,12 @@ func main() {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/order/")
 | 
			
		||||
			group.POST("list", h.List)
 | 
			
		||||
			group.GET("remove", h.Remove)
 | 
			
		||||
			group.GET("clear", h.Clear)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/order/")
 | 
			
		||||
			group.POST("list", h.List)
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
			group.GET("query", h.Query)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.ProductHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/product/")
 | 
			
		||||
@@ -400,7 +408,7 @@ func main() {
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/invite/")
 | 
			
		||||
			group.GET("code", h.Code)
 | 
			
		||||
			group.POST("list", h.List)
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
			group.GET("hits", h.Hits)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
@@ -414,13 +422,6 @@ func main() {
 | 
			
		||||
			group.GET("token", h.GenToken)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 验证码
 | 
			
		||||
		fx.Provide(admin.NewCaptchaHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.CaptchaHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/login/")
 | 
			
		||||
			group.GET("captcha", h.GetCaptcha)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		fx.Provide(admin.NewUploadHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
 | 
			
		||||
			s.Engine.POST("/api/admin/upload", h.Upload)
 | 
			
		||||
@@ -432,6 +433,7 @@ func main() {
 | 
			
		||||
			group.POST("weibo", h.WeiBo)
 | 
			
		||||
			group.POST("zaobao", h.ZaoBao)
 | 
			
		||||
			group.POST("dalle3", h.Dall3)
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/chat/")
 | 
			
		||||
@@ -475,8 +477,35 @@ func main() {
 | 
			
		||||
			group.POST("image", h.Image)
 | 
			
		||||
			group.GET("jobs", h.JobList)
 | 
			
		||||
			group.GET("imgWall", h.ImgWall)
 | 
			
		||||
			group.POST("remove", h.Remove)
 | 
			
		||||
			group.POST("publish", h.Publish)
 | 
			
		||||
			group.GET("remove", h.Remove)
 | 
			
		||||
			group.GET("publish", h.Publish)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(handler.NewSunoHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/suno")
 | 
			
		||||
			group.Any("client", h.Client)
 | 
			
		||||
			group.POST("create", h.Create)
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
			group.GET("remove", h.Remove)
 | 
			
		||||
			group.GET("publish", h.Publish)
 | 
			
		||||
			group.POST("update", h.Update)
 | 
			
		||||
			group.GET("detail", h.Detail)
 | 
			
		||||
			group.GET("play", h.Play)
 | 
			
		||||
			group.POST("lyric", h.Lyric)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(handler.NewVideoHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/video")
 | 
			
		||||
			group.Any("client", h.Client)
 | 
			
		||||
			group.POST("luma/create", h.LumaCreate)
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
			group.GET("remove", h.Remove)
 | 
			
		||||
			group.GET("publish", h.Publish)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(handler.NewTestHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/test")
 | 
			
		||||
			group.Any("sse", h.PostTest, h.SseTest)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
 | 
			
		||||
			go func() {
 | 
			
		||||
 
 | 
			
		||||
@@ -14,7 +14,6 @@ import (
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/service/sd"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
@@ -36,9 +35,10 @@ type Service struct {
 | 
			
		||||
	taskQueue     *store.RedisQueue
 | 
			
		||||
	notifyQueue   *store.RedisQueue
 | 
			
		||||
	Clients       *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
			
		||||
	userService   *service.UserService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
 | 
			
		||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service {
 | 
			
		||||
	return &Service{
 | 
			
		||||
		httpClient:    req.C().SetTimeout(time.Minute * 3),
 | 
			
		||||
		db:            db,
 | 
			
		||||
@@ -46,6 +46,7 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
 | 
			
		||||
		notifyQueue:   store.NewRedisQueue("DallE_Notify_Queue", redisCli),
 | 
			
		||||
		Clients:       types.NewLMap[uint, *types.WsClient](),
 | 
			
		||||
		uploadManager: manager,
 | 
			
		||||
		userService:   userService,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -70,10 +71,10 @@ func (s *Service) Run() {
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Errorf("error with image task: %v", err)
 | 
			
		||||
				s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
					"progress": -1,
 | 
			
		||||
					"progress": service.FailTaskProgress,
 | 
			
		||||
					"err_msg":  err.Error(),
 | 
			
		||||
				})
 | 
			
		||||
				s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed})
 | 
			
		||||
				s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
@@ -109,13 +110,12 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
 | 
			
		||||
	logger.Debugf("绘画参数:%+v", task)
 | 
			
		||||
	prompt := task.Prompt
 | 
			
		||||
	// translate prompt
 | 
			
		||||
	if utils.HasChinese(task.Prompt) {
 | 
			
		||||
		content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", fmt.Errorf("error with translate prompt: %v", err)
 | 
			
		||||
	if utils.HasChinese(prompt) {
 | 
			
		||||
		content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini")
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			prompt = content
 | 
			
		||||
			logger.Debugf("重写后提示词:%s", prompt)
 | 
			
		||||
		}
 | 
			
		||||
		prompt = content
 | 
			
		||||
		logger.Debugf("重写后提示词:%s", prompt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user model.User
 | 
			
		||||
@@ -124,14 +124,23 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
 | 
			
		||||
		return "", errors.New("insufficient of power")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 扣减算力
 | 
			
		||||
	err := s.userService.DecreasePower(int(user.Id), task.Power, model.PowerLog{
 | 
			
		||||
		Type:   types.PowerConsume,
 | 
			
		||||
		Model:  "dall-e-3",
 | 
			
		||||
		Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with decrease power: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// get image generation API KEY
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	tx := s.db.Where("platform", types.OpenAI.Value).
 | 
			
		||||
		Where("type", "img").
 | 
			
		||||
	err = s.db.Where("type", "dalle").
 | 
			
		||||
		Where("enabled", true).
 | 
			
		||||
		Order("last_used_at ASC").First(&apiKey)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return "", fmt.Errorf("no available IMG api key: %v", tx.Error)
 | 
			
		||||
		Order("last_used_at ASC").First(&apiKey).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("no available DALL-E api key: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var res imgRes
 | 
			
		||||
@@ -139,39 +148,42 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
 | 
			
		||||
	if len(apiKey.ProxyURL) > 5 {
 | 
			
		||||
		s.httpClient.SetProxyURL(apiKey.ProxyURL).R()
 | 
			
		||||
	}
 | 
			
		||||
	logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
 | 
			
		||||
	reqBody := imgReq{
 | 
			
		||||
		Model:   "dall-e-3",
 | 
			
		||||
		Prompt:  prompt,
 | 
			
		||||
		N:       1,
 | 
			
		||||
		Size:    task.Size,
 | 
			
		||||
		Style:   task.Style,
 | 
			
		||||
		Quality: task.Quality,
 | 
			
		||||
	}
 | 
			
		||||
	logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
 | 
			
		||||
	r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
			
		||||
		SetBody(imgReq{
 | 
			
		||||
			Model:   "dall-e-3",
 | 
			
		||||
			Prompt:  prompt,
 | 
			
		||||
			N:       1,
 | 
			
		||||
			Size:    task.Size,
 | 
			
		||||
			Style:   task.Style,
 | 
			
		||||
			Quality: task.Quality,
 | 
			
		||||
		}).
 | 
			
		||||
		SetBody(reqBody).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		SetSuccessResult(&res).Post(apiKey.ApiURL)
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with send request: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return "", fmt.Errorf("error with send request: %v", errRes.Error)
 | 
			
		||||
		return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error)
 | 
			
		||||
	}
 | 
			
		||||
	// update the api key last use time
 | 
			
		||||
	s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
	// update task progress
 | 
			
		||||
	tx = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
	err = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
		"progress": 100,
 | 
			
		||||
		"org_url":  res.Data[0].Url,
 | 
			
		||||
		"prompt":   prompt,
 | 
			
		||||
	})
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return "", fmt.Errorf("err with update database: %v", tx.Error)
 | 
			
		||||
	}).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("err with update database: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished})
 | 
			
		||||
	s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
 | 
			
		||||
	var content string
 | 
			
		||||
	if sync {
 | 
			
		||||
		imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
 | 
			
		||||
@@ -181,25 +193,6 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
 | 
			
		||||
		content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n\n", prompt, imgURL)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新用户算力
 | 
			
		||||
	tx = s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
 | 
			
		||||
	// 记录算力变化日志
 | 
			
		||||
	if tx.Error == nil && tx.RowsAffected > 0 {
 | 
			
		||||
		var u model.User
 | 
			
		||||
		s.db.Where("id", user.Id).First(&u)
 | 
			
		||||
		s.db.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
			Type:      types.PowerConsume,
 | 
			
		||||
			Amount:    task.Power,
 | 
			
		||||
			Balance:   u.Power,
 | 
			
		||||
			Mark:      types.PowerSub,
 | 
			
		||||
			Model:     "dall-e-3",
 | 
			
		||||
			Remark:    fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return content, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -207,7 +200,7 @@ func (s *Service) CheckTaskNotify() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running DALL-E task notify checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var message sd.NotifyMessage
 | 
			
		||||
			var message service.NotifyMessage
 | 
			
		||||
			err := s.notifyQueue.LPop(&message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
@@ -224,6 +217,30 @@ func (s *Service) CheckTaskNotify() {
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) CheckTaskStatus() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running DALL-E task status checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var jobs []model.DallJob
 | 
			
		||||
			res := s.db.Where("progress < ?", 100).Find(&jobs)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				time.Sleep(5 * time.Second)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, job := range jobs {
 | 
			
		||||
				// 超时的任务标记为失败
 | 
			
		||||
				if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
 | 
			
		||||
					job.Progress = service.FailTaskProgress
 | 
			
		||||
					job.ErrMsg = "任务超时"
 | 
			
		||||
					s.db.Updates(&job)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			time.Sleep(time.Second * 10)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) DownloadImages() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		var items []model.DallJob
 | 
			
		||||
@@ -257,7 +274,7 @@ func (s *Service) DownloadImages() {
 | 
			
		||||
 | 
			
		||||
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) {
 | 
			
		||||
	// sava image
 | 
			
		||||
	imgURL, err := s.uploadManager.GetUploadHandler().PutImg(orgURL, false)
 | 
			
		||||
	imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
@@ -267,47 +284,6 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Finished})
 | 
			
		||||
	s.notifyQueue.RPush(service.NotifyMessage{UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
 | 
			
		||||
	return imgURL, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
 | 
			
		||||
func (s *Service) CheckTaskStatus() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running Stable-Diffusion task status checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var jobs []model.DallJob
 | 
			
		||||
			res := s.db.Where("progress < ?", 100).Find(&jobs)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				time.Sleep(5 * time.Second)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, job := range jobs {
 | 
			
		||||
				// 5 分钟还没完成的任务直接删除
 | 
			
		||||
				if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
 | 
			
		||||
					s.db.Delete(&job)
 | 
			
		||||
					var user model.User
 | 
			
		||||
					s.db.Where("id = ?", job.UserId).First(&user)
 | 
			
		||||
					// 退回绘图次数
 | 
			
		||||
					res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
 | 
			
		||||
					if res.Error == nil && res.RowsAffected > 0 {
 | 
			
		||||
						s.db.Create(&model.PowerLog{
 | 
			
		||||
							UserId:    user.Id,
 | 
			
		||||
							Username:  user.Username,
 | 
			
		||||
							Type:      types.PowerConsume,
 | 
			
		||||
							Amount:    job.Power,
 | 
			
		||||
							Balance:   user.Power + job.Power,
 | 
			
		||||
							Mark:      types.PowerAdd,
 | 
			
		||||
							Model:     "dall-e-3",
 | 
			
		||||
							Remark:    fmt.Sprintf("任务失败,退回算力。任务ID:%d", job.Id),
 | 
			
		||||
							CreatedAt: time.Now(),
 | 
			
		||||
						})
 | 
			
		||||
					}
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			time.Sleep(time.Second * 10)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -91,7 +91,7 @@ func (s *LicenseService) SyncLicense() {
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				retryCounter++
 | 
			
		||||
				if retryCounter < 5 {
 | 
			
		||||
					logger.Error(err)
 | 
			
		||||
					logger.Warn(err)
 | 
			
		||||
				}
 | 
			
		||||
				s.license.IsActive = false
 | 
			
		||||
			} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -7,15 +7,28 @@ package mj
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import "geekai/core/types"
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"io"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
type Client interface {
 | 
			
		||||
	Imagine(task types.MjTask) (ImageRes, error)
 | 
			
		||||
	Blend(task types.MjTask) (ImageRes, error)
 | 
			
		||||
	SwapFace(task types.MjTask) (ImageRes, error)
 | 
			
		||||
	Upscale(task types.MjTask) (ImageRes, error)
 | 
			
		||||
	Variation(task types.MjTask) (ImageRes, error)
 | 
			
		||||
	QueryTask(taskId string) (QueryRes, error)
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Client MidJourney client
 | 
			
		||||
type Client struct {
 | 
			
		||||
	client         *req.Client
 | 
			
		||||
	licenseService *service.LicenseService
 | 
			
		||||
	db             *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ImageReq struct {
 | 
			
		||||
@@ -33,13 +46,8 @@ type ImageRes struct {
 | 
			
		||||
	Description string `json:"description"`
 | 
			
		||||
	Properties  struct {
 | 
			
		||||
	} `json:"properties"`
 | 
			
		||||
	Result string `json:"result"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ErrRes struct {
 | 
			
		||||
	Error struct {
 | 
			
		||||
		Message string `json:"message"`
 | 
			
		||||
	} `json:"error"`
 | 
			
		||||
	Result  string `json:"result"`
 | 
			
		||||
	Channel string `json:"channel,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type QueryRes struct {
 | 
			
		||||
@@ -66,3 +74,177 @@ type QueryRes struct {
 | 
			
		||||
	Status     string `json:"status"`
 | 
			
		||||
	SubmitTime int    `json:"submitTime"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
func NewClient(licenseService *service.LicenseService, db *gorm.DB) *Client {
 | 
			
		||||
	return &Client{
 | 
			
		||||
		client:         req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
 | 
			
		||||
		licenseService: licenseService,
 | 
			
		||||
		db:             db,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	apiPath := fmt.Sprintf("mj-%s/mj/submit/imagine", task.Mode)
 | 
			
		||||
	prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
 | 
			
		||||
	if task.NegPrompt != "" {
 | 
			
		||||
		prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
 | 
			
		||||
	}
 | 
			
		||||
	body := ImageReq{
 | 
			
		||||
		BotType:     "MID_JOURNEY",
 | 
			
		||||
		Prompt:      prompt,
 | 
			
		||||
		Base64Array: make([]string, 0),
 | 
			
		||||
	}
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) > 0 {
 | 
			
		||||
		imageData, err := utils.DownloadImage(task.ImgArr[0], "")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("error with download image: ", err)
 | 
			
		||||
		} else {
 | 
			
		||||
			body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
	return c.doRequest(body, apiPath, task.ChannelId)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Blend 融图
 | 
			
		||||
func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	apiPath := fmt.Sprintf("mj-%s/mj/submit/blend", task.Mode)
 | 
			
		||||
	body := ImageReq{
 | 
			
		||||
		BotType:     "MID_JOURNEY",
 | 
			
		||||
		Dimensions:  "SQUARE",
 | 
			
		||||
		Base64Array: make([]string, 0),
 | 
			
		||||
	}
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) > 0 {
 | 
			
		||||
		for _, imgURL := range task.ImgArr {
 | 
			
		||||
			imageData, err := utils.DownloadImage(imgURL, "")
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with download image: ", err)
 | 
			
		||||
			} else {
 | 
			
		||||
				body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return c.doRequest(body, apiPath, task.ChannelId)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SwapFace 换脸
 | 
			
		||||
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	apiPath := fmt.Sprintf("mj-%s/mj/insight-face/swap", task.Mode)
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) != 2 {
 | 
			
		||||
		return ImageRes{}, errors.New("参数错误,必须上传2张图片")
 | 
			
		||||
	}
 | 
			
		||||
	var sourceBase64 string
 | 
			
		||||
	var targetBase64 string
 | 
			
		||||
	imageData, err := utils.DownloadImage(task.ImgArr[0], "")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("error with download image: ", err)
 | 
			
		||||
	} else {
 | 
			
		||||
		sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
 | 
			
		||||
	}
 | 
			
		||||
	imageData, err = utils.DownloadImage(task.ImgArr[1], "")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("error with download image: ", err)
 | 
			
		||||
	} else {
 | 
			
		||||
		targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body := gin.H{
 | 
			
		||||
		"sourceBase64": sourceBase64,
 | 
			
		||||
		"targetBase64": targetBase64,
 | 
			
		||||
		"accountFilter": gin.H{
 | 
			
		||||
			"instanceId": "",
 | 
			
		||||
		},
 | 
			
		||||
		"state": "",
 | 
			
		||||
	}
 | 
			
		||||
	return c.doRequest(body, apiPath, task.ChannelId)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Upscale 放大指定的图片
 | 
			
		||||
func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	body := map[string]string{
 | 
			
		||||
		"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
 | 
			
		||||
		"taskId":   task.MessageId,
 | 
			
		||||
	}
 | 
			
		||||
	apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
 | 
			
		||||
	return c.doRequest(body, apiPath, task.ChannelId)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Variation  以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
 | 
			
		||||
func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	body := map[string]string{
 | 
			
		||||
		"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
 | 
			
		||||
		"taskId":   task.MessageId,
 | 
			
		||||
	}
 | 
			
		||||
	apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
 | 
			
		||||
 | 
			
		||||
	return c.doRequest(body, apiPath, task.ChannelId)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) {
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true)
 | 
			
		||||
	if channel != "" {
 | 
			
		||||
		session = session.Where("api_url", channel)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	err := session.Order("last_used_at ASC").First(&apiKey).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = c.licenseService.IsValidApiURL(apiKey.ApiURL); err != nil {
 | 
			
		||||
		return ImageRes{}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/%s", apiKey.ApiURL, apiPath)
 | 
			
		||||
	logger.Info("API URL: ", apiURL)
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		errMsg, _ := io.ReadAll(r.Body)
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", string(errMsg))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update the api key last used time
 | 
			
		||||
	if err = c.db.Model(&apiKey).Update("last_used_at", time.Now().Unix()).Error; err != nil {
 | 
			
		||||
		logger.Error("update api key last used time error: ", err)
 | 
			
		||||
	}
 | 
			
		||||
	res.Channel = apiKey.ApiURL
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) QueryTask(taskId string, channel string) (QueryRes, error) {
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	err := c.db.Where("type", "mj").Where("enabled", true).Where("api_url", channel).First(&apiKey).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return QueryRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", apiKey.ApiURL, taskId)
 | 
			
		||||
	var res QueryRes
 | 
			
		||||
	r, err := c.client.R().SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		Get(apiURL)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return QueryRes{}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return QueryRes{}, errors.New("error status:" + r.Status)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,267 +0,0 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"io"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// PlusClient MidJourney Plus ProxyClient
 | 
			
		||||
type PlusClient struct {
 | 
			
		||||
	Config         types.MjPlusConfig
 | 
			
		||||
	apiURL         string
 | 
			
		||||
	client         *req.Client
 | 
			
		||||
	licenseService *service.LicenseService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPlusClient(config types.MjPlusConfig, licenseService *service.LicenseService) *PlusClient {
 | 
			
		||||
	return &PlusClient{
 | 
			
		||||
		Config:         config,
 | 
			
		||||
		apiURL:         config.ApiURL,
 | 
			
		||||
		client:         req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
 | 
			
		||||
		licenseService: licenseService,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *PlusClient) preCheck() error {
 | 
			
		||||
	return c.licenseService.IsValidApiURL(c.Config.ApiURL)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	if err := c.preCheck(); err != nil {
 | 
			
		||||
		return ImageRes{}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
 | 
			
		||||
	prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
 | 
			
		||||
	if task.NegPrompt != "" {
 | 
			
		||||
		prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
 | 
			
		||||
	}
 | 
			
		||||
	body := ImageReq{
 | 
			
		||||
		BotType:     "MID_JOURNEY",
 | 
			
		||||
		Prompt:      prompt,
 | 
			
		||||
		Base64Array: make([]string, 0),
 | 
			
		||||
	}
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) > 0 {
 | 
			
		||||
		imageData, err := utils.DownloadImage(task.ImgArr[0], "")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("error with download image: ", err)
 | 
			
		||||
		} else {
 | 
			
		||||
			body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
	logger.Info("API URL: ", apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := c.client.R().
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		errStr, _ := io.ReadAll(r.Body)
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Blend 融图
 | 
			
		||||
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	if err := c.preCheck(); err != nil {
 | 
			
		||||
		return ImageRes{}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
 | 
			
		||||
	logger.Info("API URL: ", apiURL)
 | 
			
		||||
	body := ImageReq{
 | 
			
		||||
		BotType:     "MID_JOURNEY",
 | 
			
		||||
		Dimensions:  "SQUARE",
 | 
			
		||||
		Base64Array: make([]string, 0),
 | 
			
		||||
	}
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) > 0 {
 | 
			
		||||
		for _, imgURL := range task.ImgArr {
 | 
			
		||||
			imageData, err := utils.DownloadImage(imgURL, "")
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with download image: ", err)
 | 
			
		||||
			} else {
 | 
			
		||||
				body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := c.client.R().
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SwapFace 换脸
 | 
			
		||||
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	if err := c.preCheck(); err != nil {
 | 
			
		||||
		return ImageRes{}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) != 2 {
 | 
			
		||||
		return ImageRes{}, errors.New("参数错误,必须上传2张图片")
 | 
			
		||||
	}
 | 
			
		||||
	var sourceBase64 string
 | 
			
		||||
	var targetBase64 string
 | 
			
		||||
	imageData, err := utils.DownloadImage(task.ImgArr[0], "")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("error with download image: ", err)
 | 
			
		||||
	} else {
 | 
			
		||||
		sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
 | 
			
		||||
	}
 | 
			
		||||
	imageData, err = utils.DownloadImage(task.ImgArr[1], "")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("error with download image: ", err)
 | 
			
		||||
	} else {
 | 
			
		||||
		targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body := gin.H{
 | 
			
		||||
		"sourceBase64": sourceBase64,
 | 
			
		||||
		"targetBase64": targetBase64,
 | 
			
		||||
		"accountFilter": gin.H{
 | 
			
		||||
			"instanceId": "",
 | 
			
		||||
		},
 | 
			
		||||
		"state": "",
 | 
			
		||||
	}
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := c.client.SetTimeout(time.Minute).R().
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Upscale 放大指定的图片
 | 
			
		||||
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	if err := c.preCheck(); err != nil {
 | 
			
		||||
		return ImageRes{}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body := map[string]string{
 | 
			
		||||
		"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
 | 
			
		||||
		"taskId":   task.MessageId,
 | 
			
		||||
	}
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
 | 
			
		||||
	logger.Info("API URL: ", apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := c.client.R().
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Variation  以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
 | 
			
		||||
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	if err := c.preCheck(); err != nil {
 | 
			
		||||
		return ImageRes{}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body := map[string]string{
 | 
			
		||||
		"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
 | 
			
		||||
		"taskId":   task.MessageId,
 | 
			
		||||
	}
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
 | 
			
		||||
	logger.Info("API URL: ", apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
 | 
			
		||||
	var res QueryRes
 | 
			
		||||
	r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		Get(apiURL)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return QueryRes{}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return QueryRes{}, errors.New("error status:" + r.Status)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Client = &PlusClient{}
 | 
			
		||||
@@ -1,230 +0,0 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/service/sd"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ServicePool Mj service pool
 | 
			
		||||
type ServicePool struct {
 | 
			
		||||
	services        []*Service
 | 
			
		||||
	taskQueue       *store.RedisQueue
 | 
			
		||||
	notifyQueue     *store.RedisQueue
 | 
			
		||||
	db              *gorm.DB
 | 
			
		||||
	uploaderManager *oss.UploaderManager
 | 
			
		||||
	Clients         *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
			
		||||
	licenseService  *service.LicenseService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ServicePool {
 | 
			
		||||
	services := make([]*Service, 0)
 | 
			
		||||
	taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
 | 
			
		||||
	notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
 | 
			
		||||
	return &ServicePool{
 | 
			
		||||
		taskQueue:       taskQueue,
 | 
			
		||||
		notifyQueue:     notifyQueue,
 | 
			
		||||
		services:        services,
 | 
			
		||||
		uploaderManager: manager,
 | 
			
		||||
		db:              db,
 | 
			
		||||
		Clients:         types.NewLMap[uint, *types.WsClient](),
 | 
			
		||||
		licenseService:  licenseService,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfigs []types.MjProxyConfig) {
 | 
			
		||||
	// stop old service
 | 
			
		||||
	for _, s := range p.services {
 | 
			
		||||
		s.Stop()
 | 
			
		||||
	}
 | 
			
		||||
	p.services = make([]*Service, 0)
 | 
			
		||||
 | 
			
		||||
	for k, config := range plusConfigs {
 | 
			
		||||
		if config.Enabled == false {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		cli := NewPlusClient(config, p.licenseService)
 | 
			
		||||
		name := fmt.Sprintf("mj-plus-service-%d", k)
 | 
			
		||||
		plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
 | 
			
		||||
		go func() {
 | 
			
		||||
			plusService.Run()
 | 
			
		||||
		}()
 | 
			
		||||
		p.services = append(p.services, plusService)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// for mid-journey proxy
 | 
			
		||||
	for k, config := range proxyConfigs {
 | 
			
		||||
		if config.Enabled == false {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		cli := NewProxyClient(config)
 | 
			
		||||
		name := fmt.Sprintf("mj-proxy-service-%d", k)
 | 
			
		||||
		proxyService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
 | 
			
		||||
		go func() {
 | 
			
		||||
			proxyService.Run()
 | 
			
		||||
		}()
 | 
			
		||||
		p.services = append(p.services, proxyService)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) CheckTaskNotify() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			var message sd.NotifyMessage
 | 
			
		||||
			err := p.notifyQueue.LPop(&message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			cli := p.Clients.Get(uint(message.UserId))
 | 
			
		||||
			if cli == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			err = cli.Send([]byte(message.Message))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) DownloadImages() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		var items []model.MidJourneyJob
 | 
			
		||||
		for {
 | 
			
		||||
			res := p.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// download images
 | 
			
		||||
			for _, v := range items {
 | 
			
		||||
				if v.OrgURL == "" {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				logger.Infof("try to download image: %s", v.OrgURL)
 | 
			
		||||
				mjService := p.getService(v.ChannelId)
 | 
			
		||||
				if mjService == nil {
 | 
			
		||||
					logger.Errorf("Invalid task: %+v", v)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				task, _ := mjService.Client.QueryTask(v.TaskId)
 | 
			
		||||
				if len(task.Buttons) > 0 {
 | 
			
		||||
					v.Hash = GetImageHash(task.Buttons[0].CustomId)
 | 
			
		||||
				}
 | 
			
		||||
				// 如果是返回的是 discord 图片地址,则使用代理下载
 | 
			
		||||
				proxy := false
 | 
			
		||||
				if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
 | 
			
		||||
					proxy = true
 | 
			
		||||
				}
 | 
			
		||||
				imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, proxy)
 | 
			
		||||
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Errorf("error with download image %s, %v", v.OrgURL, err)
 | 
			
		||||
					continue
 | 
			
		||||
				} else {
 | 
			
		||||
					logger.Infof("download image %s successfully.", v.OrgURL)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				v.ImgURL = imgURL
 | 
			
		||||
				p.db.Updates(&v)
 | 
			
		||||
 | 
			
		||||
				cli := p.Clients.Get(uint(v.UserId))
 | 
			
		||||
				if cli == nil {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				err = cli.Send([]byte(sd.Finished))
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			time.Sleep(time.Second * 5)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PushTask push a new mj task in to task queue
 | 
			
		||||
func (p *ServicePool) PushTask(task types.MjTask) {
 | 
			
		||||
	logger.Debugf("add a new MidJourney task to the task list: %+v", task)
 | 
			
		||||
	p.taskQueue.RPush(task)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// HasAvailableService check if it has available mj service in pool
 | 
			
		||||
func (p *ServicePool) HasAvailableService() bool {
 | 
			
		||||
	return len(p.services) > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SyncTaskProgress 异步拉取任务
 | 
			
		||||
func (p *ServicePool) SyncTaskProgress() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		var items []model.MidJourneyJob
 | 
			
		||||
		for {
 | 
			
		||||
			res := p.db.Where("progress < ?", 100).Find(&items)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, job := range items {
 | 
			
		||||
				// 失败或者 30 分钟还没完成的任务删除并退回算力
 | 
			
		||||
				if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
 | 
			
		||||
					p.db.Delete(&job)
 | 
			
		||||
					// 退回算力
 | 
			
		||||
					tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
 | 
			
		||||
					if tx.Error == nil && tx.RowsAffected > 0 {
 | 
			
		||||
						var user model.User
 | 
			
		||||
						p.db.Where("id = ?", job.UserId).First(&user)
 | 
			
		||||
						p.db.Create(&model.PowerLog{
 | 
			
		||||
							UserId:    user.Id,
 | 
			
		||||
							Username:  user.Username,
 | 
			
		||||
							Type:      types.PowerConsume,
 | 
			
		||||
							Amount:    job.Power,
 | 
			
		||||
							Balance:   user.Power + job.Power,
 | 
			
		||||
							Mark:      types.PowerAdd,
 | 
			
		||||
							Model:     "mid-journey",
 | 
			
		||||
							Remark:    fmt.Sprintf("绘画任务失败,退回算力。任务ID:%s", job.TaskId),
 | 
			
		||||
							CreatedAt: time.Now(),
 | 
			
		||||
						})
 | 
			
		||||
					}
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
 | 
			
		||||
					_ = servicePlus.Notify(job)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			time.Sleep(time.Second * 10)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) getService(name string) *Service {
 | 
			
		||||
	for _, s := range p.services {
 | 
			
		||||
		if s.Name == name {
 | 
			
		||||
			return s
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -1,185 +0,0 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"io"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ProxyClient MidJourney Proxy Client
 | 
			
		||||
type ProxyClient struct {
 | 
			
		||||
	Config types.MjProxyConfig
 | 
			
		||||
	apiURL string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
 | 
			
		||||
	return &ProxyClient{Config: config, apiURL: config.ApiURL}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
 | 
			
		||||
	prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
 | 
			
		||||
	if task.NegPrompt != "" {
 | 
			
		||||
		prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
 | 
			
		||||
	}
 | 
			
		||||
	body := ImageReq{
 | 
			
		||||
		Prompt:      prompt,
 | 
			
		||||
		Base64Array: make([]string, 0),
 | 
			
		||||
	}
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) > 0 {
 | 
			
		||||
		imageData, err := utils.DownloadImage(task.ImgArr[0], "")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("error with download image: ", err)
 | 
			
		||||
		} else {
 | 
			
		||||
			body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
	logger.Info("API URL: ", apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("mj-api-secret", c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		errStr, _ := io.ReadAll(r.Body)
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Blend 融图
 | 
			
		||||
func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL)
 | 
			
		||||
	body := ImageReq{
 | 
			
		||||
		Dimensions:  "SQUARE",
 | 
			
		||||
		Base64Array: make([]string, 0),
 | 
			
		||||
	}
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) > 0 {
 | 
			
		||||
		for _, imgURL := range task.ImgArr {
 | 
			
		||||
			imageData, err := utils.DownloadImage(imgURL, "")
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with download image: ", err)
 | 
			
		||||
			} else {
 | 
			
		||||
				body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("mj-api-secret", c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SwapFace 换脸
 | 
			
		||||
func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) {
 | 
			
		||||
	return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能,请使用 MidJourney-Plus")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Upscale 放大指定的图片
 | 
			
		||||
func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	body := map[string]interface{}{
 | 
			
		||||
		"action": "UPSCALE",
 | 
			
		||||
		"index":  task.Index,
 | 
			
		||||
		"taskId": task.MessageId,
 | 
			
		||||
	}
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("mj-api-secret", c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Variation  以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
 | 
			
		||||
func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	body := map[string]interface{}{
 | 
			
		||||
		"action": "VARIATION",
 | 
			
		||||
		"index":  task.Index,
 | 
			
		||||
		"taskId": task.MessageId,
 | 
			
		||||
	}
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("mj-api-secret", c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
 | 
			
		||||
	var res QueryRes
 | 
			
		||||
	r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		Get(apiURL)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return QueryRes{}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return QueryRes{}, errors.New("error status:" + r.Status)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Client = &ProxyClient{}
 | 
			
		||||
@@ -11,10 +11,11 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/sd"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
@@ -23,118 +24,112 @@ import (
 | 
			
		||||
 | 
			
		||||
// Service MJ 绘画服务
 | 
			
		||||
type Service struct {
 | 
			
		||||
	Name        string // service Name
 | 
			
		||||
	Client      Client // MJ Client
 | 
			
		||||
	taskQueue   *store.RedisQueue
 | 
			
		||||
	notifyQueue *store.RedisQueue
 | 
			
		||||
	db          *gorm.DB
 | 
			
		||||
	running     bool
 | 
			
		||||
	client          *Client // MJ Client
 | 
			
		||||
	taskQueue       *store.RedisQueue
 | 
			
		||||
	notifyQueue     *store.RedisQueue
 | 
			
		||||
	db              *gorm.DB
 | 
			
		||||
	Clients         *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
			
		||||
	uploaderManager *oss.UploaderManager
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
 | 
			
		||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager) *Service {
 | 
			
		||||
	return &Service{
 | 
			
		||||
		Name:        name,
 | 
			
		||||
		db:          db,
 | 
			
		||||
		taskQueue:   taskQueue,
 | 
			
		||||
		notifyQueue: notifyQueue,
 | 
			
		||||
		Client:      cli,
 | 
			
		||||
		running:     true,
 | 
			
		||||
		db:              db,
 | 
			
		||||
		taskQueue:       store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
 | 
			
		||||
		notifyQueue:     store.NewRedisQueue("MidJourney_Notify_Queue", redisCli),
 | 
			
		||||
		client:          client,
 | 
			
		||||
		Clients:         types.NewLMap[uint, *types.WsClient](),
 | 
			
		||||
		uploaderManager: manager,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Run() {
 | 
			
		||||
	logger.Infof("Starting MidJourney job consumer for %s", s.Name)
 | 
			
		||||
	for s.running {
 | 
			
		||||
		var task types.MjTask
 | 
			
		||||
		err := s.taskQueue.LPop(&task)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		//  如果配置了多个中转平台的 API KEY
 | 
			
		||||
		// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
 | 
			
		||||
		if task.ChannelId != "" && task.ChannelId != s.Name {
 | 
			
		||||
			logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
 | 
			
		||||
			s.taskQueue.RPush(task)
 | 
			
		||||
			time.Sleep(time.Second)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// translate prompt
 | 
			
		||||
		if utils.HasChinese(task.Prompt) {
 | 
			
		||||
			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				task.Prompt = content
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Warnf("error with translate prompt: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		// translate negative prompt
 | 
			
		||||
		if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
 | 
			
		||||
			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt))
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				task.NegPrompt = content
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Warnf("error with translate prompt: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var job model.MidJourneyJob
 | 
			
		||||
		tx := s.db.Where("id = ?", task.Id).First(&job)
 | 
			
		||||
		if tx.Error != nil {
 | 
			
		||||
			logger.Error("任务不存在,任务ID:", task.TaskId)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
 | 
			
		||||
		var res ImageRes
 | 
			
		||||
		switch task.Type {
 | 
			
		||||
		case types.TaskImage:
 | 
			
		||||
			res, err = s.Client.Imagine(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskUpscale:
 | 
			
		||||
			res, err = s.Client.Upscale(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskVariation:
 | 
			
		||||
			res, err = s.Client.Variation(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskBlend:
 | 
			
		||||
			res, err = s.Client.Blend(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskSwapFace:
 | 
			
		||||
			res, err = s.Client.SwapFace(task)
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err != nil || (res.Code != 1 && res.Code != 22) {
 | 
			
		||||
			var errMsg string
 | 
			
		||||
	logger.Info("Starting MidJourney job consumer for service")
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			var task types.MjTask
 | 
			
		||||
			err := s.taskQueue.LPop(&task)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				errMsg = err.Error()
 | 
			
		||||
			} else {
 | 
			
		||||
				errMsg = fmt.Sprintf("%v,%s", err, res.Description)
 | 
			
		||||
				logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			logger.Error("绘画任务执行失败:", errMsg)
 | 
			
		||||
			job.Progress = -1
 | 
			
		||||
			job.ErrMsg = errMsg
 | 
			
		||||
			// update the task progress
 | 
			
		||||
			s.db.Updates(&job)
 | 
			
		||||
			// 任务失败,通知前端
 | 
			
		||||
			s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed})
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		logger.Infof("任务提交成功:%+v", res)
 | 
			
		||||
		// 更新任务 ID/频道
 | 
			
		||||
		job.TaskId = res.Result
 | 
			
		||||
		job.MessageId = res.Result
 | 
			
		||||
		job.ChannelId = s.Name
 | 
			
		||||
		s.db.Updates(&job)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
			// translate prompt
 | 
			
		||||
			if utils.HasChinese(task.Prompt) {
 | 
			
		||||
				content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini")
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					task.Prompt = content
 | 
			
		||||
				} else {
 | 
			
		||||
					logger.Warnf("error with translate prompt: %v", err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			// translate negative prompt
 | 
			
		||||
			if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
 | 
			
		||||
				content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini")
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					task.NegPrompt = content
 | 
			
		||||
				} else {
 | 
			
		||||
					logger.Warnf("error with translate prompt: %v", err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Stop() {
 | 
			
		||||
	s.running = false
 | 
			
		||||
			// use fast mode as default
 | 
			
		||||
			if task.Mode == "" {
 | 
			
		||||
				task.Mode = "fast"
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var job model.MidJourneyJob
 | 
			
		||||
			tx := s.db.Where("id = ?", task.Id).First(&job)
 | 
			
		||||
			if tx.Error != nil {
 | 
			
		||||
				logger.Error("任务不存在,任务ID:", task.TaskId)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			logger.Infof("handle a new MidJourney task: %+v", task)
 | 
			
		||||
			var res ImageRes
 | 
			
		||||
			switch task.Type {
 | 
			
		||||
			case types.TaskImage:
 | 
			
		||||
				res, err = s.client.Imagine(task)
 | 
			
		||||
				break
 | 
			
		||||
			case types.TaskUpscale:
 | 
			
		||||
				res, err = s.client.Upscale(task)
 | 
			
		||||
				break
 | 
			
		||||
			case types.TaskVariation:
 | 
			
		||||
				res, err = s.client.Variation(task)
 | 
			
		||||
				break
 | 
			
		||||
			case types.TaskBlend:
 | 
			
		||||
				res, err = s.client.Blend(task)
 | 
			
		||||
				break
 | 
			
		||||
			case types.TaskSwapFace:
 | 
			
		||||
				res, err = s.client.SwapFace(task)
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err != nil || (res.Code != 1 && res.Code != 22) {
 | 
			
		||||
				var errMsg string
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					errMsg = err.Error()
 | 
			
		||||
				} else {
 | 
			
		||||
					errMsg = fmt.Sprintf("%v,%s", err, res.Description)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				logger.Error("绘画任务执行失败:", errMsg)
 | 
			
		||||
				job.Progress = service.FailTaskProgress
 | 
			
		||||
				job.ErrMsg = errMsg
 | 
			
		||||
				// update the task progress
 | 
			
		||||
				s.db.Updates(&job)
 | 
			
		||||
				// 任务失败,通知前端
 | 
			
		||||
				s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			logger.Infof("任务提交成功:%+v", res)
 | 
			
		||||
			// 更新任务 ID/频道
 | 
			
		||||
			job.TaskId = res.Result
 | 
			
		||||
			job.MessageId = res.Result
 | 
			
		||||
			job.ChannelId = res.Channel
 | 
			
		||||
			s.db.Updates(&job)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CBReq struct {
 | 
			
		||||
@@ -155,46 +150,6 @@ type CBReq struct {
 | 
			
		||||
	} `json:"properties"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Notify(job model.MidJourneyJob) error {
 | 
			
		||||
	task, err := s.Client.QueryTask(job.TaskId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 任务执行失败了
 | 
			
		||||
	if task.FailReason != "" {
 | 
			
		||||
		s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
			"progress": -1,
 | 
			
		||||
			"err_msg":  task.FailReason,
 | 
			
		||||
		})
 | 
			
		||||
		s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
 | 
			
		||||
		return fmt.Errorf("task failed: %v", task.FailReason)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(task.Buttons) > 0 {
 | 
			
		||||
		job.Hash = GetImageHash(task.Buttons[0].CustomId)
 | 
			
		||||
	}
 | 
			
		||||
	oldProgress := job.Progress
 | 
			
		||||
	job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
 | 
			
		||||
	job.Prompt = task.PromptEn
 | 
			
		||||
	if task.ImageUrl != "" {
 | 
			
		||||
		job.OrgURL = task.ImageUrl
 | 
			
		||||
	}
 | 
			
		||||
	tx := s.db.Updates(&job)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return fmt.Errorf("error with update database: %v", tx.Error)
 | 
			
		||||
	}
 | 
			
		||||
	// 通知前端更新任务进度
 | 
			
		||||
	if oldProgress != job.Progress {
 | 
			
		||||
		message := sd.Running
 | 
			
		||||
		if job.Progress == 100 {
 | 
			
		||||
			message = sd.Finished
 | 
			
		||||
		}
 | 
			
		||||
		s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetImageHash(action string) string {
 | 
			
		||||
	split := strings.Split(action, "::")
 | 
			
		||||
	if len(split) > 5 {
 | 
			
		||||
@@ -202,3 +157,143 @@ func GetImageHash(action string) string {
 | 
			
		||||
	}
 | 
			
		||||
	return split[len(split)-1]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) CheckTaskNotify() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			var message service.NotifyMessage
 | 
			
		||||
			err := s.notifyQueue.LPop(&message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			cli := s.Clients.Get(uint(message.UserId))
 | 
			
		||||
			if cli == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			err = cli.Send([]byte(message.Message))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) DownloadImages() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		var items []model.MidJourneyJob
 | 
			
		||||
		for {
 | 
			
		||||
			res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// download images
 | 
			
		||||
			for _, v := range items {
 | 
			
		||||
				if v.OrgURL == "" {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				logger.Infof("try to download image: %s", v.OrgURL)
 | 
			
		||||
				// 如果是返回的是 discord 图片地址,则使用代理下载
 | 
			
		||||
				proxy := false
 | 
			
		||||
				if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
 | 
			
		||||
					proxy = true
 | 
			
		||||
				}
 | 
			
		||||
				imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
 | 
			
		||||
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Errorf("error with download image %s, %v", v.OrgURL, err)
 | 
			
		||||
					continue
 | 
			
		||||
				} else {
 | 
			
		||||
					logger.Infof("download image %s successfully.", v.OrgURL)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				v.ImgURL = imgURL
 | 
			
		||||
				s.db.Updates(&v)
 | 
			
		||||
 | 
			
		||||
				cli := s.Clients.Get(uint(v.UserId))
 | 
			
		||||
				if cli == nil {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				err = cli.Send([]byte(service.TaskStatusFinished))
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			time.Sleep(time.Second * 5)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PushTask push a new mj task in to task queue
 | 
			
		||||
func (s *Service) PushTask(task types.MjTask) {
 | 
			
		||||
	logger.Debugf("add a new MidJourney task to the task list: %+v", task)
 | 
			
		||||
	s.taskQueue.RPush(task)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SyncTaskProgress 异步拉取任务
 | 
			
		||||
func (s *Service) SyncTaskProgress() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		var jobs []model.MidJourneyJob
 | 
			
		||||
		for {
 | 
			
		||||
			res := s.db.Where("progress < ?", 100).Where("channel_id <> ?", "").Find(&jobs)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, job := range jobs {
 | 
			
		||||
				// 10 分钟还没完成的任务标记为失败
 | 
			
		||||
				if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
 | 
			
		||||
					job.Progress = service.FailTaskProgress
 | 
			
		||||
					job.ErrMsg = "任务超时"
 | 
			
		||||
					s.db.Updates(&job)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				task, err := s.client.QueryTask(job.TaskId, job.ChannelId)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Errorf("error with query task: %v", err)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 任务执行失败了
 | 
			
		||||
				if task.FailReason != "" {
 | 
			
		||||
					s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
						"progress": service.FailTaskProgress,
 | 
			
		||||
						"err_msg":  task.FailReason,
 | 
			
		||||
					})
 | 
			
		||||
					logger.Errorf("task failed: %v", task.FailReason)
 | 
			
		||||
					s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if len(task.Buttons) > 0 {
 | 
			
		||||
					job.Hash = GetImageHash(task.Buttons[0].CustomId)
 | 
			
		||||
				}
 | 
			
		||||
				oldProgress := job.Progress
 | 
			
		||||
				job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
 | 
			
		||||
				job.Prompt = task.PromptEn
 | 
			
		||||
				if task.ImageUrl != "" {
 | 
			
		||||
					job.OrgURL = task.ImageUrl
 | 
			
		||||
				}
 | 
			
		||||
				err = s.db.Updates(&job).Error
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Errorf("error with update database: %v", err)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 通知前端更新任务进度
 | 
			
		||||
				if oldProgress != job.Progress {
 | 
			
		||||
					message := service.TaskStatusRunning
 | 
			
		||||
					if job.Progress == 100 {
 | 
			
		||||
						message = service.TaskStatusFinished
 | 
			
		||||
					}
 | 
			
		||||
					s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			time.Sleep(time.Second * 5)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -84,25 +84,25 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
	var imageData []byte
 | 
			
		||||
func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
 | 
			
		||||
	var fileData []byte
 | 
			
		||||
	var err error
 | 
			
		||||
	if useProxy {
 | 
			
		||||
		imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
 | 
			
		||||
		fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
 | 
			
		||||
	} else {
 | 
			
		||||
		imageData, err = utils.DownloadImage(imageURL, "")
 | 
			
		||||
		fileData, err = utils.DownloadImage(fileURL, "")
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with download image: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	parse, err := url.Parse(imageURL)
 | 
			
		||||
	parse, err := url.Parse(fileURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with parse image URL: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	fileExt := utils.GetImgExt(parse.Path)
 | 
			
		||||
	objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
 | 
			
		||||
	// 上传文件字节数据
 | 
			
		||||
	err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
 | 
			
		||||
	err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -57,8 +57,8 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
	parse, err := url.Parse(imageURL)
 | 
			
		||||
func (s LocalStorage) PutUrlFile(fileURL string, useProxy bool) (string, error) {
 | 
			
		||||
	parse, err := url.Parse(fileURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with parse image URL: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -69,9 +69,9 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if useProxy {
 | 
			
		||||
		err = utils.DownloadFile(imageURL, filePath, s.proxyURL)
 | 
			
		||||
		err = utils.DownloadFile(fileURL, filePath, s.proxyURL)
 | 
			
		||||
	} else {
 | 
			
		||||
		err = utils.DownloadFile(imageURL, filePath, "")
 | 
			
		||||
		err = utils.DownloadFile(fileURL, filePath, "")
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with download image: %v", err)
 | 
			
		||||
 
 | 
			
		||||
@@ -44,18 +44,18 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
 | 
			
		||||
	return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
	var imageData []byte
 | 
			
		||||
func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
 | 
			
		||||
	var fileData []byte
 | 
			
		||||
	var err error
 | 
			
		||||
	if useProxy {
 | 
			
		||||
		imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
 | 
			
		||||
		fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
 | 
			
		||||
	} else {
 | 
			
		||||
		imageData, err = utils.DownloadImage(imageURL, "")
 | 
			
		||||
		fileData, err = utils.DownloadImage(fileURL, "")
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with download image: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	parse, err := url.Parse(imageURL)
 | 
			
		||||
	parse, err := url.Parse(fileURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with parse image URL: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -65,8 +65,8 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
		context.Background(),
 | 
			
		||||
		s.config.Bucket,
 | 
			
		||||
		filename,
 | 
			
		||||
		strings.NewReader(string(imageData)),
 | 
			
		||||
		int64(len(imageData)),
 | 
			
		||||
		strings.NewReader(string(fileData)),
 | 
			
		||||
		int64(len(fileData)),
 | 
			
		||||
		minio.PutObjectOptions{ContentType: "image/png"})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
 
 | 
			
		||||
@@ -93,18 +93,18 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
	var imageData []byte
 | 
			
		||||
func (s QinNiuOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
 | 
			
		||||
	var fileData []byte
 | 
			
		||||
	var err error
 | 
			
		||||
	if useProxy {
 | 
			
		||||
		imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
 | 
			
		||||
		fileData, err = utils.DownloadImage(fileURL, s.proxyURL)
 | 
			
		||||
	} else {
 | 
			
		||||
		imageData, err = utils.DownloadImage(imageURL, "")
 | 
			
		||||
		fileData, err = utils.DownloadImage(fileURL, "")
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with download image: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	parse, err := url.Parse(imageURL)
 | 
			
		||||
	parse, err := url.Parse(fileURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with parse image URL: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -113,7 +113,7 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
	ret := storage.PutRet{}
 | 
			
		||||
	extra := storage.PutExtra{}
 | 
			
		||||
	// 上传文件字节数据
 | 
			
		||||
	err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(imageData), int64(len(imageData)), &extra)
 | 
			
		||||
	err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(fileData), int64(len(fileData)), &extra)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -23,7 +23,7 @@ type File struct {
 | 
			
		||||
}
 | 
			
		||||
type Uploader interface {
 | 
			
		||||
	PutFile(ctx *gin.Context, name string) (File, error)
 | 
			
		||||
	PutImg(imageURL string, useProxy bool) (string, error)
 | 
			
		||||
	PutUrlFile(url string, useProxy bool) (string, error)
 | 
			
		||||
	PutBase64(imageData string) (string, error)
 | 
			
		||||
	Delete(fileURL string) error
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -8,12 +8,13 @@ package payment
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"github.com/smartwalle/alipay/v3"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"github.com/go-pay/gopay"
 | 
			
		||||
	"github.com/go-pay/gopay/alipay"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -35,93 +36,90 @@ func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
 | 
			
		||||
		return nil, fmt.Errorf("error with read App Private key: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	xClient, err := alipay.New(config.AppId, priKey, !config.SandBox)
 | 
			
		||||
	client, err := alipay.NewClient(config.AppId, priKey, !config.SandBox)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("error with initialize alipay service: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = xClient.LoadAppCertPublicKeyFromFile(config.PublicKey); err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("error with loading App PublicKey: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if err = xClient.LoadAliPayRootCertFromFile(config.RootCert); err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("error with loading alipay RootCert: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if err = xClient.LoadAlipayCertPublicKeyFromFile(config.AlipayPublicKey); err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("error with loading Alipay PublicKey: %v", err)
 | 
			
		||||
	//client.DebugSwitch = gopay.DebugOn // 开启调试模式
 | 
			
		||||
	client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间
 | 
			
		||||
		SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
 | 
			
		||||
		SetSignType(alipay.RSA2). // 设置签名类型,不设置默认 RSA2
 | 
			
		||||
		SetReturnUrl(config.ReturnURL). // 设置返回URL
 | 
			
		||||
		SetNotifyUrl(config.NotifyURL)
 | 
			
		||||
 | 
			
		||||
	if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("error with load payment public key: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &AlipayService{config: &config, client: xClient}, nil
 | 
			
		||||
	return &AlipayService{config: &config, client: client}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *AlipayService) PayUrlMobile(outTradeNo string, notifyURL string, returnURL string, Amount string, subject string) (string, error) {
 | 
			
		||||
	var p = alipay.TradeWapPay{}
 | 
			
		||||
	p.NotifyURL = notifyURL
 | 
			
		||||
	p.ReturnURL = returnURL
 | 
			
		||||
	p.Subject = subject
 | 
			
		||||
	p.OutTradeNo = outTradeNo
 | 
			
		||||
	p.TotalAmount = Amount
 | 
			
		||||
	p.ProductCode = "QUICK_WAP_WAY"
 | 
			
		||||
	res, err := s.client.TradeWapPay(p)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res.String(), err
 | 
			
		||||
func (s *AlipayService) PayUrlMobile(outTradeNo string, amount string, subject string) (string, error) {
 | 
			
		||||
	bm := make(gopay.BodyMap)
 | 
			
		||||
	bm.Set("subject", subject)
 | 
			
		||||
	bm.Set("out_trade_no", outTradeNo)
 | 
			
		||||
	bm.Set("quit_url", s.config.ReturnURL)
 | 
			
		||||
	bm.Set("total_amount", amount)
 | 
			
		||||
	bm.Set("product_code", "QUICK_WAP_WAY")
 | 
			
		||||
	return s.client.TradeWapPay(context.Background(), bm)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *AlipayService) PayUrlPc(outTradeNo string, notifyURL string, returnURL string, amount string, subject string) (string, error) {
 | 
			
		||||
	var p = alipay.TradePagePay{}
 | 
			
		||||
	p.NotifyURL = notifyURL
 | 
			
		||||
	p.ReturnURL = returnURL
 | 
			
		||||
	p.Subject = subject
 | 
			
		||||
	p.OutTradeNo = outTradeNo
 | 
			
		||||
	p.TotalAmount = amount
 | 
			
		||||
	p.ProductCode = "FAST_INSTANT_TRADE_PAY"
 | 
			
		||||
	res, err := s.client.TradePagePay(p)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res.String(), err
 | 
			
		||||
func (s *AlipayService) PayUrlPc(outTradeNo string, amount string, subject string) (string, error) {
 | 
			
		||||
	bm := make(gopay.BodyMap)
 | 
			
		||||
	bm.Set("subject", subject)
 | 
			
		||||
	bm.Set("out_trade_no", outTradeNo)
 | 
			
		||||
	bm.Set("total_amount", amount)
 | 
			
		||||
	bm.Set("product_code", "FAST_INSTANT_TRADE_PAY")
 | 
			
		||||
	return s.client.TradePagePay(context.Background(), bm)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TradeVerify 交易验证
 | 
			
		||||
func (s *AlipayService) TradeVerify(reqForm url.Values) NotifyVo {
 | 
			
		||||
	err := s.client.VerifySign(reqForm)
 | 
			
		||||
func (s *AlipayService) TradeVerify(request *http.Request) NotifyVo {
 | 
			
		||||
	notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Println("异步通知验证签名发生错误", err)
 | 
			
		||||
		return NotifyVo{
 | 
			
		||||
			Status:  0,
 | 
			
		||||
			Message: "异步通知验证签名发生错误",
 | 
			
		||||
			Status:  Failure,
 | 
			
		||||
			Message: "error with parse notify request: " + err.Error(),
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return s.TradeQuery(reqForm.Get("out_trade_no"))
 | 
			
		||||
	_, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return NotifyVo{
 | 
			
		||||
			Status:  Failure,
 | 
			
		||||
			Message: "error with verify sign: " + err.Error(),
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return s.TradeQuery(request.Form.Get("out_trade_no"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo {
 | 
			
		||||
	var p = alipay.TradeQuery{}
 | 
			
		||||
	p.OutTradeNo = outTradeNo
 | 
			
		||||
	rsp, err := s.client.TradeQuery(p)
 | 
			
		||||
	bm := make(gopay.BodyMap)
 | 
			
		||||
	bm.Set("out_trade_no", outTradeNo)
 | 
			
		||||
 | 
			
		||||
	//查询订单
 | 
			
		||||
	rsp, err := s.client.TradeQuery(context.Background(), bm)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return NotifyVo{
 | 
			
		||||
			Status:  0,
 | 
			
		||||
			Status:  Failure,
 | 
			
		||||
			Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(),
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if rsp.IsSuccess() == true && rsp.TradeStatus == "TRADE_SUCCESS" {
 | 
			
		||||
	if rsp.Response.TradeStatus == "TRADE_SUCCESS" {
 | 
			
		||||
		return NotifyVo{
 | 
			
		||||
			Status:     1,
 | 
			
		||||
			OutTradeNo: rsp.OutTradeNo,
 | 
			
		||||
			TradeNo:    rsp.TradeNo,
 | 
			
		||||
			Amount:     rsp.TotalAmount,
 | 
			
		||||
			Subject:    rsp.Subject,
 | 
			
		||||
			Status:     Success,
 | 
			
		||||
			OutTradeNo: rsp.Response.OutTradeNo,
 | 
			
		||||
			TradeId:    rsp.Response.TradeNo,
 | 
			
		||||
			Amount:     rsp.Response.TotalAmount,
 | 
			
		||||
			Subject:    rsp.Response.Subject,
 | 
			
		||||
			Message:    "OK",
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		return NotifyVo{
 | 
			
		||||
			Status:  0,
 | 
			
		||||
			Status:  Failure,
 | 
			
		||||
			Message: "异步查询验证订单信息发生错误" + outTradeNo,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
@@ -134,16 +132,3 @@ func readKey(filename string) (string, error) {
 | 
			
		||||
	}
 | 
			
		||||
	return string(data), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NotifyVo struct {
 | 
			
		||||
	Status     int
 | 
			
		||||
	OutTradeNo string
 | 
			
		||||
	TradeNo    string
 | 
			
		||||
	Amount     string
 | 
			
		||||
	Message    string
 | 
			
		||||
	Subject    string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (v NotifyVo) Success() bool {
 | 
			
		||||
	return v.Status == 1
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -21,12 +21,12 @@ import (
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type PayJS struct {
 | 
			
		||||
type JPayService struct {
 | 
			
		||||
	config *types.JPayConfig
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPayJS(appConfig *types.AppConfig) *PayJS {
 | 
			
		||||
	return &PayJS{
 | 
			
		||||
func NewJPayService(appConfig *types.AppConfig) *JPayService {
 | 
			
		||||
	return &JPayService{
 | 
			
		||||
		config: &appConfig.JPayConfig,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -53,7 +53,7 @@ func (r JPayReps) IsOK() bool {
 | 
			
		||||
	return r.ReturnMsg == "SUCCESS"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (js *PayJS) Pay(param JPayReq) JPayReps {
 | 
			
		||||
func (js *JPayService) Pay(param JPayReq) JPayReps {
 | 
			
		||||
	param.NotifyURL = js.config.NotifyURL
 | 
			
		||||
	var p = url.Values{}
 | 
			
		||||
	encode := utils.JsonEncode(param)
 | 
			
		||||
@@ -86,13 +86,13 @@ func (js *PayJS) Pay(param JPayReq) JPayReps {
 | 
			
		||||
	return data
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (js *PayJS) PayH5(p url.Values) string {
 | 
			
		||||
func (js *JPayService) PayH5(p url.Values) string {
 | 
			
		||||
	p.Add("mchid", js.config.AppId)
 | 
			
		||||
	p.Add("sign", js.sign(p))
 | 
			
		||||
	return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (js *PayJS) sign(params url.Values) string {
 | 
			
		||||
func (js *JPayService) sign(params url.Values) string {
 | 
			
		||||
	params.Del(`sign`)
 | 
			
		||||
	var keys = make([]string, 0, 0)
 | 
			
		||||
	for key := range params {
 | 
			
		||||
@@ -117,20 +117,18 @@ func (js *PayJS) sign(params url.Values) string {
 | 
			
		||||
	return strings.ToUpper(md5res)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Check 查询订单支付状态
 | 
			
		||||
// TradeVerify 查询订单支付状态
 | 
			
		||||
// @param tradeNo 支付平台交易 ID
 | 
			
		||||
func (js *PayJS) Check(tradeNo string) error {
 | 
			
		||||
func (js *JPayService) TradeVerify(tradeNo string) error {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL)
 | 
			
		||||
	params := url.Values{}
 | 
			
		||||
	params.Add("payjs_order_id", tradeNo)
 | 
			
		||||
	params.Add("sign", js.sign(params))
 | 
			
		||||
	data := strings.NewReader(params.Encode())
 | 
			
		||||
	resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data)
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error with http reqeust: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
	body, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										19
									
								
								api/service/payment/types.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								api/service/payment/types.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
package payment
 | 
			
		||||
 | 
			
		||||
type NotifyVo struct {
 | 
			
		||||
	Status     int
 | 
			
		||||
	OutTradeNo string // 商户订单号
 | 
			
		||||
	TradeId    string // 交易ID
 | 
			
		||||
	Amount     string // 交易金额
 | 
			
		||||
	Message    string
 | 
			
		||||
	Subject    string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (v NotifyVo) Success() bool {
 | 
			
		||||
	return v.Status == Success
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	Success = 0
 | 
			
		||||
	Failure = 1
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										135
									
								
								api/service/payment/wepay_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								api/service/payment/wepay_service.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,135 @@
 | 
			
		||||
package payment
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"github.com/go-pay/gopay"
 | 
			
		||||
	"github.com/go-pay/gopay/wechat/v3"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type WechatPayService struct {
 | 
			
		||||
	config *types.WechatPayConfig
 | 
			
		||||
	client *wechat.ClientV3
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) {
 | 
			
		||||
	config := appConfig.WechatPayConfig
 | 
			
		||||
	if !config.Enabled {
 | 
			
		||||
		logger.Info("Disabled WechatPay service")
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	priKey, err := readKey(config.PrivateKey)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("error with read App Private key: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, priKey)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("error with initialize WechatPay service: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	err = client.AutoVerifySign()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("error with autoVerifySign: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	//client.DebugSwitch = gopay.DebugOn
 | 
			
		||||
 | 
			
		||||
	return &WechatPayService{config: &config, client: client}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *WechatPayService) PayUrlNative(outTradeNo string, amount int, subject string) (string, error) {
 | 
			
		||||
	expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
 | 
			
		||||
	// 初始化 BodyMap
 | 
			
		||||
	bm := make(gopay.BodyMap)
 | 
			
		||||
	bm.Set("appid", s.config.AppId).
 | 
			
		||||
		Set("mchid", s.config.MchId).
 | 
			
		||||
		Set("description", subject).
 | 
			
		||||
		Set("out_trade_no", outTradeNo).
 | 
			
		||||
		Set("time_expire", expire).
 | 
			
		||||
		Set("notify_url", s.config.NotifyURL).
 | 
			
		||||
		SetBodyMap("amount", func(bm gopay.BodyMap) {
 | 
			
		||||
			bm.Set("total", amount).
 | 
			
		||||
				Set("currency", "CNY")
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
	wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if wxRsp.Code != wechat.Success {
 | 
			
		||||
		return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
 | 
			
		||||
	}
 | 
			
		||||
	return wxRsp.Response.CodeUrl, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *WechatPayService) PayUrlH5(outTradeNo string, amount int, subject string, ip string) (string, error) {
 | 
			
		||||
	expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
 | 
			
		||||
	// 初始化 BodyMap
 | 
			
		||||
	bm := make(gopay.BodyMap)
 | 
			
		||||
	bm.Set("appid", s.config.AppId).
 | 
			
		||||
		Set("mchid", s.config.MchId).
 | 
			
		||||
		Set("description", subject).
 | 
			
		||||
		Set("out_trade_no", outTradeNo).
 | 
			
		||||
		Set("time_expire", expire).
 | 
			
		||||
		Set("notify_url", s.config.NotifyURL).
 | 
			
		||||
		SetBodyMap("amount", func(bm gopay.BodyMap) {
 | 
			
		||||
			bm.Set("total", amount).
 | 
			
		||||
				Set("currency", "CNY")
 | 
			
		||||
		}).
 | 
			
		||||
		SetBodyMap("scene_info", func(bm gopay.BodyMap) {
 | 
			
		||||
			bm.Set("payer_client_ip", ip).
 | 
			
		||||
				SetBodyMap("h5_info", func(bm gopay.BodyMap) {
 | 
			
		||||
					bm.Set("type", "Wap")
 | 
			
		||||
				})
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
	wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if wxRsp.Code != wechat.Success {
 | 
			
		||||
		return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
 | 
			
		||||
	}
 | 
			
		||||
	return wxRsp.Response.H5Url, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NotifyResponse struct {
 | 
			
		||||
	Code    string `json:"code"`
 | 
			
		||||
	Message string `xml:"message"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TradeVerify 交易验证
 | 
			
		||||
func (s *WechatPayService) TradeVerify(request *http.Request) NotifyVo {
 | 
			
		||||
	notifyReq, err := wechat.V3ParseNotify(request)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return NotifyVo{Status: 1, Message: fmt.Sprintf("error with client v3 parse notify: %v", err)}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO: 这里验签程序有 Bug,一直报错:crypto/rsa: verification error,先暂时取消验签
 | 
			
		||||
	//err = notifyReq.VerifySignByPK(s.client.WxPublicKey())
 | 
			
		||||
	//if err != nil {
 | 
			
		||||
	//	return fmt.Errorf("error with client v3 verify sign: %v", err)
 | 
			
		||||
	//}
 | 
			
		||||
 | 
			
		||||
	// 解密支付密文,验证订单信息
 | 
			
		||||
	result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return NotifyVo{Status: Failure, Message: fmt.Sprintf("error with client v3 decrypt: %v", err)}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return NotifyVo{
 | 
			
		||||
		Status:     Success,
 | 
			
		||||
		OutTradeNo: result.OutTradeNo,
 | 
			
		||||
		TradeId:    result.TransactionId,
 | 
			
		||||
		Amount:     fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -1,143 +0,0 @@
 | 
			
		||||
package sd
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ServicePool struct {
 | 
			
		||||
	services    []*Service
 | 
			
		||||
	taskQueue   *store.RedisQueue
 | 
			
		||||
	notifyQueue *store.RedisQueue
 | 
			
		||||
	db          *gorm.DB
 | 
			
		||||
	Clients     *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
			
		||||
	uploader    *oss.UploaderManager
 | 
			
		||||
	levelDB     *store.LevelDB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool {
 | 
			
		||||
	services := make([]*Service, 0)
 | 
			
		||||
	taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
 | 
			
		||||
	notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
 | 
			
		||||
 | 
			
		||||
	return &ServicePool{
 | 
			
		||||
		taskQueue:   taskQueue,
 | 
			
		||||
		notifyQueue: notifyQueue,
 | 
			
		||||
		services:    services,
 | 
			
		||||
		db:          db,
 | 
			
		||||
		Clients:     types.NewLMap[uint, *types.WsClient](),
 | 
			
		||||
		uploader:    manager,
 | 
			
		||||
		levelDB:     levelDB,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) {
 | 
			
		||||
	// stop old service
 | 
			
		||||
	for _, s := range p.services {
 | 
			
		||||
		s.Stop()
 | 
			
		||||
	}
 | 
			
		||||
	p.services = make([]*Service, 0)
 | 
			
		||||
 | 
			
		||||
	for k, config := range configs {
 | 
			
		||||
		if config.Enabled == false {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// create sd service
 | 
			
		||||
		name := fmt.Sprintf(" sd-service-%d", k)
 | 
			
		||||
		service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB)
 | 
			
		||||
		// run sd service
 | 
			
		||||
		go func() {
 | 
			
		||||
			service.Run()
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		p.services = append(p.services, service)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PushTask push a new mj task in to task queue
 | 
			
		||||
func (p *ServicePool) PushTask(task types.SdTask) {
 | 
			
		||||
	logger.Debugf("add a new MidJourney task to the task list: %+v", task)
 | 
			
		||||
	p.taskQueue.RPush(task)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) CheckTaskNotify() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running Stable-Diffusion task notify checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var message NotifyMessage
 | 
			
		||||
			err := p.notifyQueue.LPop(&message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			client := p.Clients.Get(uint(message.UserId))
 | 
			
		||||
			if client == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			err = client.Send([]byte(message.Message))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
 | 
			
		||||
func (p *ServicePool) CheckTaskStatus() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running Stable-Diffusion task status checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var jobs []model.SdJob
 | 
			
		||||
			res := p.db.Where("progress < ?", 100).Find(&jobs)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				time.Sleep(5 * time.Second)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, job := range jobs {
 | 
			
		||||
				// 5 分钟还没完成的任务直接删除
 | 
			
		||||
				if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
 | 
			
		||||
					p.db.Delete(&job)
 | 
			
		||||
					var user model.User
 | 
			
		||||
					p.db.Where("id = ?", job.UserId).First(&user)
 | 
			
		||||
					// 退回绘图次数
 | 
			
		||||
					res = p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
 | 
			
		||||
					if res.Error == nil && res.RowsAffected > 0 {
 | 
			
		||||
						p.db.Create(&model.PowerLog{
 | 
			
		||||
							UserId:    user.Id,
 | 
			
		||||
							Username:  user.Username,
 | 
			
		||||
							Type:      types.PowerConsume,
 | 
			
		||||
							Amount:    job.Power,
 | 
			
		||||
							Balance:   user.Power + job.Power,
 | 
			
		||||
							Mark:      types.PowerAdd,
 | 
			
		||||
							Model:     "stable-diffusion",
 | 
			
		||||
							Remark:    fmt.Sprintf("任务失败,退回算力。任务ID:%s", job.TaskId),
 | 
			
		||||
							CreatedAt: time.Now(),
 | 
			
		||||
						})
 | 
			
		||||
					}
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			time.Sleep(time.Second * 10)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// HasAvailableService check if it has available mj service in pool
 | 
			
		||||
func (p *ServicePool) HasAvailableService() bool {
 | 
			
		||||
	return len(p.services) > 0
 | 
			
		||||
}
 | 
			
		||||
@@ -10,95 +10,91 @@ package sd
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
// SD 绘画服务
 | 
			
		||||
 | 
			
		||||
type Service struct {
 | 
			
		||||
	httpClient    *req.Client
 | 
			
		||||
	config        types.StableDiffusionConfig
 | 
			
		||||
	taskQueue     *store.RedisQueue
 | 
			
		||||
	notifyQueue   *store.RedisQueue
 | 
			
		||||
	db            *gorm.DB
 | 
			
		||||
	uploadManager *oss.UploaderManager
 | 
			
		||||
	name          string // service name
 | 
			
		||||
	leveldb       *store.LevelDB
 | 
			
		||||
	running       bool // 运行状态
 | 
			
		||||
	Clients       *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service {
 | 
			
		||||
	config.ApiURL = strings.TrimRight(config.ApiURL, "/")
 | 
			
		||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB, redisCli *redis.Client) *Service {
 | 
			
		||||
	return &Service{
 | 
			
		||||
		name:          name,
 | 
			
		||||
		config:        config,
 | 
			
		||||
		httpClient:    req.C(),
 | 
			
		||||
		taskQueue:     taskQueue,
 | 
			
		||||
		notifyQueue:   notifyQueue,
 | 
			
		||||
		taskQueue:     store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
 | 
			
		||||
		notifyQueue:   store.NewRedisQueue("StableDiffusion_Queue", redisCli),
 | 
			
		||||
		db:            db,
 | 
			
		||||
		leveldb:       levelDB,
 | 
			
		||||
		Clients:       types.NewLMap[uint, *types.WsClient](),
 | 
			
		||||
		uploadManager: manager,
 | 
			
		||||
		running:       true,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Run() {
 | 
			
		||||
	logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name)
 | 
			
		||||
	for s.running {
 | 
			
		||||
		var task types.SdTask
 | 
			
		||||
		err := s.taskQueue.LPop(&task)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	logger.Infof("Starting Stable-Diffusion job consumer")
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			var task types.SdTask
 | 
			
		||||
			err := s.taskQueue.LPop(&task)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		// translate prompt
 | 
			
		||||
		if utils.HasChinese(task.Params.Prompt) {
 | 
			
		||||
			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt))
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				task.Params.Prompt = content
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Warnf("error with translate prompt: %v", err)
 | 
			
		||||
			// translate prompt
 | 
			
		||||
			if utils.HasChinese(task.Params.Prompt) {
 | 
			
		||||
				content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					task.Params.Prompt = content
 | 
			
		||||
				} else {
 | 
			
		||||
					logger.Warnf("error with translate prompt: %v", err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// translate negative prompt
 | 
			
		||||
			if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
 | 
			
		||||
				content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					task.Params.NegPrompt = content
 | 
			
		||||
				} else {
 | 
			
		||||
					logger.Warnf("error with translate prompt: %v", err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			logger.Infof("handle a new Stable-Diffusion task: %+v", task)
 | 
			
		||||
			err = s.Txt2Img(task)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("绘画任务执行失败:", err.Error())
 | 
			
		||||
				// update the task progress
 | 
			
		||||
				s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
					"progress": service.FailTaskProgress,
 | 
			
		||||
					"err_msg":  err.Error(),
 | 
			
		||||
				})
 | 
			
		||||
				// 通知前端,任务失败
 | 
			
		||||
				s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// translate negative prompt
 | 
			
		||||
		if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
 | 
			
		||||
			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt))
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				task.Params.NegPrompt = content
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Warnf("error with translate prompt: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task)
 | 
			
		||||
		err = s.Txt2Img(task)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("绘画任务执行失败:", err.Error())
 | 
			
		||||
			// update the task progress
 | 
			
		||||
			s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
				"progress": -1,
 | 
			
		||||
				"err_msg":  err.Error(),
 | 
			
		||||
			})
 | 
			
		||||
			// 通知前端,任务失败
 | 
			
		||||
			s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Stop() {
 | 
			
		||||
	s.running = false
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Txt2ImgReq 文生图请求实体
 | 
			
		||||
@@ -160,12 +156,19 @@ func (s *Service) Txt2Img(task types.SdTask) error {
 | 
			
		||||
	}
 | 
			
		||||
	var res Txt2ImgResp
 | 
			
		||||
	var errChan = make(chan error)
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
 | 
			
		||||
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	err := s.db.Where("type", "sd").Where("enabled", true).Order("last_used_at ASC").First(&apiKey).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("no available Stable-Diffusion api key: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", apiKey.ApiURL)
 | 
			
		||||
	logger.Debugf("send image request to %s", apiURL)
 | 
			
		||||
	// send a request to sd api endpoint
 | 
			
		||||
	go func() {
 | 
			
		||||
		response, err := s.httpClient.R().
 | 
			
		||||
			SetHeader("Authorization", s.config.ApiKey).
 | 
			
		||||
			SetHeader("Authorization", apiKey.Value).
 | 
			
		||||
			SetBody(body).
 | 
			
		||||
			SetSuccessResult(&res).
 | 
			
		||||
			Post(apiURL)
 | 
			
		||||
@@ -178,6 +181,10 @@ func (s *Service) Txt2Img(task types.SdTask) error {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// update the last used time
 | 
			
		||||
		apiKey.LastUsedAt = time.Now().Unix()
 | 
			
		||||
		s.db.Updates(&apiKey)
 | 
			
		||||
 | 
			
		||||
		// 保存 Base64 图片
 | 
			
		||||
		imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -192,7 +199,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1))
 | 
			
		||||
		s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)})
 | 
			
		||||
		s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params), Prompt: task.Params.Prompt})
 | 
			
		||||
		errChan <- nil
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
@@ -206,17 +213,17 @@ func (s *Service) Txt2Img(task types.SdTask) error {
 | 
			
		||||
 | 
			
		||||
			// task finished
 | 
			
		||||
			s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
 | 
			
		||||
			s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished})
 | 
			
		||||
			s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
 | 
			
		||||
			// 从 leveldb 中删除预览图片数据
 | 
			
		||||
			_ = s.leveldb.Delete(task.Params.TaskId)
 | 
			
		||||
			return nil
 | 
			
		||||
		default:
 | 
			
		||||
			err, resp := s.checkTaskProgress()
 | 
			
		||||
			err, resp := s.checkTaskProgress(apiKey)
 | 
			
		||||
			// 更新任务进度
 | 
			
		||||
			if err == nil && resp.Progress > 0 {
 | 
			
		||||
				s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
 | 
			
		||||
				// 发送更新状态信号
 | 
			
		||||
				s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running})
 | 
			
		||||
				s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
 | 
			
		||||
				// 保存预览图片数据
 | 
			
		||||
				if resp.CurrentImage != "" {
 | 
			
		||||
					_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
 | 
			
		||||
@@ -229,11 +236,11 @@ func (s *Service) Txt2Img(task types.SdTask) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行任务
 | 
			
		||||
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
 | 
			
		||||
func (s *Service) checkTaskProgress(apiKey model.ApiKey) (error, *TaskProgressResp) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", apiKey.ApiURL)
 | 
			
		||||
	var res TaskProgressResp
 | 
			
		||||
	response, err := s.httpClient.R().
 | 
			
		||||
		SetHeader("Authorization", s.config.ApiKey).
 | 
			
		||||
		SetHeader("Authorization", apiKey.Value).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		Get(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -245,3 +252,54 @@ func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
 | 
			
		||||
 | 
			
		||||
	return nil, &res
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) PushTask(task types.SdTask) {
 | 
			
		||||
	logger.Debugf("add a new MidJourney task to the task list: %+v", task)
 | 
			
		||||
	s.taskQueue.RPush(task)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) CheckTaskNotify() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running Stable-Diffusion task notify checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var message service.NotifyMessage
 | 
			
		||||
			err := s.notifyQueue.LPop(&message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			client := s.Clients.Get(uint(message.UserId))
 | 
			
		||||
			if client == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			err = client.Send([]byte(message.Message))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
 | 
			
		||||
func (s *Service) CheckTaskStatus() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running Stable-Diffusion task status checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var jobs []model.SdJob
 | 
			
		||||
			res := s.db.Where("progress < ?", 100).Find(&jobs)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				time.Sleep(5 * time.Second)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, job := range jobs {
 | 
			
		||||
				// 5 分钟还没完成的任务标记为失败
 | 
			
		||||
				if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
 | 
			
		||||
					job.Progress = service.FailTaskProgress
 | 
			
		||||
					job.ErrMsg = "任务超时"
 | 
			
		||||
					s.db.Updates(&job)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			time.Sleep(time.Second * 5)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,24 +0,0 @@
 | 
			
		||||
package sd
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import logger2 "geekai/logger"
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
type NotifyMessage struct {
 | 
			
		||||
	UserId  int    `json:"user_id"`
 | 
			
		||||
	JobId   int    `json:"job_id"`
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	Running  = "RUNNING"
 | 
			
		||||
	Finished = "FINISH"
 | 
			
		||||
	Failed   = "FAIL"
 | 
			
		||||
)
 | 
			
		||||
@@ -28,8 +28,8 @@ func NewSmtpService(appConfig *types.AppConfig) *SmtpService {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *SmtpService) SendVerifyCode(to string, code int) error {
 | 
			
		||||
	subject := "Geek-AI 注册验证码"
 | 
			
		||||
	body := fmt.Sprintf("您正在注册 Geek-AI 助手账户,注册验证码为 %d,请不要告诉他人。如非本人操作,请忽略此邮件。", code)
 | 
			
		||||
	subject := fmt.Sprintf("%s 注册验证码", s.config.AppName)
 | 
			
		||||
	body := fmt.Sprintf("您正在注册 %s 账户,注册验证码为 %d,请不要告诉他人。如非本人操作,请忽略此邮件。", s.config.AppName, code)
 | 
			
		||||
 | 
			
		||||
	auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host)
 | 
			
		||||
	if s.config.UseTls {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										456
									
								
								api/service/suno/service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										456
									
								
								api/service/suno/service.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,456 @@
 | 
			
		||||
package suno
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"io"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
type Service struct {
 | 
			
		||||
	httpClient    *req.Client
 | 
			
		||||
	db            *gorm.DB
 | 
			
		||||
	uploadManager *oss.UploaderManager
 | 
			
		||||
	taskQueue     *store.RedisQueue
 | 
			
		||||
	notifyQueue   *store.RedisQueue
 | 
			
		||||
	Clients       *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
 | 
			
		||||
	return &Service{
 | 
			
		||||
		httpClient:    req.C().SetTimeout(time.Minute * 3),
 | 
			
		||||
		db:            db,
 | 
			
		||||
		taskQueue:     store.NewRedisQueue("Suno_Task_Queue", redisCli),
 | 
			
		||||
		notifyQueue:   store.NewRedisQueue("Suno_Notify_Queue", redisCli),
 | 
			
		||||
		Clients:       types.NewLMap[uint, *types.WsClient](),
 | 
			
		||||
		uploadManager: manager,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) PushTask(task types.SunoTask) {
 | 
			
		||||
	logger.Infof("add a new Suno task to the task list: %+v", task)
 | 
			
		||||
	s.taskQueue.RPush(task)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Run() {
 | 
			
		||||
	// 将数据库中未提交的人物加载到队列
 | 
			
		||||
	var jobs []model.SunoJob
 | 
			
		||||
	s.db.Where("task_id", "").Find(&jobs)
 | 
			
		||||
	for _, v := range jobs {
 | 
			
		||||
		s.PushTask(types.SunoTask{
 | 
			
		||||
			Id:           v.Id,
 | 
			
		||||
			Channel:      v.Channel,
 | 
			
		||||
			UserId:       v.UserId,
 | 
			
		||||
			Type:         v.Type,
 | 
			
		||||
			Title:        v.Title,
 | 
			
		||||
			RefTaskId:    v.RefTaskId,
 | 
			
		||||
			RefSongId:    v.RefSongId,
 | 
			
		||||
			Prompt:       v.Prompt,
 | 
			
		||||
			Tags:         v.Tags,
 | 
			
		||||
			Model:        v.ModelName,
 | 
			
		||||
			Instrumental: v.Instrumental,
 | 
			
		||||
			ExtendSecs:   v.ExtendSecs,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	logger.Info("Starting Suno job consumer...")
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			var task types.SunoTask
 | 
			
		||||
			err := s.taskQueue.LPop(&task)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			var r RespVo
 | 
			
		||||
			if task.Type == 3 && task.SongId != "" { // 歌曲拼接
 | 
			
		||||
				r, err = s.Merge(task)
 | 
			
		||||
			} else if task.Type == 4 && task.AudioURL != "" { // 上传歌曲
 | 
			
		||||
				r, err = s.Upload(task)
 | 
			
		||||
			} else { // 歌曲创作
 | 
			
		||||
				r, err = s.Create(task)
 | 
			
		||||
			}
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Errorf("create task with error: %v", err)
 | 
			
		||||
				s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
					"err_msg":  err.Error(),
 | 
			
		||||
					"progress": service.FailTaskProgress,
 | 
			
		||||
				})
 | 
			
		||||
				s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 更新任务信息
 | 
			
		||||
			s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
				"task_id": r.Data,
 | 
			
		||||
				"channel": r.Channel,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RespVo struct {
 | 
			
		||||
	Code    string `json:"code"`
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
	Data    string `json:"data"`
 | 
			
		||||
	Channel string `json:"channel,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Create(task types.SunoTask) (RespVo, error) {
 | 
			
		||||
	// 读取 API KEY
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
 | 
			
		||||
	if task.Channel != "" {
 | 
			
		||||
		session = session.Where("api_url", task.Channel)
 | 
			
		||||
	}
 | 
			
		||||
	tx := session.Order("last_used_at DESC").First(&apiKey)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return RespVo{}, errors.New("no available API KEY for Suno")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	reqBody := map[string]interface{}{
 | 
			
		||||
		"task_id":           task.RefTaskId,
 | 
			
		||||
		"continue_clip_id":  task.RefSongId,
 | 
			
		||||
		"continue_at":       task.ExtendSecs,
 | 
			
		||||
		"make_instrumental": task.Instrumental,
 | 
			
		||||
	}
 | 
			
		||||
	// 灵感模式
 | 
			
		||||
	if task.Type == 1 {
 | 
			
		||||
		reqBody["gpt_description_prompt"] = task.Prompt
 | 
			
		||||
	} else { // 自定义模式
 | 
			
		||||
		reqBody["prompt"] = task.Prompt
 | 
			
		||||
		reqBody["tags"] = task.Tags
 | 
			
		||||
		reqBody["mv"] = task.Model
 | 
			
		||||
		reqBody["title"] = task.Title
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var res RespVo
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/suno/submit/music", apiKey.ApiURL)
 | 
			
		||||
	logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
			
		||||
		SetBody(reqBody).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body, _ := io.ReadAll(r.Body)
 | 
			
		||||
	err = json.Unmarshal(body, &res)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != "success" {
 | 
			
		||||
		return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
 | 
			
		||||
	}
 | 
			
		||||
	// update the last_use_at for api key
 | 
			
		||||
	apiKey.LastUsedAt = time.Now().Unix()
 | 
			
		||||
	session.Updates(&apiKey)
 | 
			
		||||
	res.Channel = apiKey.ApiURL
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Merge(task types.SunoTask) (RespVo, error) {
 | 
			
		||||
	// 读取 API KEY
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
 | 
			
		||||
	if task.Channel != "" {
 | 
			
		||||
		session = session.Where("api_url", task.Channel)
 | 
			
		||||
	}
 | 
			
		||||
	tx := session.Order("last_used_at DESC").First(&apiKey)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return RespVo{}, errors.New("no available API KEY for Suno")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	reqBody := map[string]interface{}{
 | 
			
		||||
		"clip_id":   task.SongId,
 | 
			
		||||
		"is_infill": false,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var res RespVo
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/suno/submit/concat", apiKey.ApiURL)
 | 
			
		||||
	logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
			
		||||
		SetBody(reqBody).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body, _ := io.ReadAll(r.Body)
 | 
			
		||||
	err = json.Unmarshal(body, &res)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != "success" {
 | 
			
		||||
		return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
 | 
			
		||||
	}
 | 
			
		||||
	// update the last_use_at for api key
 | 
			
		||||
	apiKey.LastUsedAt = time.Now().Unix()
 | 
			
		||||
	session.Updates(&apiKey)
 | 
			
		||||
	res.Channel = apiKey.ApiURL
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Upload(task types.SunoTask) (RespVo, error) {
 | 
			
		||||
	// 读取 API KEY
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
 | 
			
		||||
	if task.Channel != "" {
 | 
			
		||||
		session = session.Where("api_url", task.Channel)
 | 
			
		||||
	}
 | 
			
		||||
	tx := session.Order("last_used_at DESC").First(&apiKey)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return RespVo{}, errors.New("no available API KEY for Suno")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	reqBody := map[string]interface{}{
 | 
			
		||||
		"url": task.AudioURL,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var res RespVo
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/suno/uploads/audio-url", apiKey.ApiURL)
 | 
			
		||||
	logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
			
		||||
		SetBody(reqBody).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.StatusCode != 200 {
 | 
			
		||||
		return RespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body, _ := io.ReadAll(r.Body)
 | 
			
		||||
	err = json.Unmarshal(body, &res)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != "success" {
 | 
			
		||||
		return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
 | 
			
		||||
	}
 | 
			
		||||
	// update the last_use_at for api key
 | 
			
		||||
	apiKey.LastUsedAt = time.Now().Unix()
 | 
			
		||||
	session.Updates(&apiKey)
 | 
			
		||||
	res.Channel = apiKey.ApiURL
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) CheckTaskNotify() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running Suno task notify checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var message service.NotifyMessage
 | 
			
		||||
			err := s.notifyQueue.LPop(&message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			client := s.Clients.Get(uint(message.UserId))
 | 
			
		||||
			if client == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			err = client.Send([]byte(message.Message))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) DownloadFiles() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		var items []model.SunoJob
 | 
			
		||||
		for {
 | 
			
		||||
			res := s.db.Where("progress", 102).Find(&items)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, v := range items {
 | 
			
		||||
				// 下载图片和音频
 | 
			
		||||
				logger.Infof("try download cover image: %s", v.CoverURL)
 | 
			
		||||
				coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Errorf("download image with error: %v", err)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				logger.Infof("try download audio: %s", v.AudioURL)
 | 
			
		||||
				audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Errorf("download audio with error: %v", err)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				v.CoverURL = coverURL
 | 
			
		||||
				v.AudioURL = audioURL
 | 
			
		||||
				v.Progress = 100
 | 
			
		||||
				s.db.Updates(&v)
 | 
			
		||||
				s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			time.Sleep(time.Second * 10)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SyncTaskProgress 异步拉取任务
 | 
			
		||||
func (s *Service) SyncTaskProgress() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		var jobs []model.SunoJob
 | 
			
		||||
		for {
 | 
			
		||||
			res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, job := range jobs {
 | 
			
		||||
				task, err := s.QueryTask(job.TaskId, job.Channel)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Errorf("query task with error: %v", err)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if task.Code != "success" {
 | 
			
		||||
					logger.Errorf("query task with error: %v", task.Message)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				logger.Debugf("task: %+v", task.Data.Status)
 | 
			
		||||
				// 任务完成,删除旧任务插入两条新任务
 | 
			
		||||
				if task.Data.Status == "SUCCESS" {
 | 
			
		||||
					var jobId = job.Id
 | 
			
		||||
					var flag = false
 | 
			
		||||
					tx := s.db.Begin()
 | 
			
		||||
					for _, v := range task.Data.Data {
 | 
			
		||||
						job.Id = 0
 | 
			
		||||
						job.Progress = 102 // 102 表示资源未下载完成
 | 
			
		||||
						job.Title = v.Title
 | 
			
		||||
						job.SongId = v.Id
 | 
			
		||||
						job.Duration = int(v.Metadata.Duration)
 | 
			
		||||
						job.Prompt = v.Metadata.Prompt
 | 
			
		||||
						job.Tags = v.Metadata.Tags
 | 
			
		||||
						job.ModelName = v.ModelName
 | 
			
		||||
						job.RawData = utils.JsonEncode(v)
 | 
			
		||||
						job.CoverURL = v.ImageLargeUrl
 | 
			
		||||
						job.AudioURL = v.AudioUrl
 | 
			
		||||
 | 
			
		||||
						if err = tx.Create(&job).Error; err != nil {
 | 
			
		||||
							logger.Error("create job with error: %v", err)
 | 
			
		||||
							tx.Rollback()
 | 
			
		||||
							break
 | 
			
		||||
						}
 | 
			
		||||
						flag = true
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					// 删除旧任务
 | 
			
		||||
					if flag {
 | 
			
		||||
						if err = tx.Delete(&model.SunoJob{}, "id = ?", jobId).Error; err != nil {
 | 
			
		||||
							logger.Error("create job with error: %v", err)
 | 
			
		||||
							tx.Rollback()
 | 
			
		||||
							continue
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
					tx.Commit()
 | 
			
		||||
 | 
			
		||||
				} else if task.Data.FailReason != "" {
 | 
			
		||||
					job.Progress = service.FailTaskProgress
 | 
			
		||||
					job.ErrMsg = task.Data.FailReason
 | 
			
		||||
					s.db.Updates(&job)
 | 
			
		||||
					s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			time.Sleep(time.Second * 10)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type QueryRespVo struct {
 | 
			
		||||
	Code    string `json:"code"`
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
	Data    struct {
 | 
			
		||||
		TaskId     string `json:"task_id"`
 | 
			
		||||
		Action     string `json:"action"`
 | 
			
		||||
		Status     string `json:"status"`
 | 
			
		||||
		FailReason string `json:"fail_reason"`
 | 
			
		||||
		SubmitTime int    `json:"submit_time"`
 | 
			
		||||
		StartTime  int    `json:"start_time"`
 | 
			
		||||
		FinishTime int    `json:"finish_time"`
 | 
			
		||||
		Progress   string `json:"progress"`
 | 
			
		||||
		Data       []struct {
 | 
			
		||||
			Id       string `json:"id"`
 | 
			
		||||
			Title    string `json:"title"`
 | 
			
		||||
			Status   string `json:"status"`
 | 
			
		||||
			Metadata struct {
 | 
			
		||||
				Tags         string      `json:"tags"`
 | 
			
		||||
				Type         string      `json:"type"`
 | 
			
		||||
				Prompt       string      `json:"prompt"`
 | 
			
		||||
				Stream       bool        `json:"stream"`
 | 
			
		||||
				Duration     float64     `json:"duration"`
 | 
			
		||||
				ErrorMessage interface{} `json:"error_message"`
 | 
			
		||||
			} `json:"metadata"`
 | 
			
		||||
			AudioUrl          string `json:"audio_url"`
 | 
			
		||||
			ImageUrl          string `json:"image_url"`
 | 
			
		||||
			VideoUrl          string `json:"video_url"`
 | 
			
		||||
			ModelName         string `json:"model_name"`
 | 
			
		||||
			DisplayName       string `json:"display_name"`
 | 
			
		||||
			ImageLargeUrl     string `json:"image_large_url"`
 | 
			
		||||
			MajorModelVersion string `json:"major_model_version"`
 | 
			
		||||
		} `json:"data"`
 | 
			
		||||
	} `json:"data"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) {
 | 
			
		||||
	// 读取 API KEY
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	err := s.db.Session(&gorm.Session{}).Where("type", "suno").
 | 
			
		||||
		Where("api_url", channel).
 | 
			
		||||
		Where("enabled", true).
 | 
			
		||||
		Order("last_used_at DESC").First(&apiKey).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return QueryRespVo{}, errors.New("no available API KEY for Suno")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/suno/fetch/%s", apiKey.ApiURL, taskId)
 | 
			
		||||
	var res QueryRespVo
 | 
			
		||||
	r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return QueryRespVo{}, fmt.Errorf("请求 API 失败:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer r.Body.Close()
 | 
			
		||||
	body, _ := io.ReadAll(r.Body)
 | 
			
		||||
	err = json.Unmarshal(body, &res)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return QueryRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
@@ -1,4 +1,17 @@
 | 
			
		||||
package service
 | 
			
		||||
 | 
			
		||||
const FailTaskProgress = 101
 | 
			
		||||
const (
 | 
			
		||||
	TaskStatusRunning  = "RUNNING"
 | 
			
		||||
	TaskStatusFinished = "FINISH"
 | 
			
		||||
	TaskStatusFailed   = "FAIL"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type NotifyMessage struct {
 | 
			
		||||
	UserId  int    `json:"user_id"`
 | 
			
		||||
	JobId   int    `json:"job_id"`
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]"
 | 
			
		||||
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										83
									
								
								api/service/user_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								api/service/user_service.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,83 @@
 | 
			
		||||
package service
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type UserService struct {
 | 
			
		||||
	db   *gorm.DB
 | 
			
		||||
	lock sync.Mutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewUserService(db *gorm.DB) *UserService {
 | 
			
		||||
	return &UserService{db: db, lock: sync.Mutex{}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IncreasePower 增加用户算力
 | 
			
		||||
func (s *UserService) IncreasePower(userId int, power int, log model.PowerLog) error {
 | 
			
		||||
	s.lock.Lock()
 | 
			
		||||
	defer s.lock.Unlock()
 | 
			
		||||
 | 
			
		||||
	tx := s.db.Begin()
 | 
			
		||||
	err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power + ?", power)).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	var user model.User
 | 
			
		||||
	tx.Where("id", userId).First(&user)
 | 
			
		||||
	err = tx.Create(&model.PowerLog{
 | 
			
		||||
		UserId:    user.Id,
 | 
			
		||||
		Username:  user.Username,
 | 
			
		||||
		Type:      log.Type,
 | 
			
		||||
		Amount:    power,
 | 
			
		||||
		Balance:   user.Power,
 | 
			
		||||
		Mark:      types.PowerAdd,
 | 
			
		||||
		Model:     log.Model,
 | 
			
		||||
		Remark:    log.Remark,
 | 
			
		||||
		CreatedAt: time.Now(),
 | 
			
		||||
	}).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	tx.Commit()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DecreasePower 减少用户算力
 | 
			
		||||
func (s *UserService) DecreasePower(userId int, power int, log model.PowerLog) error {
 | 
			
		||||
	s.lock.Lock()
 | 
			
		||||
	defer s.lock.Unlock()
 | 
			
		||||
 | 
			
		||||
	tx := s.db.Begin()
 | 
			
		||||
	err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		return fmt.Errorf("扣减算力失败:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
	var user model.User
 | 
			
		||||
	tx.Where("id", userId).First(&user)
 | 
			
		||||
	err = tx.Create(&model.PowerLog{
 | 
			
		||||
		UserId:    user.Id,
 | 
			
		||||
		Username:  user.Username,
 | 
			
		||||
		Type:      log.Type,
 | 
			
		||||
		Amount:    power,
 | 
			
		||||
		Balance:   user.Power,
 | 
			
		||||
		Mark:      types.PowerSub,
 | 
			
		||||
		Model:     log.Model,
 | 
			
		||||
		Remark:    log.Remark,
 | 
			
		||||
		CreatedAt: time.Now(),
 | 
			
		||||
	}).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		return fmt.Errorf("记录算力日志失败:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
	tx.Commit()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										330
									
								
								api/service/video/luma.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										330
									
								
								api/service/video/luma.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,330 @@
 | 
			
		||||
package video
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"io"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
type Service struct {
 | 
			
		||||
	httpClient    *req.Client
 | 
			
		||||
	db            *gorm.DB
 | 
			
		||||
	uploadManager *oss.UploaderManager
 | 
			
		||||
	taskQueue     *store.RedisQueue
 | 
			
		||||
	notifyQueue   *store.RedisQueue
 | 
			
		||||
	Clients       *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
 | 
			
		||||
	return &Service{
 | 
			
		||||
		httpClient:    req.C().SetTimeout(time.Minute * 3),
 | 
			
		||||
		db:            db,
 | 
			
		||||
		taskQueue:     store.NewRedisQueue("Video_Task_Queue", redisCli),
 | 
			
		||||
		notifyQueue:   store.NewRedisQueue("Video_Notify_Queue", redisCli),
 | 
			
		||||
		Clients:       types.NewLMap[uint, *types.WsClient](),
 | 
			
		||||
		uploadManager: manager,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) PushTask(task types.VideoTask) {
 | 
			
		||||
	logger.Infof("add a new Video task to the task list: %+v", task)
 | 
			
		||||
	s.taskQueue.RPush(task)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Run() {
 | 
			
		||||
	// 将数据库中未提交的人物加载到队列
 | 
			
		||||
	var jobs []model.VideoJob
 | 
			
		||||
	s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
 | 
			
		||||
	for _, v := range jobs {
 | 
			
		||||
		var params types.VideoParams
 | 
			
		||||
		if err := utils.JsonDecode(v.Params, ¶ms); err != nil {
 | 
			
		||||
			logger.Errorf("unmarshal params failed: %v", err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		s.PushTask(types.VideoTask{
 | 
			
		||||
			Id:      v.Id,
 | 
			
		||||
			Channel: v.Channel,
 | 
			
		||||
			UserId:  v.UserId,
 | 
			
		||||
			Type:    v.Type,
 | 
			
		||||
			TaskId:  v.TaskId,
 | 
			
		||||
			Prompt:  v.Prompt,
 | 
			
		||||
			Params:  params,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	logger.Info("Starting Video job consumer...")
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			var task types.VideoTask
 | 
			
		||||
			err := s.taskQueue.LPop(&task)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			var r LumaRespVo
 | 
			
		||||
			r, err = s.LumaCreate(task)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Errorf("create task with error: %v", err)
 | 
			
		||||
				err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
					"err_msg":   err.Error(),
 | 
			
		||||
					"progress":  service.FailTaskProgress,
 | 
			
		||||
					"cover_url": "/images/failed.jpg",
 | 
			
		||||
				}).Error
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Errorf("update task with error: %v", err)
 | 
			
		||||
				}
 | 
			
		||||
				s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 更新任务信息
 | 
			
		||||
			err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
				"task_id":    r.Id,
 | 
			
		||||
				"channel":    r.Channel,
 | 
			
		||||
				"prompt_ext": r.Prompt,
 | 
			
		||||
			}).Error
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Errorf("update task with error: %v", err)
 | 
			
		||||
				s.PushTask(task)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type LumaRespVo struct {
 | 
			
		||||
	Id                  string      `json:"id"`
 | 
			
		||||
	Prompt              string      `json:"prompt"`
 | 
			
		||||
	State               string      `json:"state"`
 | 
			
		||||
	CreatedAt           time.Time   `json:"created_at"`
 | 
			
		||||
	Video               interface{} `json:"video"`
 | 
			
		||||
	Liked               interface{} `json:"liked"`
 | 
			
		||||
	EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
 | 
			
		||||
	Channel             string      `json:"channel,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
 | 
			
		||||
	// 读取 API KEY
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
 | 
			
		||||
	if task.Channel != "" {
 | 
			
		||||
		session = session.Where("api_url", task.Channel)
 | 
			
		||||
	}
 | 
			
		||||
	tx := session.Order("last_used_at DESC").First(&apiKey)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return LumaRespVo{}, errors.New("no available API KEY for Luma")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	reqBody := map[string]interface{}{
 | 
			
		||||
		"user_prompt":   task.Prompt,
 | 
			
		||||
		"expand_prompt": task.Params.PromptOptimize,
 | 
			
		||||
		"loop":          task.Params.Loop,
 | 
			
		||||
		"image_url":     task.Params.StartImgURL,
 | 
			
		||||
		"image_end_url": task.Params.EndImgURL,
 | 
			
		||||
	}
 | 
			
		||||
	var res LumaRespVo
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
 | 
			
		||||
	logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
			
		||||
		SetBody(reqBody).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.StatusCode != 200 && r.StatusCode != 201 {
 | 
			
		||||
		return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body, _ := io.ReadAll(r.Body)
 | 
			
		||||
	err = json.Unmarshal(body, &res)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return LumaRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update the last_use_at for api key
 | 
			
		||||
	apiKey.LastUsedAt = time.Now().Unix()
 | 
			
		||||
	session.Updates(&apiKey)
 | 
			
		||||
	res.Channel = apiKey.ApiURL
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) CheckTaskNotify() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running Suno task notify checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var message service.NotifyMessage
 | 
			
		||||
			err := s.notifyQueue.LPop(&message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			client := s.Clients.Get(uint(message.UserId))
 | 
			
		||||
			if client == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			err = client.Send([]byte(message.Message))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) DownloadFiles() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		var items []model.VideoJob
 | 
			
		||||
		for {
 | 
			
		||||
			res := s.db.Where("progress", 102).Find(&items)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, v := range items {
 | 
			
		||||
				if v.WaterURL == "" {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				logger.Infof("try download video: %s", v.WaterURL)
 | 
			
		||||
				videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Errorf("download video with error: %v", err)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				logger.Infof("download video success: %s", videoURL)
 | 
			
		||||
				v.WaterURL = videoURL
 | 
			
		||||
 | 
			
		||||
				if v.VideoURL != "" {
 | 
			
		||||
					logger.Infof("try download no water video: %s", v.VideoURL)
 | 
			
		||||
					videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						logger.Errorf("download video with error: %v", err)
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				logger.Info("download no water video success: %s", videoURL)
 | 
			
		||||
				v.VideoURL = videoURL
 | 
			
		||||
				v.Progress = 100
 | 
			
		||||
				s.db.Updates(&v)
 | 
			
		||||
				s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			time.Sleep(time.Second * 10)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SyncTaskProgress 异步拉取任务
 | 
			
		||||
func (s *Service) SyncTaskProgress() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		var jobs []model.VideoJob
 | 
			
		||||
		for {
 | 
			
		||||
			res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, job := range jobs {
 | 
			
		||||
				task, err := s.QueryLumaTask(job.TaskId, job.Channel)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Errorf("query task with error: %v", err)
 | 
			
		||||
					// 更新任务信息
 | 
			
		||||
					s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
						"progress": service.FailTaskProgress, // 102 表示资源未下载完成,
 | 
			
		||||
						"err_msg":  err.Error(),
 | 
			
		||||
					})
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				logger.Debugf("task: %+v", task)
 | 
			
		||||
				if task.State == "completed" { // 更新任务信息
 | 
			
		||||
					data := map[string]interface{}{
 | 
			
		||||
						"progress":   102, // 102 表示资源未下载完成,
 | 
			
		||||
						"water_url":  task.Video.Url,
 | 
			
		||||
						"raw_data":   utils.JsonEncode(task),
 | 
			
		||||
						"prompt_ext": task.Prompt,
 | 
			
		||||
					}
 | 
			
		||||
					if task.Video.DownloadUrl != "" {
 | 
			
		||||
						data["video_url"] = task.Video.DownloadUrl
 | 
			
		||||
					}
 | 
			
		||||
					err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						logger.Errorf("更新数据库失败:%v", err)
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			time.Sleep(time.Second * 10)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type LumaTaskVo struct {
 | 
			
		||||
	Id    string      `json:"id"`
 | 
			
		||||
	Liked interface{} `json:"liked"`
 | 
			
		||||
	State string      `json:"state"`
 | 
			
		||||
	Video struct {
 | 
			
		||||
		Url         string `json:"url"`
 | 
			
		||||
		Width       int    `json:"width"`
 | 
			
		||||
		Height      int    `json:"height"`
 | 
			
		||||
		DownloadUrl string `json:"download_url"`
 | 
			
		||||
	} `json:"video"`
 | 
			
		||||
	Prompt              string      `json:"prompt"`
 | 
			
		||||
	CreatedAt           time.Time   `json:"created_at"`
 | 
			
		||||
	EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
 | 
			
		||||
	// 读取 API KEY
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	err := s.db.Session(&gorm.Session{}).Where("type", "luma").
 | 
			
		||||
		Where("api_url", channel).
 | 
			
		||||
		Where("enabled", true).
 | 
			
		||||
		Order("last_used_at DESC").First(&apiKey).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return LumaTaskVo{}, errors.New("no available API KEY for Luma")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
 | 
			
		||||
	var res LumaTaskVo
 | 
			
		||||
	r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
	defer r.Body.Close()
 | 
			
		||||
 | 
			
		||||
	if r.StatusCode != 200 {
 | 
			
		||||
		return LumaTaskVo{}, fmt.Errorf("API 返回失败:%v", r.String())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body, _ := io.ReadAll(r.Body)
 | 
			
		||||
	err = json.Unmarshal(body, &res)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return LumaTaskVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
@@ -1,101 +0,0 @@
 | 
			
		||||
package wx
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"github.com/eatmoreapple/openwechat"
 | 
			
		||||
	"github.com/skip2/go-qrcode"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 微信收款机器人
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
type Bot struct {
 | 
			
		||||
	bot   *openwechat.Bot
 | 
			
		||||
	token string
 | 
			
		||||
	db    *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewWeChatBot(db *gorm.DB) *Bot {
 | 
			
		||||
	bot := openwechat.DefaultBot(openwechat.Desktop)
 | 
			
		||||
	return &Bot{
 | 
			
		||||
		bot: bot,
 | 
			
		||||
		db:  db,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) Run() error {
 | 
			
		||||
	logger.Info("Starting WeChat Bot...")
 | 
			
		||||
 | 
			
		||||
	// set message handler
 | 
			
		||||
	b.bot.MessageHandler = func(msg *openwechat.Message) {
 | 
			
		||||
		b.messageHandler(msg)
 | 
			
		||||
	}
 | 
			
		||||
	// scan code login callback
 | 
			
		||||
	b.bot.UUIDCallback = b.qrCodeCallBack
 | 
			
		||||
	debug, err := strconv.ParseBool(os.Getenv("APP_DEBUG"))
 | 
			
		||||
	if debug {
 | 
			
		||||
		reloadStorage := openwechat.NewJsonFileHotReloadStorage("storage.json")
 | 
			
		||||
		err = b.bot.HotLogin(reloadStorage, true)
 | 
			
		||||
	} else {
 | 
			
		||||
		err = b.bot.Login()
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Info("微信登录成功!")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// message handler
 | 
			
		||||
func (b *Bot) messageHandler(msg *openwechat.Message) {
 | 
			
		||||
	sender, err := msg.Sender()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 只处理微信支付的推送消息
 | 
			
		||||
	if sender.NickName == "微信支付" ||
 | 
			
		||||
		msg.MsgType == openwechat.MsgTypeApp ||
 | 
			
		||||
		msg.AppMsgType == openwechat.AppMsgTypeUrl {
 | 
			
		||||
		// 解析支付金额
 | 
			
		||||
		message := parseTransactionMessage(msg.Content)
 | 
			
		||||
		transaction := extractTransaction(message)
 | 
			
		||||
		logger.Infof("解析到收款信息:%+v", transaction)
 | 
			
		||||
		if transaction.TransId != "" {
 | 
			
		||||
			var item model.Reward
 | 
			
		||||
			res := b.db.Where("tx_id = ?", transaction.TransId).First(&item)
 | 
			
		||||
			if item.Id > 0 {
 | 
			
		||||
				logger.Error("当前交易 ID 己经存在!")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			res = b.db.Create(&model.Reward{
 | 
			
		||||
				TxId:   transaction.TransId,
 | 
			
		||||
				Amount: transaction.Amount,
 | 
			
		||||
				Remark: transaction.Remark,
 | 
			
		||||
				Status: false,
 | 
			
		||||
			})
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Errorf("交易保存失败: %v", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) qrCodeCallBack(uuid string) {
 | 
			
		||||
	logger.Info("请使用微信扫描下面二维码登录")
 | 
			
		||||
	q, _ := qrcode.New("https://login.weixin.qq.com/l/"+uuid, qrcode.Medium)
 | 
			
		||||
	logger.Info(q.ToString(true))
 | 
			
		||||
}
 | 
			
		||||
@@ -1,112 +0,0 @@
 | 
			
		||||
package wx
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/xml"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Message 转账消息
 | 
			
		||||
type Message struct {
 | 
			
		||||
	Des string
 | 
			
		||||
	Url string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Transaction 解析后的交易信息
 | 
			
		||||
type Transaction struct {
 | 
			
		||||
	TransId string  `json:"trans_id"` // 微信转账交易 ID
 | 
			
		||||
	Amount  float64 `json:"amount"`   // 微信转账交易金额
 | 
			
		||||
	Remark  string  `json:"remark"`   // 转账备注
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 解析微信转账消息
 | 
			
		||||
func parseTransactionMessage(xmlData string) *Message {
 | 
			
		||||
	decoder := xml.NewDecoder(strings.NewReader(xmlData))
 | 
			
		||||
	message := Message{}
 | 
			
		||||
	for {
 | 
			
		||||
		token, err := decoder.Token()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		switch se := token.(type) {
 | 
			
		||||
		case xml.StartElement:
 | 
			
		||||
			var value string
 | 
			
		||||
			if se.Name.Local == "des" && message.Des == "" {
 | 
			
		||||
				if err := decoder.DecodeElement(&value, &se); err == nil {
 | 
			
		||||
					message.Des = strings.TrimSpace(value)
 | 
			
		||||
				}
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
			if se.Name.Local == "weapp_path" || se.Name.Local == "url" {
 | 
			
		||||
				if err := decoder.DecodeElement(&value, &se); err == nil {
 | 
			
		||||
					if strings.Contains(value, "?trans_id=") || strings.Contains(value, "?id=") {
 | 
			
		||||
						message.Url = value
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 兼容旧版消息记录
 | 
			
		||||
	if message.Url == "" {
 | 
			
		||||
		var msg struct {
 | 
			
		||||
			XMLName xml.Name `xml:"msg"`
 | 
			
		||||
			AppMsg  struct {
 | 
			
		||||
				Des string `xml:"des"`
 | 
			
		||||
				Url string `xml:"url"`
 | 
			
		||||
			} `xml:"appmsg"`
 | 
			
		||||
		}
 | 
			
		||||
		if err := xml.Unmarshal([]byte(xmlData), &msg); err == nil {
 | 
			
		||||
			message.Url = msg.AppMsg.Url
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return &message
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 导出交易信息
 | 
			
		||||
func extractTransaction(message *Message) Transaction {
 | 
			
		||||
	var tx = Transaction{}
 | 
			
		||||
	// 导出交易金额和备注
 | 
			
		||||
	lines := strings.Split(message.Des, "\n")
 | 
			
		||||
	for _, line := range lines {
 | 
			
		||||
		line = strings.TrimSpace(line)
 | 
			
		||||
		if len(line) == 0 {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		// 解析收款金额
 | 
			
		||||
		prefix := "收款金额¥"
 | 
			
		||||
		if strings.HasPrefix(line, prefix) {
 | 
			
		||||
			if value, err := strconv.ParseFloat(line[len(prefix):], 64); err == nil {
 | 
			
		||||
				tx.Amount = value
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		// 解析收款备注
 | 
			
		||||
		prefix = "付款方备注"
 | 
			
		||||
		if strings.HasPrefix(line, prefix) {
 | 
			
		||||
			tx.Remark = line[len(prefix):]
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 解析交易 ID
 | 
			
		||||
	parse, err := url.Parse(message.Url)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		tx.TransId = parse.Query().Get("id")
 | 
			
		||||
		if tx.TransId == "" {
 | 
			
		||||
			tx.TransId = parse.Query().Get("trans_id")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return tx
 | 
			
		||||
}
 | 
			
		||||
@@ -81,51 +81,6 @@ func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (ms
 | 
			
		||||
// 自动将 VIP 会员的算力补充到每月赠送的最大值
 | 
			
		||||
func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (msg string) {
 | 
			
		||||
	logger.Info("开始进行月底账号盘点...")
 | 
			
		||||
	var users []model.User
 | 
			
		||||
	res := e.db.Where("vip", 1).Where("status", 1).Find(&users)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return "No vip users found"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var sysConfig model.Config
 | 
			
		||||
	res = e.db.Where("marker", "system").First(&sysConfig)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return "error with get system config: " + res.Error.Error()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var config types.SystemConfig
 | 
			
		||||
	err := utils.JsonDecode(sysConfig.Config, &config)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "error with decode system config: " + err.Error()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, u := range users {
 | 
			
		||||
		// 处理过期的 VIP
 | 
			
		||||
		if u.ExpiredTime > 0 && u.ExpiredTime <= time.Now().Unix() {
 | 
			
		||||
			u.Vip = false
 | 
			
		||||
			e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("vip", false)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		// update user
 | 
			
		||||
		tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", config.VipMonthPower))
 | 
			
		||||
		// 记录算力变动日志
 | 
			
		||||
		if tx.Error == nil {
 | 
			
		||||
			var user model.User
 | 
			
		||||
			e.db.Where("id", u.Id).First(&user)
 | 
			
		||||
			e.db.Create(&model.PowerLog{
 | 
			
		||||
				UserId:    u.Id,
 | 
			
		||||
				Username:  u.Username,
 | 
			
		||||
				Type:      types.PowerRecharge,
 | 
			
		||||
				Amount:    config.VipMonthPower,
 | 
			
		||||
				Mark:      types.PowerAdd,
 | 
			
		||||
				Balance:   user.Power,
 | 
			
		||||
				Model:     "系统盘点",
 | 
			
		||||
				Remark:    fmt.Sprintf("VIP会员每月算力派发,:%d", config.VipMonthPower),
 | 
			
		||||
				CreatedAt: time.Now(),
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	logger.Info("月底盘点完成!")
 | 
			
		||||
	return "success"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -29,15 +29,9 @@ func NewLevelDB() (*LevelDB, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *LevelDB) Put(key string, value interface{}) error {
 | 
			
		||||
	var byteData []byte
 | 
			
		||||
	if v, ok := value.(string); ok {
 | 
			
		||||
		byteData = []byte(v)
 | 
			
		||||
	} else {
 | 
			
		||||
		b, err := json.Marshal(value)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		byteData = b
 | 
			
		||||
	byteData, err := json.Marshal(value)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return db.driver.Put([]byte(key), byteData, nil)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,6 @@ package model
 | 
			
		||||
// ApiKey OpenAI API 模型
 | 
			
		||||
type ApiKey struct {
 | 
			
		||||
	BaseModel
 | 
			
		||||
	Platform   string
 | 
			
		||||
	Name       string
 | 
			
		||||
	Type       string // 用途 chat => 聊天,img => 绘图
 | 
			
		||||
	Value      string // API Key 的值
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,6 @@ package model
 | 
			
		||||
 | 
			
		||||
type ChatModel struct {
 | 
			
		||||
	BaseModel
 | 
			
		||||
	Platform    string
 | 
			
		||||
	Name        string
 | 
			
		||||
	Value       string // API Key 的值
 | 
			
		||||
	SortNum     int
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										16
									
								
								api/store/model/redeem.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								api/store/model/redeem.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,16 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import "time"
 | 
			
		||||
 | 
			
		||||
// 兑换码
 | 
			
		||||
 | 
			
		||||
type Redeem struct {
 | 
			
		||||
	Id         uint   `gorm:"primarykey;column:id"`
 | 
			
		||||
	UserId     uint   // 用户 ID
 | 
			
		||||
	Name       string // 名称
 | 
			
		||||
	Power      int    // 算力
 | 
			
		||||
	Code       string // 兑换码
 | 
			
		||||
	Enabled    bool   // 启用状态
 | 
			
		||||
	RedeemedAt int64  // 兑换时间
 | 
			
		||||
	CreatedAt  time.Time
 | 
			
		||||
}
 | 
			
		||||
@@ -1,13 +0,0 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
// 用户打赏
 | 
			
		||||
 | 
			
		||||
type Reward struct {
 | 
			
		||||
	BaseModel
 | 
			
		||||
	UserId   uint    // 用户 ID
 | 
			
		||||
	TxId     string  // 交易ID
 | 
			
		||||
	Amount   float64 // 打赏金额
 | 
			
		||||
	Remark   string  // 打赏备注
 | 
			
		||||
	Status   bool    // 核销状态
 | 
			
		||||
	Exchange string  // 众筹兑换详情,JSON
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										34
									
								
								api/store/model/suno_job.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								api/store/model/suno_job.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,34 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import "time"
 | 
			
		||||
 | 
			
		||||
type SunoJob struct {
 | 
			
		||||
	Id           uint `gorm:"primarykey;column:id"`
 | 
			
		||||
	UserId       int
 | 
			
		||||
	Channel      string // 频道
 | 
			
		||||
	Title        string
 | 
			
		||||
	Type         int
 | 
			
		||||
	TaskId       string
 | 
			
		||||
	RefTaskId    string // 续写的任务id
 | 
			
		||||
	Tags         string // 歌曲风格和标签
 | 
			
		||||
	Instrumental bool   // 是否生成纯音乐
 | 
			
		||||
	ExtendSecs   int    // 续写秒数
 | 
			
		||||
	SongId       string // 续写的歌曲id
 | 
			
		||||
	RefSongId    string
 | 
			
		||||
	Prompt       string // 提示词
 | 
			
		||||
	CoverURL     string // 封面图 URL
 | 
			
		||||
	AudioURL     string // 音频 URL
 | 
			
		||||
	ModelName    string // 模型名称
 | 
			
		||||
	Progress     int    // 任务进度
 | 
			
		||||
	Duration     int    // 银屏时长,秒
 | 
			
		||||
	Publish      bool   // 是否发布
 | 
			
		||||
	ErrMsg       string // 错误信息
 | 
			
		||||
	RawData      string // 原始数据 json
 | 
			
		||||
	Power        int    // 消耗算力
 | 
			
		||||
	PlayTimes    int    // 播放次数
 | 
			
		||||
	CreatedAt    time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (SunoJob) TableName() string {
 | 
			
		||||
	return "chatgpt_suno_jobs"
 | 
			
		||||
}
 | 
			
		||||
@@ -4,6 +4,8 @@ type User struct {
 | 
			
		||||
	BaseModel
 | 
			
		||||
	Username    string
 | 
			
		||||
	Nickname    string
 | 
			
		||||
	Email       string
 | 
			
		||||
	Mobile      string
 | 
			
		||||
	Password    string
 | 
			
		||||
	Avatar      string
 | 
			
		||||
	Salt        string // 密码盐
 | 
			
		||||
@@ -15,5 +17,7 @@ type User struct {
 | 
			
		||||
	Status      bool   `gorm:"default:true"` // 当前状态
 | 
			
		||||
	LastLoginAt int64  // 最后登录时间
 | 
			
		||||
	LastLoginIp string // 最后登录 IP
 | 
			
		||||
	OpenId      string `gorm:"column:openid"`
 | 
			
		||||
	Platform    string `json:"platform"`
 | 
			
		||||
	Vip         bool   // 是否 VIP 会员
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										27
									
								
								api/store/model/video_job.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								api/store/model/video_job.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import "time"
 | 
			
		||||
 | 
			
		||||
type VideoJob struct {
 | 
			
		||||
	Id        uint `gorm:"primarykey;column:id"`
 | 
			
		||||
	UserId    int
 | 
			
		||||
	Channel   string // 频道
 | 
			
		||||
	Type      string // luma,runway,cog
 | 
			
		||||
	TaskId    string
 | 
			
		||||
	Prompt    string // 提示词
 | 
			
		||||
	PromptExt string // 优化后提示词
 | 
			
		||||
	CoverURL  string // 封面图 URL
 | 
			
		||||
	VideoURL  string // 无水印视频 URL
 | 
			
		||||
	WaterURL  string // 有水印视频 URL
 | 
			
		||||
	Progress  int    // 任务进度
 | 
			
		||||
	Publish   bool   // 是否发布
 | 
			
		||||
	ErrMsg    string // 错误信息
 | 
			
		||||
	RawData   string // 原始数据 json
 | 
			
		||||
	Power     int    // 消耗算力
 | 
			
		||||
	Params    string // 任务参数
 | 
			
		||||
	CreatedAt time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (VideoJob) TableName() string {
 | 
			
		||||
	return "chatgpt_video_jobs"
 | 
			
		||||
}
 | 
			
		||||
@@ -3,7 +3,6 @@ package vo
 | 
			
		||||
// ApiKey OpenAI API 模型
 | 
			
		||||
type ApiKey struct {
 | 
			
		||||
	BaseVo
 | 
			
		||||
	Platform   string `json:"platform"`
 | 
			
		||||
	Name       string `json:"name"`
 | 
			
		||||
	Type       string `json:"type"`
 | 
			
		||||
	Value      string `json:"value"` // API Key 的值
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,6 @@ package vo
 | 
			
		||||
 | 
			
		||||
type ChatModel struct {
 | 
			
		||||
	BaseVo
 | 
			
		||||
	Platform    string  `json:"platform"`
 | 
			
		||||
	Name        string  `json:"name"`
 | 
			
		||||
	Value       string  `json:"value"`
 | 
			
		||||
	Enabled     bool    `json:"enabled"`
 | 
			
		||||
@@ -12,6 +11,6 @@ type ChatModel struct {
 | 
			
		||||
	MaxTokens   int     `json:"max_tokens"`  // 最大响应长度
 | 
			
		||||
	MaxContext  int     `json:"max_context"` // 最大上下文长度
 | 
			
		||||
	Temperature float32 `json:"temperature"` // 模型温度
 | 
			
		||||
	KeyId       int     `json:"key_id"`
 | 
			
		||||
	KeyId       int     `json:"key_id,omitempty"`
 | 
			
		||||
	KeyName     string  `json:"key_name"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										13
									
								
								api/store/vo/redeem.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								api/store/vo/redeem.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,13 @@
 | 
			
		||||
package vo
 | 
			
		||||
 | 
			
		||||
type Redeem struct {
 | 
			
		||||
	Id         uint   `json:"id"`
 | 
			
		||||
	UserId     uint   `json:"user_id"` // 用户 ID
 | 
			
		||||
	Name       string `json:"name"`
 | 
			
		||||
	Username   string `json:"username"`
 | 
			
		||||
	Power      int    `json:"power"` // 算力
 | 
			
		||||
	Code       string `json:"code"`  // 兑换码
 | 
			
		||||
	Enabled    bool   `json:"enabled"`
 | 
			
		||||
	RedeemedAt int64  `json:"redeemed_at"` // 兑换时间
 | 
			
		||||
	CreatedAt  int64  `json:"created_at"`
 | 
			
		||||
}
 | 
			
		||||
@@ -1,16 +0,0 @@
 | 
			
		||||
package vo
 | 
			
		||||
 | 
			
		||||
type Reward struct {
 | 
			
		||||
	BaseVo
 | 
			
		||||
	UserId   uint           `json:"user_id"` // 用户 ID
 | 
			
		||||
	Username string         `json:"username"`
 | 
			
		||||
	TxId     string         `json:"tx_id"`  // 交易ID
 | 
			
		||||
	Amount   float64        `json:"amount"` // 打赏金额
 | 
			
		||||
	Remark   string         `json:"remark"` // 打赏备注
 | 
			
		||||
	Status   bool           `json:"status"` // 核销状态
 | 
			
		||||
	Exchange RewardExchange `json:"exchange"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RewardExchange struct {
 | 
			
		||||
	Power int `json:"power"`
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										30
									
								
								api/store/vo/suno_job.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								api/store/vo/suno_job.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,30 @@
 | 
			
		||||
package vo
 | 
			
		||||
 | 
			
		||||
type SunoJob struct {
 | 
			
		||||
	Id           uint                   `json:"id"`
 | 
			
		||||
	UserId       int                    `json:"user_id"`
 | 
			
		||||
	Channel      string                 `json:"channel"`
 | 
			
		||||
	Title        string                 `json:"title"`
 | 
			
		||||
	Type         int                    `json:"type"`
 | 
			
		||||
	TaskId       string                 `json:"task_id"`
 | 
			
		||||
	RefTaskId    string                 `json:"ref_task_id"`  // 续写的任务id
 | 
			
		||||
	Tags         string                 `json:"tags"`         // 歌曲风格和标签
 | 
			
		||||
	Instrumental bool                   `json:"instrumental"` // 是否生成纯音乐
 | 
			
		||||
	ExtendSecs   int                    `json:"extend_secs"`  // 续写秒数
 | 
			
		||||
	SongId       string                 `json:"song_id"`      // 续写的歌曲id
 | 
			
		||||
	RefSongId    string                 `json:"ref_song_id"`  // 续写的歌曲id
 | 
			
		||||
	Prompt       string                 `json:"prompt"`       // 提示词
 | 
			
		||||
	CoverURL     string                 `json:"cover_url"`    // 封面图 URL
 | 
			
		||||
	AudioURL     string                 `json:"audio_url"`    // 音频 URL
 | 
			
		||||
	ModelName    string                 `json:"model_name"`   // 模型名称
 | 
			
		||||
	Progress     int                    `json:"progress"`     // 任务进度
 | 
			
		||||
	Duration     int                    `json:"duration"`     // 银屏时长,秒
 | 
			
		||||
	Publish      bool                   `json:"publish"`      // 是否发布
 | 
			
		||||
	ErrMsg       string                 `json:"err_msg"`      // 错误信息
 | 
			
		||||
	RawData      map[string]interface{} `json:"raw_data"`     // 原始数据 json
 | 
			
		||||
	Power        int                    `json:"power"`        // 消耗算力
 | 
			
		||||
	RefSong      map[string]interface{} `json:"ref_song,omitempty"`
 | 
			
		||||
	User         map[string]interface{} `json:"user,omitempty"` //关联用户信息
 | 
			
		||||
	PlayTimes    int                    `json:"play_times"`     // 播放次数
 | 
			
		||||
	CreatedAt    int64                  `json:"created_at"`
 | 
			
		||||
}
 | 
			
		||||
@@ -4,6 +4,8 @@ type User struct {
 | 
			
		||||
	BaseVo
 | 
			
		||||
	Username    string   `json:"username"`
 | 
			
		||||
	Nickname    string   `json:"nickname"`
 | 
			
		||||
	Mobile      string   `json:"mobile"`
 | 
			
		||||
	Email       string   `json:"email"`
 | 
			
		||||
	Avatar      string   `json:"avatar"`
 | 
			
		||||
	Salt        string   `json:"salt"`          // 密码盐
 | 
			
		||||
	Power       int      `json:"power"`         // 剩余算力
 | 
			
		||||
@@ -14,4 +16,6 @@ type User struct {
 | 
			
		||||
	LastLoginAt int64    `json:"last_login_at"` // 最后登录时间
 | 
			
		||||
	LastLoginIp string   `json:"last_login_ip"` // 最后登录 IP
 | 
			
		||||
	Vip         bool     `json:"vip"`
 | 
			
		||||
	OpenId      string   `json:"openid"`   // 第三方登录 OpenID
 | 
			
		||||
	Platform    string   `json:"platform"` // 第三方登录平台
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										23
									
								
								api/store/vo/video_job.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								api/store/vo/video_job.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,23 @@
 | 
			
		||||
package vo
 | 
			
		||||
 | 
			
		||||
import "geekai/core/types"
 | 
			
		||||
 | 
			
		||||
type VideoJob struct {
 | 
			
		||||
	Id        uint                   `json:"id"`
 | 
			
		||||
	UserId    int                    `json:"user_id"`
 | 
			
		||||
	Channel   string                 `json:"channel"`
 | 
			
		||||
	Type      string                 `json:"type"`
 | 
			
		||||
	TaskId    string                 `json:"task_id"`
 | 
			
		||||
	Prompt    string                 `json:"prompt"`     // 提示词
 | 
			
		||||
	PromptExt string                 `json:"prompt_ext"` // 提示词
 | 
			
		||||
	CoverURL  string                 `json:"cover_url"`  // 封面图 URL
 | 
			
		||||
	VideoURL  string                 `json:"video_url"`  // 无水印视频 URL
 | 
			
		||||
	WaterURL  string                 `json:"water_url"`  // 有水印视频 URL
 | 
			
		||||
	Progress  int                    `json:"progress"`   // 任务进度
 | 
			
		||||
	Publish   bool                   `json:"publish"`    // 是否发布
 | 
			
		||||
	ErrMsg    string                 `json:"err_msg"`    // 错误信息
 | 
			
		||||
	RawData   map[string]interface{} `json:"raw_data"`   // 原始数据 json
 | 
			
		||||
	Power     int                    `json:"power"`      // 消耗算力
 | 
			
		||||
	Params    types.VideoParams      `json:"params"`     // 任务参数
 | 
			
		||||
	CreatedAt int64                  `json:"created_at"`
 | 
			
		||||
}
 | 
			
		||||
@@ -1,12 +1,55 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	text := "https://nk.img.r9it.com/chatgpt-plus/1712709360012445.png"
 | 
			
		||||
	parse, _ := url.Parse(text)
 | 
			
		||||
	fmt.Println(fmt.Sprintf("%s://%s", parse.Scheme, parse.Host))
 | 
			
		||||
const (
 | 
			
		||||
	codeLength = 32 // 兑换码长度
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	codeMap  = make(map[string]bool)
 | 
			
		||||
	mapMutex = &sync.Mutex{}
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// GenerateUniqueCode 生成唯一兑换码
 | 
			
		||||
func GenerateUniqueCode() (string, error) {
 | 
			
		||||
	for {
 | 
			
		||||
		code, err := generateCode()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		mapMutex.Lock()
 | 
			
		||||
		if !codeMap[code] {
 | 
			
		||||
			codeMap[code] = true
 | 
			
		||||
			mapMutex.Unlock()
 | 
			
		||||
			return code, nil
 | 
			
		||||
		}
 | 
			
		||||
		mapMutex.Unlock()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// generateCode 生成兑换码
 | 
			
		||||
func generateCode() (string, error) {
 | 
			
		||||
	bytes := make([]byte, codeLength/2) // 因为 hex 编码会使长度翻倍
 | 
			
		||||
	if _, err := rand.Read(bytes); err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	return hex.EncodeToString(bytes), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	for i := 0; i < 10; i++ {
 | 
			
		||||
		code, err := GenerateUniqueCode()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			fmt.Println("Error generating code:", err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		fmt.Println("Generated code:", code)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -84,6 +84,8 @@ func CopyObject(src interface{}, dst interface{}) error {
 | 
			
		||||
				case reflect.Bool:
 | 
			
		||||
					value.SetBool(v.Bool())
 | 
			
		||||
					break
 | 
			
		||||
				default:
 | 
			
		||||
					value.Set(v)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										106
									
								
								api/utils/file.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								api/utils/file.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,106 @@
 | 
			
		||||
package utils
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/microcosm-cc/bluemonday"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/google/go-tika/tika"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ReadFileContent(filePath string, tikaHost string) (string, error) {
 | 
			
		||||
	// for remote file, download it first
 | 
			
		||||
	if strings.HasPrefix(filePath, "http") {
 | 
			
		||||
		file, err := downloadFile(filePath)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", err
 | 
			
		||||
		}
 | 
			
		||||
		filePath = file
 | 
			
		||||
	}
 | 
			
		||||
	// 创建 Tika 客户端
 | 
			
		||||
	client := tika.NewClient(nil, tikaHost)
 | 
			
		||||
	// 打开 PDF 文件
 | 
			
		||||
	file, err := os.Open(filePath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with open file: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	defer file.Close()
 | 
			
		||||
 | 
			
		||||
	// 使用 Tika 提取 PDF 文件的文本内容
 | 
			
		||||
	content, err := client.Parse(context.TODO(), file)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with parse file: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ext := filepath.Ext(filePath)
 | 
			
		||||
	switch ext {
 | 
			
		||||
	case ".doc", ".docx", ".pdf", ".pptx", "ppt":
 | 
			
		||||
		return cleanBlankLine(cleanHtml(content, false)), nil
 | 
			
		||||
	case ".xls", ".xlsx":
 | 
			
		||||
		return cleanBlankLine(cleanHtml(content, true)), nil
 | 
			
		||||
	default:
 | 
			
		||||
		return cleanBlankLine(content), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 清理文本内容
 | 
			
		||||
func cleanHtml(html string, keepTable bool) string {
 | 
			
		||||
	// 清理 HTML 标签
 | 
			
		||||
	var policy *bluemonday.Policy
 | 
			
		||||
	if keepTable {
 | 
			
		||||
		policy = bluemonday.NewPolicy()
 | 
			
		||||
		policy.AllowElements("table", "thead", "tbody", "tfoot", "tr", "td", "th")
 | 
			
		||||
	} else {
 | 
			
		||||
		policy = bluemonday.StrictPolicy()
 | 
			
		||||
	}
 | 
			
		||||
	return policy.Sanitize(html)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func cleanBlankLine(content string) string {
 | 
			
		||||
	lines := strings.Split(content, "\n")
 | 
			
		||||
	texts := make([]string, 0)
 | 
			
		||||
	for _, line := range lines {
 | 
			
		||||
		line = strings.TrimSpace(line)
 | 
			
		||||
		if len(line) < 2 {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		// discard image
 | 
			
		||||
		if strings.HasSuffix(line, ".png") ||
 | 
			
		||||
			strings.HasSuffix(line, ".jpg") ||
 | 
			
		||||
			strings.HasSuffix(line, ".jpeg") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		texts = append(texts, line)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return strings.Join(texts, "\n")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 下载文件
 | 
			
		||||
func downloadFile(url string) (string, error) {
 | 
			
		||||
	base := filepath.Base(url)
 | 
			
		||||
	dir := os.TempDir()
 | 
			
		||||
	filename := filepath.Join(dir, base)
 | 
			
		||||
	out, err := os.Create(filename)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	defer out.Close()
 | 
			
		||||
 | 
			
		||||
	// 获取数据
 | 
			
		||||
	resp, err := http.Get(url)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
	// 写入数据到文件
 | 
			
		||||
	_, err = io.Copy(out, resp.Body)
 | 
			
		||||
	return filename, err
 | 
			
		||||
}
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user